ml_train.py#

# src/app/ml_train.py
"""
Standalone ML training job executed inside the Slurm cluster.

This module encapsulates the horizon-wise training workflow used by the
platform's MLOps pipeline. It loads non-imputed observations for selected
coordinates, prepares future targets per forecasting horizon, trains a
baseline Scikit-learn model and a simple PyTorch regressor, persists model
artifacts under the shared ``/data`` volume, and records scores in the
database for monitoring and visualization.

See Also
--------
app.slurm_job_trigger.create_and_dispatch_training_job
app.slurm_job_trigger.trigger_slurm_job
app.ml_utils.get_latest_training_logs
app.ml_utils.get_historical_scores
app.models.TrainingLog
app.models.TrainingStatus

Notes
-----
- Primary role: execute training for configured horizons and persist both
  models and TrainingLog entries, updating the TrainingStatus singleton
  throughout execution.
- Key dependencies: a reachable database via ``DATABASE_URL``, a writable
  shared volume at ``/data``, and environment variables controlling training.
- Invariants: ``DATABASE_URL`` must be set; the shared ``/data`` volume must
  exist and be writable; minimum data thresholds determine training viability.

Examples
--------
>>> # Executed by Slurm via the dispatcher                      # doctest: +SKIP
>>> # $ python3 /data/app_code_for_slurm/ml_train.py             # doctest: +SKIP
>>> from app.ml_train import main                                # doctest: +SKIP
>>> main()                                                       # doctest: +SKIP
"""


import logging
import os
import uuid
from datetime import datetime
from pathlib import Path
from typing import Tuple, cast

import joblib
import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sqlalchemy import Boolean, DateTime, Float, Integer, String, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import (
    DeclarativeBase,
    Mapped,
    Session as SQLAlchemySession,
    mapped_column,
    sessionmaker,
)
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

# --- Configuration and Constants ---
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
logger = logging.getLogger(__name__)

DATA_DIRECTORY = Path(os.getenv("SLURM_JOB_DATA_PATH", "/data"))
HORIZON_SHIFTS = {"5min": 1, "1h": 12, "12h": 144, "24h": 288}
FEATURE_COLUMNS = [
    "air_temperature",
    "wind_speed",
    "wind_direction",
    "precipitation_amount",
]
TARGET_COLUMN = "air_temperature"

DEFAULT_NUM_EPOCHS = int(os.getenv("ML_NUM_EPOCHS", "10"))
DEFAULT_LEARNING_RATE = float(os.getenv("ML_LEARNING_RATE", "1e-3"))
DEFAULT_BATCH_SIZE = int(os.getenv("ML_BATCH_SIZE", "32"))
DEFAULT_TEST_SIZE = float(os.getenv("ML_TEST_SIZE", "0.2"))
MIN_DATA_POINTS_FOR_TRAINING = int(os.getenv("ML_MIN_DATA_POINTS", "50"))


# --- ORM Base and Models ---
class Base(DeclarativeBase):
    """Base class for ORM models in this standalone job.

    Notes
    -----
    The standalone training job defines a minimal subset of ORM models to
    avoid importing the main application's metadata. Tables mirror the
    fields in ``app.models`` sufficiently for writing logs and status.
    """

    pass


class WeatherObservation(Base):
    """Observational weather record (non-imputed preferred for training).

    Attributes
    ----------
    timestamp : datetime
        Unique timestamp for the observation (UTC).
    latitude : float
        Coordinate latitude in decimal degrees.
    longitude : float
        Coordinate longitude in decimal degrees.
    air_temperature : float | None
        Air temperature in degrees Celsius.
    wind_speed : float | None
        Wind speed in m/s.
    wind_direction : float | None
        Wind direction in degrees.
    cloud_area_fraction : float | None
        Cloud cover fraction (0–1) when available.
    precipitation_amount : float | None
        Precipitation amount in mm for the interval.
    is_imputed : bool
        Whether the record was imputed (training uses non-imputed).
    """

    __tablename__ = "weather_observations"

    timestamp: Mapped[datetime] = mapped_column(DateTime, primary_key=True)
    latitude: Mapped[float] = mapped_column(Float, primary_key=True)
    longitude: Mapped[float] = mapped_column(Float, primary_key=True)
    air_temperature: Mapped[float | None] = mapped_column(Float, nullable=True)
    wind_speed: Mapped[float | None] = mapped_column(Float, nullable=True)
    wind_direction: Mapped[float | None] = mapped_column(Float, nullable=True)
    cloud_area_fraction: Mapped[float | None] = mapped_column(Float, nullable=True)
    precipitation_amount: Mapped[float | None] = mapped_column(Float, nullable=True)
    is_imputed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)


