Source code for app.ml_train

# src/app/ml_train.py
"""
Standalone ML training job executed inside the Slurm cluster.

This module encapsulates the horizon-wise training workflow used by the
platform's MLOps pipeline. It loads non-imputed observations for selected
coordinates, prepares future targets per forecasting horizon, trains a
baseline Scikit-learn model and a simple PyTorch regressor, persists model
artifacts under the shared ``/data`` volume, and records scores in the
database for monitoring and visualization.

See Also
--------
app.slurm_job_trigger.create_and_dispatch_training_job
app.slurm_job_trigger.trigger_slurm_job
app.ml_utils.get_latest_training_logs
app.ml_utils.get_historical_scores
app.models.TrainingLog
app.models.TrainingStatus

Notes
-----
- Primary role: execute training for configured horizons and persist both
  models and TrainingLog entries, updating the TrainingStatus singleton
  throughout execution.
- Key dependencies: a reachable database via ``DATABASE_URL``, a writable
  shared volume at ``/data``, and environment variables controlling training.
- Invariants: ``DATABASE_URL`` must be set; the shared ``/data`` volume must
  exist and be writable; minimum data thresholds determine training viability.

Examples
--------
>>> # Executed by Slurm via the dispatcher                      # doctest: +SKIP
>>> # $ python3 /data/app_code_for_slurm/ml_train.py             # doctest: +SKIP
>>> from app.ml_train import main                                # doctest: +SKIP
>>> main()                                                       # doctest: +SKIP
"""


import logging
import os
import uuid
from datetime import datetime
from pathlib import Path
from typing import Tuple, cast

import joblib
import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sqlalchemy import Boolean, DateTime, Float, Integer, String, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import (
    DeclarativeBase,
    Mapped,
    Session as SQLAlchemySession,
    mapped_column,
    sessionmaker,
)
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

# --- Configuration and Constants ---
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
logger = logging.getLogger(__name__)

DATA_DIRECTORY = Path(os.getenv("SLURM_JOB_DATA_PATH", "/data"))
HORIZON_SHIFTS = {"5min": 1, "1h": 12, "12h": 144, "24h": 288}
FEATURE_COLUMNS = [
    "air_temperature",
    "wind_speed",
    "wind_direction",
    "precipitation_amount",
]
TARGET_COLUMN = "air_temperature"

DEFAULT_NUM_EPOCHS = int(os.getenv("ML_NUM_EPOCHS", "10"))
DEFAULT_LEARNING_RATE = float(os.getenv("ML_LEARNING_RATE", "1e-3"))
DEFAULT_BATCH_SIZE = int(os.getenv("ML_BATCH_SIZE", "32"))
DEFAULT_TEST_SIZE = float(os.getenv("ML_TEST_SIZE", "0.2"))
MIN_DATA_POINTS_FOR_TRAINING = int(os.getenv("ML_MIN_DATA_POINTS", "50"))


