# src/app/ml_train.py
"""
Standalone ML training job executed inside the Slurm cluster.
This module encapsulates the horizon-wise training workflow used by the
platform's MLOps pipeline. It loads non-imputed observations for selected
coordinates, prepares future targets per forecasting horizon, trains a
baseline Scikit-learn model and a simple PyTorch regressor, persists model
artifacts under the shared ``/data`` volume, and records scores in the
database for monitoring and visualization.
See Also
--------
app.slurm_job_trigger.create_and_dispatch_training_job
app.slurm_job_trigger.trigger_slurm_job
app.ml_utils.get_latest_training_logs
app.ml_utils.get_historical_scores
app.models.TrainingLog
app.models.TrainingStatus
Notes
-----
- Primary role: execute training for configured horizons and persist both
models and TrainingLog entries, updating the TrainingStatus singleton
throughout execution.
- Key dependencies: a reachable database via ``DATABASE_URL``, a writable
shared volume at ``/data``, and environment variables controlling training.
- Invariants: ``DATABASE_URL`` must be set; the shared ``/data`` volume must
exist and be writable; minimum data thresholds determine training viability.
Examples
--------
>>> # Executed by Slurm via the dispatcher # doctest: +SKIP
>>> # $ python3 /data/app_code_for_slurm/ml_train.py # doctest: +SKIP
>>> from app.ml_train import main # doctest: +SKIP
>>> main() # doctest: +SKIP
"""
import logging
import os
import uuid
from datetime import datetime
from pathlib import Path
from typing import Tuple, cast
import joblib
import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sqlalchemy import Boolean, DateTime, Float, Integer, String, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
Session as SQLAlchemySession,
mapped_column,
sessionmaker,
)
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
# --- Configuration and Constants ---
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
logger = logging.getLogger(__name__)
DATA_DIRECTORY = Path(os.getenv("SLURM_JOB_DATA_PATH", "/data"))
HORIZON_SHIFTS = {"5min": 1, "1h": 12, "12h": 144, "24h": 288}
FEATURE_COLUMNS = [
"air_temperature",
"wind_speed",
"wind_direction",
"precipitation_amount",
]
TARGET_COLUMN = "air_temperature"
DEFAULT_NUM_EPOCHS = int(os.getenv("ML_NUM_EPOCHS", "10"))
DEFAULT_LEARNING_RATE = float(os.getenv("ML_LEARNING_RATE", "1e-3"))
DEFAULT_BATCH_SIZE = int(os.getenv("ML_BATCH_SIZE", "32"))
DEFAULT_TEST_SIZE = float(os.getenv("ML_TEST_SIZE", "0.2"))
MIN_DATA_POINTS_FOR_TRAINING = int(os.getenv("ML_MIN_DATA_POINTS", "50"))
# --- ORM Base and Models ---
[docs]
class Base(DeclarativeBase):
"""Base class for ORM models in this standalone job.
Notes
-----
The standalone training job defines a minimal subset of ORM models to
avoid importing the main application's metadata. Tables mirror the
fields in ``app.models`` sufficiently for writing logs and status.
"""
pass
[docs]
class WeatherObservation(Base):
"""Observational weather record (non-imputed preferred for training).
Attributes
----------
timestamp : datetime
Unique timestamp for the observation (UTC).
latitude : float
Coordinate latitude in decimal degrees.
longitude : float
Coordinate longitude in decimal degrees.
air_temperature : float | None
Air temperature in degrees Celsius.
wind_speed : float | None
Wind speed in m/s.
wind_direction : float | None
Wind direction in degrees.
cloud_area_fraction : float | None
Cloud cover fraction (0–1) when available.
precipitation_amount : float | None
Precipitation amount in mm for the interval.
is_imputed : bool
Whether the record was imputed (training uses non-imputed).
"""
__tablename__ = "weather_observations"
timestamp: Mapped[datetime] = mapped_column(DateTime, primary_key=True)
latitude: Mapped[float] = mapped_column(Float, primary_key=True)
longitude: Mapped[float] = mapped_column(Float, primary_key=True)
air_temperature: Mapped[float | None] = mapped_column(Float, nullable=True)
wind_speed: Mapped[float | None] = mapped_column(Float, nullable=True)
wind_direction: Mapped[float | None] = mapped_column(Float, nullable=True)
cloud_area_fraction: Mapped[float | None] = mapped_column(Float, nullable=True)
precipitation_amount: Mapped[float | None] = mapped_column(Float, nullable=True)
is_imputed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
[docs]
class TrainingLog(Base):
"""Log entry for one training run and horizon.
Attributes
----------
id : str
UUID identifier of the run.
timestamp : datetime
When the run completed (UTC).
horizon : str
Unique horizon key (often ``"<coord>_<horizon>"``).
sklearn_score : float
R^2 score for the Sklearn model on the test split.
pytorch_score : float
R^2 score for the PyTorch model on the test split.
data_count : int
Number of samples used for this horizon after preprocessing.
coord_latitude : float | None
Latitude of the coordinate or ``None`` for aggregate runs.
coord_longitude : float | None
Longitude of the coordinate or ``None`` for aggregate runs.
horizon_label : str | None
One of ``{"5min", "1h", "12h", "24h"}``.
"""
__tablename__ = "training_logs"
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4()), index=True
)
timestamp: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
horizon: Mapped[str] = mapped_column(String)
sklearn_score: Mapped[float] = mapped_column(Float)
pytorch_score: Mapped[float] = mapped_column(Float)
data_count: Mapped[int] = mapped_column(Integer)
coord_latitude: Mapped[float | None] = mapped_column(Float, nullable=True)
coord_longitude: Mapped[float | None] = mapped_column(Float, nullable=True)
horizon_label: Mapped[str | None] = mapped_column(String, nullable=True)
[docs]
class TrainingStatus(Base):
"""Singleton table reflecting the current training state.
Attributes
----------
id : int
Primary key (always 1 in this job flow).
is_training : bool
Whether a training job is running.
last_trained_at : datetime | None
Timestamp of the last successful training completion.
train_count : int
Number of training runs since system start or initialization.
current_horizon : str | None
Human-readable status message or horizon marker.
"""
__tablename__ = "training_status"
id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
is_training: Mapped[bool] = mapped_column(Boolean, default=False)
last_trained_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
train_count: Mapped[int] = mapped_column(Integer, default=0)
current_horizon: Mapped[str | None] = mapped_column(String, nullable=True)
[docs]
class SimpleRegressionNet(nn.Module):
"""Simple feed-forward regression network for 1-step regression.
Parameters
----------
input_dim : int
Number of input features (must be positive).
Examples
--------
>>> import torch # doctest: +SKIP
>>> from app.ml_train import SimpleRegressionNet # doctest: +SKIP
>>> net = SimpleRegressionNet(input_dim=4) # doctest: +SKIP
>>> x = torch.randn(2, 4) # doctest: +SKIP
>>> y = net(x) # doctest: +SKIP
>>> y.shape == (2, 1) # doctest: +SKIP
True
"""
def __init__(self, input_dim: int) -> None:
super().__init__()
assert input_dim > 0, f"input_dim must be positive, got {input_dim}"
self.net = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Compute predictions for input features.
Parameters
----------
x : torch.Tensor
Input batch with shape ``(batch_size, input_dim)``.
Returns
-------
torch.Tensor
Predicted target values with shape ``(batch_size, 1)``.
"""
return cast(torch.Tensor, self.net(x))
# --- Database Session ---
[docs]
def get_db_session() -> SQLAlchemySession:
"""Create and return a new SQLAlchemy session.
Returns
-------
sqlalchemy.orm.Session
Session bound to the engine specified by ``DATABASE_URL``.
Raises
------
ValueError
If ``DATABASE_URL`` is not defined in the environment.
sqlalchemy.exc.SQLAlchemyError
If the engine cannot be created or the metadata initialization fails.
Examples
--------
>>> import os # doctest: +SKIP
>>> os.environ["DATABASE_URL"] = "sqlite:///:memory:" # doctest: +SKIP
>>> from app.ml_train import get_db_session # doctest: +SKIP
>>> s = get_db_session() # doctest: +SKIP
>>> s.close() # doctest: +SKIP
"""
database_url = os.getenv("DATABASE_URL")
if not database_url:
logger.error("DATABASE_URL environment variable not set.")
raise ValueError("DATABASE_URL is required")
connect_args = {}
if database_url.startswith("sqlite"):
connect_args["check_same_thread"] = False
engine = create_engine(database_url, connect_args=connect_args)
Base.metadata.create_all(engine)
SessionLocal = sessionmaker(bind=engine)
return SessionLocal()
# --- Core Functions ---
[docs]
def load_training_data(
session: SQLAlchemySession, latitude: float, longitude: float
) -> pd.DataFrame:
"""Load non-imputed observations for one coordinate.
Parameters
----------
session : sqlalchemy.orm.Session
Open SQLAlchemy session.
latitude : float
Coordinate latitude in decimal degrees.
longitude : float
Coordinate longitude in decimal degrees.
Returns
-------
pandas.DataFrame
DataFrame with a ``timestamp`` column parsed to ``datetime64[ns]``.
Empty DataFrame if no non-imputed rows are available.
Examples
--------
>>> # Requires seeded DB with weather_observations # doctest: +SKIP
>>> from app.ml_train import load_training_data # doctest: +SKIP
>>> df = load_training_data(session, 57.70, 11.90) # doctest: +SKIP
>>> isinstance(df.empty, bool) # doctest: +SKIP
True
"""
logger.info(f"Loading non-imputed data for ({latitude}, {longitude})")
try:
query = session.query(WeatherObservation).filter_by(
latitude=latitude, longitude=longitude, is_imputed=False
)
assert session.bind, "Session must be bound to an engine"
df = pd.read_sql(query.statement, session.bind)
if df.empty:
logger.warning(f"No actual data for ({latitude}, {longitude})")
return pd.DataFrame()
df["timestamp"] = pd.to_datetime(df["timestamp"])
logger.info(f"Loaded {len(df)} rows for ({latitude}, {longitude})")
return df
except SQLAlchemyError as e:
logger.error(
f"Error loading data for ({latitude}, {longitude}): {e}", exc_info=True
)
return pd.DataFrame()
[docs]
def prepare_horizon_data(
df: pd.DataFrame, horizon_label: str, shift_steps: int
) -> Tuple[NDArray[np.float64], NDArray[np.float64], int]:
"""Prepare features and targets for a specific horizon.
Parameters
----------
df : pandas.DataFrame
Input data sorted/filtered per coordinate.
horizon_label : str
Label for the horizon (e.g., ``"5min"``, ``"1h"``).
shift_steps : int
Positive number of steps the target is shifted into the future.
Returns
-------
tuple[numpy.ndarray, numpy.ndarray, int]
Tuple ``(X, y, count)`` where count is the number of samples retained
after dropping NA rows for the chosen horizon.
Examples
--------
>>> import pandas as pd # doctest: +SKIP
>>> from app.ml_train import prepare_horizon_data # doctest: +SKIP
>>> ts = pd.date_range("2024-01-01", periods=12, freq="5min") # doctest: +SKIP
>>> df = pd.DataFrame({ # doctest: +SKIP
... "timestamp": ts,
... "air_temperature": list(range(12)),
... "wind_speed": [1.0]*12,
... "wind_direction": [0.0]*12,
... "precipitation_amount": [0.0]*12,
... })
>>> X, y, count = prepare_horizon_data(df, "5min", 1) # doctest: +SKIP
>>> count >= 10 # doctest: +SKIP
True
"""
if df.empty or len(df) < shift_steps + 1:
logger.debug(f"Insufficient rows for horizon '{horizon_label}'")
return (
np.empty((0, len(FEATURE_COLUMNS)), dtype=np.float64),
np.empty(0, dtype=np.float64),
0,
)
df_sorted = df.sort_values("timestamp").copy()
future_col = f"future_{TARGET_COLUMN}_{horizon_label}"
df_sorted[future_col] = df_sorted[TARGET_COLUMN].shift(-shift_steps)
df_sorted.dropna(subset=FEATURE_COLUMNS + [future_col], inplace=True)
count = len(df_sorted)
threshold = (
max(5, int(MIN_DATA_POINTS_FOR_TRAINING * 0.2))
if horizon_label == "5min"
else MIN_DATA_POINTS_FOR_TRAINING
)
if count < threshold:
logger.warning(
f"Not enough data after prep for '{horizon_label}' ({count} < {threshold})"
)
return (
np.empty((0, len(FEATURE_COLUMNS)), dtype=np.float64),
np.empty(0, dtype=np.float64),
0,
)
X = df_sorted[FEATURE_COLUMNS].to_numpy(dtype=np.float64)
y = df_sorted[future_col].to_numpy(dtype=np.float64)
return X, y, count
[docs]
def train_and_save_model(
X_train: NDArray[np.float64],
y_train: NDArray[np.float64],
X_test: NDArray[np.float64],
y_test: NDArray[np.float64],
horizon: str,
coord_str: str,
) -> Tuple[float, float]:
"""Train Ridge and PyTorch models, persist them, and return R² scores.
Parameters
----------
X_train, y_train, X_test, y_test : numpy.ndarray
Training and test splits.
horizon : str
Horizon label used for filenames and logging.
coord_str : str
Coordinate identifier (e.g., ``"lat57_7000_lon11_9000"``).
Returns
-------
tuple[float, float]
``(sklearn_r2, pytorch_r2)`` scores on the test split.
Notes
-----
- Models are saved under ``/data/models/<coord_str>/`` with deterministic
filenames per horizon and framework.
Raises
------
OSError
If persisting the model artifacts to disk fails.
Examples
--------
>>> import numpy as np # doctest: +SKIP
>>> from app.ml_train import train_and_save_model # doctest: +SKIP
>>> X = np.random.rand(100, 4); y = np.random.rand(100) # doctest: +SKIP
>>> s = int(len(X)*0.8) # doctest: +SKIP
>>> train_and_save_model(X[:s], y[:s], X[s:], y[s:], # doctest: +SKIP
... "5min", "lat57_7000_lon11_9000") # doctest: +SKIP
(..., ...)
"""
model_dir = DATA_DIRECTORY / "models" / coord_str
model_dir.mkdir(parents=True, exist_ok=True)
sklearn_path = model_dir / f"sklearn_model_{horizon}.pkl"
pytorch_path = model_dir / f"pytorch_model_{horizon}.pt"
# Sklearn Ridge
sklearn_model = Ridge()
sklearn_model.fit(X_train, y_train)
sklearn_score = sklearn_model.score(X_test, y_test) if len(X_test) else 0.0
joblib.dump(sklearn_model, sklearn_path)
logger.info(f"Saved Sklearn model at {sklearn_path} (R²={sklearn_score:.4f})")
# PyTorch
Xt = torch.from_numpy(X_train.astype(np.float32))
yt = torch.from_numpy(y_train.astype(np.float32)).view(-1, 1)
dataset = TensorDataset(Xt, yt)
loader = DataLoader(dataset, batch_size=DEFAULT_BATCH_SIZE, shuffle=True)
net = SimpleRegressionNet(X_train.shape[1])
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=DEFAULT_LEARNING_RATE)
net.train()
for epoch in range(DEFAULT_NUM_EPOCHS):
total_loss = 0.0
for batch_X, batch_y in loader:
optimizer.zero_grad()
preds = net(batch_X)
loss = criterion(preds, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
logger.debug(
f"Epoch {epoch + 1}/{DEFAULT_NUM_EPOCHS}: Loss={total_loss / len(loader):.4f}"
)
net.eval()
pytorch_score = 0.0
if len(X_test):
with torch.no_grad():
preds = (
net(torch.from_numpy(X_test.astype(np.float32))).cpu().numpy().flatten()
)
pytorch_score = r2_score(y_test, preds)
torch.save(net.state_dict(), pytorch_path)
logger.info(f"Saved PyTorch model at {pytorch_path} (R²={pytorch_score:.4f})")
return float(sklearn_score), float(pytorch_score)
[docs]
def update_training_status(
session: SQLAlchemySession,
is_training: bool,
current_status_message: str | None = None,
increment_count: bool = False,
) -> None:
"""Update the ``training_status`` row with the current state.
Parameters
----------
session : sqlalchemy.orm.Session
Open SQLAlchemy session.
is_training : bool
Flag indicating whether a training is in progress.
current_status_message : str | None, optional
Optional status message or horizon marker.
increment_count : bool, optional
Whether to increment the total training counter.
Examples
--------
>>> # Within an open SQLAlchemy session # doctest: +SKIP
>>> from app.ml_train import update_training_status # doctest: +SKIP
>>> update_training_status(session, True, "Job started") # doctest: +SKIP
"""
try:
status = session.query(TrainingStatus).get(1)
if not status:
status = TrainingStatus(id=1)
session.add(status)
status.is_training = is_training
if current_status_message is not None:
status.current_horizon = current_status_message
if not is_training:
status.last_trained_at = datetime.utcnow()
status.current_horizon = None
if increment_count:
status.train_count = (status.train_count or 0) + 1
session.commit()
except SQLAlchemyError as e:
logger.error(f"Failed to update training status: {e}", exc_info=True)
session.rollback()
[docs]
def main() -> None:
"""Entry point for the standalone ML training job.
Notes
-----
- Coordinates are fetched (preferentially central ones) and iterated. For
each horizon, data are prepared, models trained, artifacts saved, and a
``TrainingLog`` row written.
- Status updates are written to ``TrainingStatus`` throughout the run.
Examples
--------
>>> # Executed inside Slurm job container # doctest: +SKIP
>>> from app.ml_train import main # doctest: +SKIP
>>> main() # doctest: +SKIP
"""
logger.info("=== Starting ML Training Job ===")
try:
session = get_db_session()
except (ValueError, AssertionError) as e:
logger.error(f"Configuration error: {e}")
return
try:
from sqlalchemy import text
central_query = session.execute(
text("SELECT latitude, longitude FROM coordinates WHERE is_central = true")
)
all_coords = [(row.latitude, row.longitude) for row in central_query.fetchall()]
except SQLAlchemyError as e:
logger.warning(f"Could not fetch central coords: {e}")
all_coords = [(56.8618, 14.8069)]
if not all_coords:
update_training_status(
session, is_training=False, current_status_message="No data"
)
session.close()
return
update_training_status(
session, is_training=True, current_status_message="Job started"
)
overall_success = True
try:
for lat, lon in all_coords:
coord_str = f"lat{lat:.4f}_lon{lon:.4f}".replace(".", "_")
update_training_status(
session,
is_training=True,
current_status_message=f"Training {coord_str}",
)
df = load_training_data(session, lat, lon)
if df.empty:
logger.warning(f"Skipping {coord_str}, no data")
continue
for horizon, steps in HORIZON_SHIFTS.items():
update_training_status(
session,
is_training=True,
current_status_message=f"{coord_str}:{horizon}",
)
X, y, count = prepare_horizon_data(df, horizon, steps)
if count == 0:
session.add(
TrainingLog(
id=str(uuid.uuid4()),
horizon=f"{coord_str}_{horizon}",
sklearn_score=float(0.0),
pytorch_score=float(0.0),
data_count=int(0),
coord_latitude=float(lat),
coord_longitude=float(lon),
horizon_label=horizon,
)
)
session.commit()
continue
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=DEFAULT_TEST_SIZE, shuffle=False
)
if (
len(X_train) < MIN_DATA_POINTS_FOR_TRAINING // 2
or len(X_test) < MIN_DATA_POINTS_FOR_TRAINING // 10
):
logger.warning(
f"Skipping {coord_str}:{horizon}, insufficient split data"
)
continue
sklearn_score, pytorch_score = train_and_save_model(
X_train, y_train, X_test, y_test, horizon, coord_str
)
session.add(
TrainingLog(
id=str(uuid.uuid4()),
horizon=f"{coord_str}_{horizon}",
sklearn_score=float(sklearn_score),
pytorch_score=float(pytorch_score),
data_count=int(count),
coord_latitude=float(lat),
coord_longitude=float(lon),
horizon_label=horizon,
)
)
session.commit()
except SQLAlchemyError as e:
logger.error(f"Training loop error: {e}", exc_info=True)
overall_success = False
update_training_status(
session, is_training=False, current_status_message=f"Error: {e}"
)
finally:
if overall_success:
update_training_status(
session, is_training=False, current_status_message="Completed"
)
logger.info("=== ML Training Job Completed ===")
else:
logger.error("=== ML Training Job Failed ===")
session.close()
if __name__ == "__main__":
# Prepare directories
DATA_DIRECTORY.mkdir(parents=True, exist_ok=True)
(DATA_DIRECTORY / "models").mkdir(parents=True, exist_ok=True)
main()