Source code for app.ml_utils

# src/app/ml_utils.py

"""
Machine learning utilities for models and training-log queries.

This module provides small, focused utilities that are used across the
training and visualization layers: a minimal PyTorch regression network,
typed structures for training results, and helpers to retrieve and format the
latest and historical training scores from the database. It exists to keep
shared ML-centric logic isolated from the FastAPI views and the standalone
training job.

See Also
--------
app.ml_train : Standalone Slurm-executed training pipeline that writes logs.
app.slurm_job_trigger : Dispatches training jobs into the Slurm cluster.
app.database : Engine and session factories (``SessionLocal``).
app.models.TrainingLog : ORM model consumed by the query helpers here.

Notes
-----
- Primary role: define lightweight ML helpers and expose read-only queries for
  :class:`app.models.TrainingLog` suitable for dashboards and APIs.
- Key dependencies: a reachable database via :data:`app.database.SessionLocal` and
  an optional writable shared volume at ``/data`` for model artifacts.
- Invariants: the database schema for ``training_logs`` must be present.

Examples
--------
>>> # Fetch latest scores grouped by horizon (requires DB)  # doctest: +SKIP
>>> from app.ml_utils import get_latest_training_logs          # doctest: +SKIP
>>> latest = get_latest_training_logs()                        # doctest: +SKIP
>>> isinstance(latest, dict)                                   # doctest: +SKIP
True

>>> # Create a tiny regression net (no DB required)            # doctest: +SKIP
>>> import torch                                               # doctest: +SKIP
>>> from app.ml_utils import SimpleRegressionNet               # doctest: +SKIP
>>> net = SimpleRegressionNet(input_dim=4)                     # doctest: +SKIP
>>> y = net(torch.randn(2, 4))                                 # doctest: +SKIP
>>> y.shape                                                    # doctest: +SKIP
torch.Size([2, 1])
"""


import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict

from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from torch import Tensor, nn

from .database import SessionLocal
from .models import TrainingLog

logger = logging.getLogger(__name__)

ENV_DATA_PATH_KEY: str = "SLURM_JOB_DATA_PATH"
DEFAULT_DATA_PATH: Path = Path("/data")
DATA_DIRECTORY: Path = Path(os.getenv(ENV_DATA_PATH_KEY, str(DEFAULT_DATA_PATH)))
MODEL_SUBDIR: str = "models"
TIMESTAMP_FORMAT: str = "%Y-%m-%d %H:%M:%S"