# --- ORM Base and Models ---
[docs] class Base(DeclarativeBase): """Base class for ORM models in this standalone job. Notes ----- The standalone training job defines a minimal subset of ORM models to avoid importing the main application's metadata. Tables mirror the fields in ``app.models`` sufficiently for writing logs and status. """ pass
[docs] class WeatherObservation(Base): """Observational weather record (non-imputed preferred for training). Attributes ---------- timestamp : datetime Unique timestamp for the observation (UTC). latitude : float Coordinate latitude in decimal degrees. longitude : float Coordinate longitude in decimal degrees. air_temperature : float | None Air temperature in degrees Celsius. wind_speed : float | None Wind speed in m/s. wind_direction : float | None Wind direction in degrees. cloud_area_fraction : float | None Cloud cover fraction (0–1) when available. precipitation_amount : float | None Precipitation amount in mm for the interval. is_imputed : bool Whether the record was imputed (training uses non-imputed). """ __tablename__ = "weather_observations" timestamp: Mapped[datetime] = mapped_column(DateTime, primary_key=True) latitude: Mapped[float] = mapped_column(Float, primary_key=True) longitude: Mapped[float] = mapped_column(Float, primary_key=True) air_temperature: Mapped[float | None] = mapped_column(Float, nullable=True) wind_speed: Mapped[float | None] = mapped_column(Float, nullable=True) wind_direction: Mapped[float | None] = mapped_column(Float, nullable=True) cloud_area_fraction: Mapped[float | None] = mapped_column(Float, nullable=True) precipitation_amount: Mapped[float | None] = mapped_column(Float, nullable=True) is_imputed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
[docs] class TrainingLog(Base): """Log entry for one training run and horizon. Attributes ---------- id : str UUID identifier of the run. timestamp : datetime When the run completed (UTC). horizon : str Unique horizon key (often ``"<coord>_<horizon>"``). sklearn_score : float R^2 score for the Sklearn model on the test split. pytorch_score : float R^2 score for the PyTorch model on the test split. data_count : int Number of samples used for this horizon after preprocessing. coord_latitude : float | None Latitude of the coordinate or ``None`` for aggregate runs. coord_longitude : float | None Longitude of the coordinate or ``None`` for aggregate runs. horizon_label : str | None One of ``{"5min", "1h", "12h", "24h"}``. """ __tablename__ = "training_logs" id: Mapped[str] = mapped_column( String, primary_key=True, default=lambda: str(uuid.uuid4()), index=True ) timestamp: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) horizon: Mapped[str] = mapped_column(String) sklearn_score: Mapped[float] = mapped_column(Float) pytorch_score: Mapped[float] = mapped_column(Float) data_count: Mapped[int] = mapped_column(Integer) coord_latitude: Mapped[float | None] = mapped_column(Float, nullable=True) coord_longitude: Mapped[float | None] = mapped_column(Float, nullable=True) horizon_label: Mapped[str | None] = mapped_column(String, nullable=True)
[docs] class TrainingStatus(Base): """Singleton table reflecting the current training state. Attributes ---------- id : int Primary key (always 1 in this job flow). is_training : bool Whether a training job is running. last_trained_at : datetime | None Timestamp of the last successful training completion. train_count : int Number of training runs since system start or initialization. current_horizon : str | None Human-readable status message or horizon marker. """ __tablename__ = "training_status" id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1) is_training: Mapped[bool] = mapped_column(Boolean, default=False) last_trained_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) train_count: Mapped[int] = mapped_column(Integer, default=0) current_horizon: Mapped[str | None] = mapped_column(String, nullable=True)
[docs] class SimpleRegressionNet(nn.Module): """Simple feed-forward regression network for 1-step regression. Parameters ---------- input_dim : int Number of input features (must be positive). Examples -------- >>> import torch # doctest: +SKIP >>> from app.ml_train import SimpleRegressionNet # doctest: +SKIP >>> net = SimpleRegressionNet(input_dim=4) # doctest: +SKIP >>> x = torch.randn(2, 4) # doctest: +SKIP >>> y = net(x) # doctest: +SKIP >>> y.shape == (2, 1) # doctest: +SKIP True """ def __init__(self, input_dim: int) -> None: super().__init__() assert input_dim > 0, f"input_dim must be positive, got {input_dim}" self.net = nn.Sequential( nn.Linear(input_dim, 64), nn.ReLU(), nn.Linear(64, 1), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Compute predictions for input features. Parameters ---------- x : torch.Tensor Input batch with shape ``(batch_size, input_dim)``. Returns ------- torch.Tensor Predicted target values with shape ``(batch_size, 1)``. """ return cast(torch.Tensor, self.net(x))
# --- Database Session ---
[docs] def get_db_session() -> SQLAlchemySession: """Create and return a new SQLAlchemy session. Returns ------- sqlalchemy.orm.Session Session bound to the engine specified by ``DATABASE_URL``. Raises ------ ValueError If ``DATABASE_URL`` is not defined in the environment. sqlalchemy.exc.SQLAlchemyError If the engine cannot be created or the metadata initialization fails. Examples -------- >>> import os # doctest: +SKIP >>> os.environ["DATABASE_URL"] = "sqlite:///:memory:" # doctest: +SKIP >>> from app.ml_train import get_db_session # doctest: +SKIP >>> s = get_db_session() # doctest: +SKIP >>> s.close() # doctest: +SKIP """ database_url = os.getenv("DATABASE_URL") if not database_url: logger.error("DATABASE_URL environment variable not set.") raise ValueError("DATABASE_URL is required") connect_args = {} if database_url.startswith("sqlite"): connect_args["check_same_thread"] = False engine = create_engine(database_url, connect_args=connect_args) Base.metadata.create_all(engine) SessionLocal = sessionmaker(bind=engine) return SessionLocal()
# --- Core Functions ---
[docs] def load_training_data( session: SQLAlchemySession, latitude: float, longitude: float ) -> pd.DataFrame: """Load non-imputed observations for one coordinate. Parameters ---------- session : sqlalchemy.orm.Session Open SQLAlchemy session. latitude : float Coordinate latitude in decimal degrees. longitude : float Coordinate longitude in decimal degrees. Returns ------- pandas.DataFrame DataFrame with a ``timestamp`` column parsed to ``datetime64[ns]``. Empty DataFrame if no non-imputed rows are available. Examples -------- >>> # Requires seeded DB with weather_observations # doctest: +SKIP >>> from app.ml_train import load_training_data # doctest: +SKIP >>> df = load_training_data(session, 57.70, 11.90) # doctest: +SKIP >>> isinstance(df.empty, bool) # doctest: +SKIP True """ logger.info(f"Loading non-imputed data for ({latitude}, {longitude})") try: query = session.query(WeatherObservation).filter_by( latitude=latitude, longitude=longitude, is_imputed=False ) assert session.bind, "Session must be bound to an engine" df = pd.read_sql(query.statement, session.bind) if df.empty: logger.warning(f"No actual data for ({latitude}, {longitude})") return pd.DataFrame() df["timestamp"] = pd.to_datetime(df["timestamp"]) logger.info(f"Loaded {len(df)} rows for ({latitude}, {longitude})") return df except SQLAlchemyError as e: logger.error( f"Error loading data for ({latitude}, {longitude}): {e}", exc_info=True ) return pd.DataFrame()
[docs] def prepare_horizon_data( df: pd.DataFrame, horizon_label: str, shift_steps: int ) -> Tuple[NDArray[np.float64], NDArray[np.float64], int]: """Prepare features and targets for a specific horizon. Parameters ---------- df : pandas.DataFrame Input data sorted/filtered per coordinate. horizon_label : str Label for the horizon (e.g., ``"5min"``, ``"1h"``). shift_steps : int Positive number of steps the target is shifted into the future. Returns ------- tuple[numpy.ndarray, numpy.ndarray, int] Tuple ``(X, y, count)`` where count is the number of samples retained after dropping NA rows for the chosen horizon. Examples -------- >>> import pandas as pd # doctest: +SKIP >>> from app.ml_train import prepare_horizon_data # doctest: +SKIP >>> ts = pd.date_range("2024-01-01", periods=12, freq="5min") # doctest: +SKIP >>> df = pd.DataFrame({ # doctest: +SKIP ... "timestamp": ts, ... "air_temperature": list(range(12)), ... "wind_speed": [1.0]*12, ... "wind_direction": [0.0]*12, ... "precipitation_amount": [0.0]*12, ... }) >>> X, y, count = prepare_horizon_data(df, "5min", 1) # doctest: +SKIP >>> count >= 10 # doctest: +SKIP True """ if df.empty or len(df) < shift_steps + 1: logger.debug(f"Insufficient rows for horizon '{horizon_label}'") return ( np.empty((0, len(FEATURE_COLUMNS)), dtype=np.float64), np.empty(0, dtype=np.float64), 0, ) df_sorted = df.sort_values("timestamp").copy() future_col = f"future_{TARGET_COLUMN}_{horizon_label}" df_sorted[future_col] = df_sorted[TARGET_COLUMN].shift(-shift_steps) df_sorted.dropna(subset=FEATURE_COLUMNS + [future_col], inplace=True) count = len(df_sorted) threshold = ( max(5, int(MIN_DATA_POINTS_FOR_TRAINING * 0.2)) if horizon_label == "5min" else MIN_DATA_POINTS_FOR_TRAINING ) if count < threshold: logger.warning( f"Not enough data after prep for '{horizon_label}' ({count} < {threshold})" ) return ( np.empty((0, len(FEATURE_COLUMNS)), dtype=np.float64), np.empty(0, dtype=np.float64), 0, ) X = df_sorted[FEATURE_COLUMNS].to_numpy(dtype=np.float64) y = df_sorted[future_col].to_numpy(dtype=np.float64) return X, y, count
[docs] def train_and_save_model( X_train: NDArray[np.float64], y_train: NDArray[np.float64], X_test: NDArray[np.float64], y_test: NDArray[np.float64], horizon: str, coord_str: str, ) -> Tuple[float, float]: """Train Ridge and PyTorch models, persist them, and return R² scores. Parameters ---------- X_train, y_train, X_test, y_test : numpy.ndarray Training and test splits. horizon : str Horizon label used for filenames and logging. coord_str : str Coordinate identifier (e.g., ``"lat57_7000_lon11_9000"``). Returns ------- tuple[float, float] ``(sklearn_r2, pytorch_r2)`` scores on the test split. Notes ----- - Models are saved under ``/data/models/<coord_str>/`` with deterministic filenames per horizon and framework. Raises ------ OSError If persisting the model artifacts to disk fails. Examples -------- >>> import numpy as np # doctest: +SKIP >>> from app.ml_train import train_and_save_model # doctest: +SKIP >>> X = np.random.rand(100, 4); y = np.random.rand(100) # doctest: +SKIP >>> s = int(len(X)*0.8) # doctest: +SKIP >>> train_and_save_model(X[:s], y[:s], X[s:], y[s:], # doctest: +SKIP ... "5min", "lat57_7000_lon11_9000") # doctest: +SKIP (..., ...) """ model_dir = DATA_DIRECTORY / "models" / coord_str model_dir.mkdir(parents=True, exist_ok=True) sklearn_path = model_dir / f"sklearn_model_{horizon}.pkl" pytorch_path = model_dir / f"pytorch_model_{horizon}.pt" # Sklearn Ridge sklearn_model = Ridge() sklearn_model.fit(X_train, y_train) sklearn_score = sklearn_model.score(X_test, y_test) if len(X_test) else 0.0 joblib.dump(sklearn_model, sklearn_path) logger.info(f"Saved Sklearn model at {sklearn_path} (R²={sklearn_score:.4f})") # PyTorch Xt = torch.from_numpy(X_train.astype(np.float32)) yt = torch.from_numpy(y_train.astype(np.float32)).view(-1, 1) dataset = TensorDataset(Xt, yt) loader = DataLoader(dataset, batch_size=DEFAULT_BATCH_SIZE, shuffle=True) net = SimpleRegressionNet(X_train.shape[1]) criterion = nn.MSELoss() optimizer = torch.optim.Adam(net.parameters(), lr=DEFAULT_LEARNING_RATE) net.train() for epoch in range(DEFAULT_NUM_EPOCHS): total_loss = 0.0 for batch_X, batch_y in loader: optimizer.zero_grad() preds = net(batch_X) loss = criterion(preds, batch_y) loss.backward() optimizer.step() total_loss += loss.item() logger.debug( f"Epoch {epoch + 1}/{DEFAULT_NUM_EPOCHS}: Loss={total_loss / len(loader):.4f}" ) net.eval() pytorch_score = 0.0 if len(X_test): with torch.no_grad(): preds = ( net(torch.from_numpy(X_test.astype(np.float32))).cpu().numpy().flatten() ) pytorch_score = r2_score(y_test, preds) torch.save(net.state_dict(), pytorch_path) logger.info(f"Saved PyTorch model at {pytorch_path} (R²={pytorch_score:.4f})") return float(sklearn_score), float(pytorch_score)
[docs] def update_training_status( session: SQLAlchemySession, is_training: bool, current_status_message: str | None = None, increment_count: bool = False, ) -> None: """Update the ``training_status`` row with the current state. Parameters ---------- session : sqlalchemy.orm.Session Open SQLAlchemy session. is_training : bool Flag indicating whether a training is in progress. current_status_message : str | None, optional Optional status message or horizon marker. increment_count : bool, optional Whether to increment the total training counter. Examples -------- >>> # Within an open SQLAlchemy session # doctest: +SKIP >>> from app.ml_train import update_training_status # doctest: +SKIP >>> update_training_status(session, True, "Job started") # doctest: +SKIP """ try: status = session.query(TrainingStatus).get(1) if not status: status = TrainingStatus(id=1) session.add(status) status.is_training = is_training if current_status_message is not None: status.current_horizon = current_status_message if not is_training: status.last_trained_at = datetime.utcnow() status.current_horizon = None if increment_count: status.train_count = (status.train_count or 0) + 1 session.commit() except SQLAlchemyError as e: logger.error(f"Failed to update training status: {e}", exc_info=True) session.rollback()
[docs] def main() -> None: """Entry point for the standalone ML training job. Notes ----- - Coordinates are fetched (preferentially central ones) and iterated. For each horizon, data are prepared, models trained, artifacts saved, and a ``TrainingLog`` row written. - Status updates are written to ``TrainingStatus`` throughout the run. Examples -------- >>> # Executed inside Slurm job container # doctest: +SKIP >>> from app.ml_train import main # doctest: +SKIP >>> main() # doctest: +SKIP """ logger.info("=== Starting ML Training Job ===") try: session = get_db_session() except (ValueError, AssertionError) as e: logger.error(f"Configuration error: {e}") return try: from sqlalchemy import text central_query = session.execute( text("SELECT latitude, longitude FROM coordinates WHERE is_central = true") ) all_coords = [(row.latitude, row.longitude) for row in central_query.fetchall()] except SQLAlchemyError as e: logger.warning(f"Could not fetch central coords: {e}") all_coords = [(56.8618, 14.8069)] if not all_coords: update_training_status( session, is_training=False, current_status_message="No data" ) session.close() return update_training_status( session, is_training=True, current_status_message="Job started" ) overall_success = True try: for lat, lon in all_coords: coord_str = f"lat{lat:.4f}_lon{lon:.4f}".replace(".", "_") update_training_status( session, is_training=True, current_status_message=f"Training {coord_str}", ) df = load_training_data(session, lat, lon) if df.empty: logger.warning(f"Skipping {coord_str}, no data") continue for horizon, steps in HORIZON_SHIFTS.items(): update_training_status( session, is_training=True, current_status_message=f"{coord_str}:{horizon}", ) X, y, count = prepare_horizon_data(df, horizon, steps) if count == 0: session.add( TrainingLog( id=str(uuid.uuid4()), horizon=f"{coord_str}_{horizon}", sklearn_score=float(0.0), pytorch_score=float(0.0), data_count=int(0), coord_latitude=float(lat), coord_longitude=float(lon), horizon_label=horizon, ) ) session.commit() continue X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=DEFAULT_TEST_SIZE, shuffle=False ) if ( len(X_train) < MIN_DATA_POINTS_FOR_TRAINING // 2 or len(X_test) < MIN_DATA_POINTS_FOR_TRAINING // 10 ): logger.warning( f"Skipping {coord_str}:{horizon}, insufficient split data" ) continue sklearn_score, pytorch_score = train_and_save_model( X_train, y_train, X_test, y_test, horizon, coord_str ) session.add( TrainingLog( id=str(uuid.uuid4()), horizon=f"{coord_str}_{horizon}", sklearn_score=float(sklearn_score), pytorch_score=float(pytorch_score), data_count=int(count), coord_latitude=float(lat), coord_longitude=float(lon), horizon_label=horizon, ) ) session.commit() except SQLAlchemyError as e: logger.error(f"Training loop error: {e}", exc_info=True) overall_success = False update_training_status( session, is_training=False, current_status_message=f"Error: {e}" ) finally: if overall_success: update_training_status( session, is_training=False, current_status_message="Completed" ) logger.info("=== ML Training Job Completed ===") else: logger.error("=== ML Training Job Failed ===") session.close()
if __name__ == "__main__": # Prepare directories DATA_DIRECTORY.mkdir(parents=True, exist_ok=True) (DATA_DIRECTORY / "models").mkdir(parents=True, exist_ok=True) main()