class TrainingLog(Base):
    """Log entry for one training run and horizon.

    Attributes
    ----------
    id : str
        UUID identifier of the run.
    timestamp : datetime
        When the run completed (UTC).
    horizon : str
        Unique horizon key (often ``"<coord>_<horizon>"``).
    sklearn_score : float
        R^2 score for the Sklearn model on the test split.
    pytorch_score : float
        R^2 score for the PyTorch model on the test split.
    data_count : int
        Number of samples used for this horizon after preprocessing.
    coord_latitude : float | None
        Latitude of the coordinate or ``None`` for aggregate runs.
    coord_longitude : float | None
        Longitude of the coordinate or ``None`` for aggregate runs.
    horizon_label : str | None
        One of ``{"5min", "1h", "12h", "24h"}``.
    """

    __tablename__ = "training_logs"

    id: Mapped[str] = mapped_column(
        String, primary_key=True, default=lambda: str(uuid.uuid4()), index=True
    )
    timestamp: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
    horizon: Mapped[str] = mapped_column(String)
    sklearn_score: Mapped[float] = mapped_column(Float)
    pytorch_score: Mapped[float] = mapped_column(Float)
    data_count: Mapped[int] = mapped_column(Integer)
    coord_latitude: Mapped[float | None] = mapped_column(Float, nullable=True)
    coord_longitude: Mapped[float | None] = mapped_column(Float, nullable=True)
    horizon_label: Mapped[str | None] = mapped_column(String, nullable=True)


class TrainingStatus(Base):
    """Singleton table reflecting the current training state.

    Attributes
    ----------
    id : int
        Primary key (always 1 in this job flow).
    is_training : bool
        Whether a training job is running.
    last_trained_at : datetime | None
        Timestamp of the last successful training completion.
    train_count : int
        Number of training runs since system start or initialization.
    current_horizon : str | None
        Human-readable status message or horizon marker.
    """

    __tablename__ = "training_status"

    id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
    is_training: Mapped[bool] = mapped_column(Boolean, default=False)
    last_trained_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
    train_count: Mapped[int] = mapped_column(Integer, default=0)
    current_horizon: Mapped[str | None] = mapped_column(String, nullable=True)