[docs] class TrainingLogDetails(TypedDict): """Structured details for a single training log entry. This typed mapping captures the essential fields used by the UI and reporting layers when presenting the most recent score per horizon. Attributes ---------- timestamp : datetime | None Completion time of the training run in UTC. sklearn_score : float R^2 score of the Scikit-learn model for this run. pytorch_score : float R^2 score of the PyTorch model for this run. data_count : int Number of samples used for training and evaluation. coord_latitude : float | None Coordinate latitude associated with the run, if any. coord_longitude : float | None Coordinate longitude associated with the run, if any. horizon_label : str | None Human-friendly horizon label (e.g., ``"5min"``, ``"1h"``), if set. horizon_display_name : str Preformatted string suitable for charts/legends. """ timestamp: Optional[datetime] sklearn_score: float pytorch_score: float data_count: int coord_latitude: Optional[float] coord_longitude: Optional[float] horizon_label: Optional[str] horizon_display_name: str
[docs] def assert_positive_input_dim(input_dim: int) -> None: """Validate that ``input_dim`` is a positive integer. Parameters ---------- input_dim : int The number of input features expected by the model. Must be ``> 0``. Raises ------ ValueError If ``input_dim`` is not a positive integer. Examples -------- >>> assert_positive_input_dim(4) >>> assert_positive_input_dim(0) Traceback (most recent call last): ... ValueError: input_dim must be a positive integer, but was 0 (type: <class 'int'>). """ if not isinstance(input_dim, int) or input_dim <= 0: raise ValueError( f"input_dim must be a positive integer, but was {input_dim} (type: {type(input_dim)})." )
[docs] class SimpleRegressionNet(nn.Module): """A minimal fully connected network for regression tasks. Parameters ---------- input_dim : int Number of input features; must be a positive integer. Examples -------- >>> import torch # doctest: +SKIP >>> net = SimpleRegressionNet(input_dim=3) # doctest: +SKIP >>> out = net(torch.randn(2, 3)) # doctest: +SKIP >>> out.shape # doctest: +SKIP torch.Size([2, 1]) """ def __init__(self, input_dim: int) -> None: """Construct the module with a single hidden layer. Parameters ---------- input_dim : int The number of input features. """ assert_positive_input_dim(input_dim) super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, 64), nn.ReLU(), nn.Linear(64, 1), )
[docs] def forward(self, x: Tensor) -> Tensor: """Compute predictions for a batch of inputs. Parameters ---------- x : torch.Tensor Input feature tensor with shape ``(batch_size, input_dim)``. Returns ------- torch.Tensor Output tensor with shape ``(batch_size, 1)``. """ output: Tensor = self.net(x) return output
def _get_distinct_horizon_keys(session: Session) -> List[str]: """Retrieve distinct, non-null horizon keys from ``TrainingLog``. Parameters ---------- session : sqlalchemy.orm.Session Active database session bound to the application engine. Returns ------- list[str] Sorted list of unique horizon key strings (order as returned by DB). Examples -------- >>> from app.database import SessionLocal # doctest: +SKIP >>> with SessionLocal() as s: # doctest: +SKIP ... keys = _get_distinct_horizon_keys(s) # doctest: +SKIP ... isinstance(keys, list) # doctest: +SKIP True """ rows = session.query(TrainingLog.horizon).distinct().all() return [row[0] for row in rows if row[0] is not None] def _format_display_name( latitude: Optional[float], longitude: Optional[float], label: Optional[str] ) -> str: """Format a compact, human-readable display name. Parameters ---------- latitude : float | None Coordinate latitude in decimal degrees, if available. longitude : float | None Coordinate longitude in decimal degrees, if available. label : str | None Horizon label (e.g., ``"5min"`` or ``"1h"``) if known. Returns ------- str A formatted display name, for example ``"Coord (59.33, 18.07) - Horizon: 1h"``. """ lat_str = f"{latitude:.2f}" if latitude is not None else "N/A" lon_str = f"{longitude:.2f}" if longitude is not None else "N/A" label_str = label if label else "N/A" return f"Coord ({lat_str}, {lon_str}) - Horizon: {label_str}" def _build_latest_training_log_details( session: Session, key: str ) -> TrainingLogDetails: """Build details for the most recent log of a horizon. Parameters ---------- session : sqlalchemy.orm.Session Active database session. key : str Horizon key to filter by. Returns ------- TrainingLogDetails Structured details for the latest log entry with the given key. Raises ------ ValueError If no entry is found for the provided horizon key. """ entry = ( session.query(TrainingLog) .filter(TrainingLog.horizon == key) .order_by(TrainingLog.timestamp.desc()) .first() ) if entry is None: raise ValueError(f"No training log entry found for horizon key: {key}") display_name = _format_display_name( entry.coord_latitude, entry.coord_longitude, entry.horizon_label ) return TrainingLogDetails( timestamp=entry.timestamp, sklearn_score=float(entry.sklearn_score), pytorch_score=float(entry.pytorch_score), data_count=int(entry.data_count), coord_latitude=entry.coord_latitude, coord_longitude=entry.coord_longitude, horizon_label=entry.horizon_label, horizon_display_name=display_name, ) def _build_historical_scores_for_key(session: Session, key: str) -> Dict[str, Any]: """Collect historical scores for a specific horizon key. Parameters ---------- session : sqlalchemy.orm.Session Active database session. key : str Horizon key to filter by. Returns ------- dict[str, Any] Mapping with keys ``"timestamps"``, ``"sklearn_scores"``, ``"pytorch_scores"``, and ``"display_name"``. Raises ------ ValueError If no entries exist for the provided horizon key. """ entries = ( session.query(TrainingLog) .filter(TrainingLog.horizon == key) .order_by(TrainingLog.timestamp) .all() ) if not entries: raise ValueError(f"No training log entries found for horizon key: {key}") display_name = _format_display_name( entries[0].coord_latitude, entries[0].coord_longitude, entries[0].horizon_label ) return { "timestamps": [entry.timestamp.strftime(TIMESTAMP_FORMAT) for entry in entries], "sklearn_scores": [entry.sklearn_score for entry in entries], "pytorch_scores": [entry.pytorch_score for entry in entries], "display_name": display_name, }
[docs] def get_latest_training_logs() -> Dict[str, TrainingLogDetails]: """Fetch the latest training log per horizon. Iterates over distinct horizon keys in ``training_logs`` and returns the most recent entry for each. Database errors are logged and an empty mapping is returned on failure to keep callers resilient. Returns ------- dict[str, TrainingLogDetails] Mapping from horizon key to latest log details. Notes ----- - All exceptions are caught and logged; on any error this function returns an empty dictionary. """ try: with SessionLocal() as session: horizon_keys = _get_distinct_horizon_keys(session) logger.debug("Found %d distinct horizon keys", len(horizon_keys)) return { key: _build_latest_training_log_details(session, key) for key in horizon_keys } except SQLAlchemyError as err: logger.error( "Database error fetching latest training logs: %s", err, exc_info=True ) except Exception as err: logger.error( "Unexpected error fetching latest training logs: %s", err, exc_info=True ) return {}
[docs] def get_historical_scores() -> Dict[str, Dict[str, Any]]: """Fetch historical scores grouped by horizon. Returns time-ordered scores for every distinct horizon key found in the database. Database errors are logged and an empty mapping is returned on failure to keep callers resilient. Returns ------- dict[str, dict[str, Any]] Mapping from horizon key to a dictionary with keys ``"timestamps"``, ``"sklearn_scores"``, ``"pytorch_scores"``, and ``"display_name"``. Notes ----- - All exceptions are caught and logged; on any error this function returns an empty dictionary. """ try: with SessionLocal() as session: horizon_keys = _get_distinct_horizon_keys(session) logger.debug("Found %d distinct horizon keys", len(horizon_keys)) return { key: _build_historical_scores_for_key(session, key) for key in horizon_keys } except SQLAlchemyError as err: logger.error( "Database error fetching historical training scores: %s", err, exc_info=True ) except Exception as err: logger.error( "Unexpected error fetching historical scores: %s", err, exc_info=True ) return {}