class SimpleRegressionNet(nn.Module):
    """Simple feed-forward regression network for 1-step regression.

    Parameters
    ----------
    input_dim : int
        Number of input features (must be positive).

    Examples
    --------
    >>> import torch                                            # doctest: +SKIP
    >>> from app.ml_train import SimpleRegressionNet            # doctest: +SKIP
    >>> net = SimpleRegressionNet(input_dim=4)                  # doctest: +SKIP
    >>> x = torch.randn(2, 4)                                   # doctest: +SKIP
    >>> y = net(x)                                              # doctest: +SKIP
    >>> y.shape == (2, 1)                                       # doctest: +SKIP
    True
    """

    def __init__(self, input_dim: int) -> None:
        super().__init__()
        assert input_dim > 0, f"input_dim must be positive, got {input_dim}"
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute predictions for input features.

        Parameters
        ----------
        x : torch.Tensor
            Input batch with shape ``(batch_size, input_dim)``.

        Returns
        -------
        torch.Tensor
            Predicted target values with shape ``(batch_size, 1)``.
        """
        return cast(torch.Tensor, self.net(x))


# --- Database Session ---
def get_db_session() -> SQLAlchemySession:
    """Create and return a new SQLAlchemy session.

    Returns
    -------
    sqlalchemy.orm.Session
        Session bound to the engine specified by ``DATABASE_URL``.

    Raises
    ------
    ValueError
        If ``DATABASE_URL`` is not defined in the environment.
    sqlalchemy.exc.SQLAlchemyError
        If the engine cannot be created or the metadata initialization fails.

    Examples
    --------
    >>> import os                                               # doctest: +SKIP
    >>> os.environ["DATABASE_URL"] = "sqlite:///:memory:"       # doctest: +SKIP
    >>> from app.ml_train import get_db_session                 # doctest: +SKIP
    >>> s = get_db_session()                                    # doctest: +SKIP
    >>> s.close()                                               # doctest: +SKIP
    """
    database_url = os.getenv("DATABASE_URL")
    if not database_url:
        logger.error("DATABASE_URL environment variable not set.")
        raise ValueError("DATABASE_URL is required")

    connect_args = {}
    if database_url.startswith("sqlite"):
        connect_args["check_same_thread"] = False

    engine = create_engine(database_url, connect_args=connect_args)
    Base.metadata.create_all(engine)
    SessionLocal = sessionmaker(bind=engine)
    return SessionLocal()


# --- Core Functions ---
def load_training_data(
    session: SQLAlchemySession, latitude: float, longitude: float
) -> pd.DataFrame:
    """Load non-imputed observations for one coordinate.

    Parameters
    ----------
    session : sqlalchemy.orm.Session
        Open SQLAlchemy session.
    latitude : float
        Coordinate latitude in decimal degrees.
    longitude : float
        Coordinate longitude in decimal degrees.

    Returns
    -------
    pandas.DataFrame
        DataFrame with a ``timestamp`` column parsed to ``datetime64[ns]``.
        Empty DataFrame if no non-imputed rows are available.

    Examples
    --------
    >>> # Requires seeded DB with weather_observations          # doctest: +SKIP
    >>> from app.ml_train import load_training_data             # doctest: +SKIP
    >>> df = load_training_data(session, 57.70, 11.90)          # doctest: +SKIP
    >>> isinstance(df.empty, bool)                              # doctest: +SKIP
    True
    """
    logger.info(f"Loading non-imputed data for ({latitude}, {longitude})")
    try:
        query = session.query(WeatherObservation).filter_by(
            latitude=latitude, longitude=longitude, is_imputed=False
        )
        assert session.bind, "Session must be bound to an engine"
        df = pd.read_sql(query.statement, session.bind)
        if df.empty:
            logger.warning(f"No actual data for ({latitude}, {longitude})")
            return pd.DataFrame()

        df["timestamp"] = pd.to_datetime(df["timestamp"])
        logger.info(f"Loaded {len(df)} rows for ({latitude}, {longitude})")
        return df
    except SQLAlchemyError as e:
        logger.error(
            f"Error loading data for ({latitude}, {longitude}): {e}", exc_info=True
        )
        return pd.DataFrame()


def prepare_horizon_data(
    df: pd.DataFrame, horizon_label: str, shift_steps: int
) -> Tuple[NDArray[np.float64], NDArray[np.float64], int]:
    """Prepare features and targets for a specific horizon.

    Parameters
    ----------
    df : pandas.DataFrame
        Input data sorted/filtered per coordinate.
    horizon_label : str
        Label for the horizon (e.g., ``"5min"``, ``"1h"``).
    shift_steps : int
        Positive number of steps the target is shifted into the future.

    Returns
    -------
    tuple[numpy.ndarray, numpy.ndarray, int]
        Tuple ``(X, y, count)`` where count is the number of samples retained
        after dropping NA rows for the chosen horizon.

    Examples
    --------
    >>> import pandas as pd                                     # doctest: +SKIP
    >>> from app.ml_train import prepare_horizon_data           # doctest: +SKIP
    >>> ts = pd.date_range("2024-01-01", periods=12, freq="5min")  # doctest: +SKIP
    >>> df = pd.DataFrame({                                     # doctest: +SKIP
    ...     "timestamp": ts,
    ...     "air_temperature": list(range(12)),
    ...     "wind_speed": [1.0]*12,
    ...     "wind_direction": [0.0]*12,
    ...     "precipitation_amount": [0.0]*12,
    ... })
    >>> X, y, count = prepare_horizon_data(df, "5min", 1)       # doctest: +SKIP
    >>> count >= 10                                              # doctest: +SKIP
    True
    """
    if df.empty or len(df) < shift_steps + 1:
        logger.debug(f"Insufficient rows for horizon '{horizon_label}'")
        return (
            np.empty((0, len(FEATURE_COLUMNS)), dtype=np.float64),
            np.empty(0, dtype=np.float64),
            0,
        )

    df_sorted = df.sort_values("timestamp").copy()
    future_col = f"future_{TARGET_COLUMN}_{horizon_label}"
    df_sorted[future_col] = df_sorted[TARGET_COLUMN].shift(-shift_steps)
    df_sorted.dropna(subset=FEATURE_COLUMNS + [future_col], inplace=True)

    count = len(df_sorted)
    threshold = (
        max(5, int(MIN_DATA_POINTS_FOR_TRAINING * 0.2))
        if horizon_label == "5min"
        else MIN_DATA_POINTS_FOR_TRAINING
    )
    if count < threshold:
        logger.warning(
            f"Not enough data after prep for '{horizon_label}' ({count} < {threshold})"
        )
        return (
            np.empty((0, len(FEATURE_COLUMNS)), dtype=np.float64),
            np.empty(0, dtype=np.float64),
            0,
        )

    X = df_sorted[FEATURE_COLUMNS].to_numpy(dtype=np.float64)
    y = df_sorted[future_col].to_numpy(dtype=np.float64)
    return X, y, count


def train_and_save_model(
    X_train: NDArray[np.float64],
    y_train: NDArray[np.float64],
    X_test: NDArray[np.float64],
    y_test: NDArray[np.float64],
    horizon: str,
    coord_str: str,
) -> Tuple[float, float]:
    """Train Ridge and PyTorch models, persist them, and return R² scores.

    Parameters
    ----------
    X_train, y_train, X_test, y_test : numpy.ndarray
        Training and test splits.
    horizon : str
        Horizon label used for filenames and logging.
    coord_str : str
        Coordinate identifier (e.g., ``"lat57_7000_lon11_9000"``).

    Returns
    -------
    tuple[float, float]
        ``(sklearn_r2, pytorch_r2)`` scores on the test split.

    Notes
    -----
    - Models are saved under ``/data/models/<coord_str>/`` with deterministic
      filenames per horizon and framework.

    Raises
    ------
    OSError
        If persisting the model artifacts to disk fails.

    Examples
    --------
    >>> import numpy as np                                      # doctest: +SKIP
    >>> from app.ml_train import train_and_save_model           # doctest: +SKIP
    >>> X = np.random.rand(100, 4); y = np.random.rand(100)     # doctest: +SKIP
    >>> s = int(len(X)*0.8)                                     # doctest: +SKIP
    >>> train_and_save_model(X[:s], y[:s], X[s:], y[s:],        # doctest: +SKIP
    ...                      "5min", "lat57_7000_lon11_9000")  # doctest: +SKIP
    (..., ...)
    """
    model_dir = DATA_DIRECTORY / "models" / coord_str
    model_dir.mkdir(parents=True, exist_ok=True)

    sklearn_path = model_dir / f"sklearn_model_{horizon}.pkl"
    pytorch_path = model_dir / f"pytorch_model_{horizon}.pt"

    # Sklearn Ridge
    sklearn_model = Ridge()
    sklearn_model.fit(X_train, y_train)
    sklearn_score = sklearn_model.score(X_test, y_test) if len(X_test) else 0.0
    joblib.dump(sklearn_model, sklearn_path)
    logger.info(f"Saved Sklearn model at {sklearn_path} (R²={sklearn_score:.4f})")

    # PyTorch
    Xt = torch.from_numpy(X_train.astype(np.float32))
    yt = torch.from_numpy(y_train.astype(np.float32)).view(-1, 1)
    dataset = TensorDataset(Xt, yt)
    loader = DataLoader(dataset, batch_size=DEFAULT_BATCH_SIZE, shuffle=True)

    net = SimpleRegressionNet(X_train.shape[1])
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=DEFAULT_LEARNING_RATE)

    net.train()
    for epoch in range(DEFAULT_NUM_EPOCHS):
        total_loss = 0.0
        for batch_X, batch_y in loader:
            optimizer.zero_grad()
            preds = net(batch_X)
            loss = criterion(preds, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        logger.debug(
            f"Epoch {epoch + 1}/{DEFAULT_NUM_EPOCHS}: Loss={total_loss / len(loader):.4f}"
        )

    net.eval()
    pytorch_score = 0.0
    if len(X_test):
        with torch.no_grad():
            preds = (
                net(torch.from_numpy(X_test.astype(np.float32))).cpu().numpy().flatten()
            )
        pytorch_score = r2_score(y_test, preds)

    torch.save(net.state_dict(), pytorch_path)
    logger.info(f"Saved PyTorch model at {pytorch_path} (R²={pytorch_score:.4f})")

    return float(sklearn_score), float(pytorch_score)


def update_training_status(
    session: SQLAlchemySession,
    is_training: bool,
    current_status_message: str | None = None,
    increment_count: bool = False,
) -> None:
    """Update the ``training_status`` row with the current state.

    Parameters
    ----------
    session : sqlalchemy.orm.Session
        Open SQLAlchemy session.
    is_training : bool
        Flag indicating whether a training is in progress.
    current_status_message : str | None, optional
        Optional status message or horizon marker.
    increment_count : bool, optional
        Whether to increment the total training counter.

    Examples
    --------
    >>> # Within an open SQLAlchemy session                    # doctest: +SKIP
    >>> from app.ml_train import update_training_status        # doctest: +SKIP
    >>> update_training_status(session, True, "Job started")   # doctest: +SKIP
    """
    try:
        status = session.query(TrainingStatus).get(1)
        if not status:
            status = TrainingStatus(id=1)
            session.add(status)

        status.is_training = is_training
        if current_status_message is not None:
            status.current_horizon = current_status_message
        if not is_training:
            status.last_trained_at = datetime.utcnow()
            status.current_horizon = None
        if increment_count:
            status.train_count = (status.train_count or 0) + 1

        session.commit()
    except SQLAlchemyError as e:
        logger.error(f"Failed to update training status: {e}", exc_info=True)
        session.rollback()


def main() -> None:
    """Entry point for the standalone ML training job.

    Notes
    -----
    - Coordinates are fetched (preferentially central ones) and iterated. For
      each horizon, data are prepared, models trained, artifacts saved, and a
      ``TrainingLog`` row written.
    - Status updates are written to ``TrainingStatus`` throughout the run.
    Examples
    --------
    >>> # Executed inside Slurm job container                   # doctest: +SKIP
    >>> from app.ml_train import main                           # doctest: +SKIP
    >>> main()                                                  # doctest: +SKIP
    """
    logger.info("=== Starting ML Training Job ===")
    try:
        session = get_db_session()
    except (ValueError, AssertionError) as e:
        logger.error(f"Configuration error: {e}")
        return

    try:
        from sqlalchemy import text

        central_query = session.execute(
            text("SELECT latitude, longitude FROM coordinates WHERE is_central = true")
        )
        all_coords = [(row.latitude, row.longitude) for row in central_query.fetchall()]
    except SQLAlchemyError as e:
        logger.warning(f"Could not fetch central coords: {e}")
        all_coords = [(56.8618, 14.8069)]

    if not all_coords:
        update_training_status(
            session, is_training=False, current_status_message="No data"
        )
        session.close()
        return

    update_training_status(
        session, is_training=True, current_status_message="Job started"
    )

    overall_success = True
    try:
        for lat, lon in all_coords:
            coord_str = f"lat{lat:.4f}_lon{lon:.4f}".replace(".", "_")
            update_training_status(
                session,
                is_training=True,
                current_status_message=f"Training {coord_str}",
            )

            df = load_training_data(session, lat, lon)
            if df.empty:
                logger.warning(f"Skipping {coord_str}, no data")
                continue

            for horizon, steps in HORIZON_SHIFTS.items():
                update_training_status(
                    session,
                    is_training=True,
                    current_status_message=f"{coord_str}:{horizon}",
                )

                X, y, count = prepare_horizon_data(df, horizon, steps)
                if count == 0:
                    session.add(
                        TrainingLog(
                            id=str(uuid.uuid4()),
                            horizon=f"{coord_str}_{horizon}",
                            sklearn_score=float(0.0),
                            pytorch_score=float(0.0),
                            data_count=int(0),
                            coord_latitude=float(lat),
                            coord_longitude=float(lon),
                            horizon_label=horizon,
                        )
                    )
                    session.commit()
                    continue

                X_train, X_test, y_train, y_test = train_test_split(
                    X, y, test_size=DEFAULT_TEST_SIZE, shuffle=False
                )
                if (
                    len(X_train) < MIN_DATA_POINTS_FOR_TRAINING // 2
                    or len(X_test) < MIN_DATA_POINTS_FOR_TRAINING // 10
                ):
                    logger.warning(
                        f"Skipping {coord_str}:{horizon}, insufficient split data"
                    )
                    continue

                sklearn_score, pytorch_score = train_and_save_model(
                    X_train, y_train, X_test, y_test, horizon, coord_str
                )

                session.add(
                    TrainingLog(
                        id=str(uuid.uuid4()),
                        horizon=f"{coord_str}_{horizon}",
                        sklearn_score=float(sklearn_score),
                        pytorch_score=float(pytorch_score),
                        data_count=int(count),
                        coord_latitude=float(lat),
                        coord_longitude=float(lon),
                        horizon_label=horizon,
                    )
                )
                session.commit()

    except SQLAlchemyError as e:
        logger.error(f"Training loop error: {e}", exc_info=True)
        overall_success = False
        update_training_status(
            session, is_training=False, current_status_message=f"Error: {e}"
        )
    finally:
        if overall_success:
            update_training_status(
                session, is_training=False, current_status_message="Completed"
            )
            logger.info("=== ML Training Job Completed ===")
        else:
            logger.error("=== ML Training Job Failed ===")
        session.close()


if __name__ == "__main__":
    # Prepare directories
    DATA_DIRECTORY.mkdir(parents=True, exist_ok=True)
    (DATA_DIRECTORY / "models").mkdir(parents=True, exist_ok=True)
    main()