# src/app/ml_utils.py
"""
Machine learning utilities for models and training-log queries.
This module provides small, focused utilities that are used across the
training and visualization layers: a minimal PyTorch regression network,
typed structures for training results, and helpers to retrieve and format the
latest and historical training scores from the database. It exists to keep
shared ML-centric logic isolated from the FastAPI views and the standalone
training job.
See Also
--------
app.ml_train : Standalone Slurm-executed training pipeline that writes logs.
app.slurm_job_trigger : Dispatches training jobs into the Slurm cluster.
app.database : Engine and session factories (``SessionLocal``).
app.models.TrainingLog : ORM model consumed by the query helpers here.
Notes
-----
- Primary role: define lightweight ML helpers and expose read-only queries for
:class:`app.models.TrainingLog` suitable for dashboards and APIs.
- Key dependencies: a reachable database via :data:`app.database.SessionLocal` and
an optional writable shared volume at ``/data`` for model artifacts.
- Invariants: the database schema for ``training_logs`` must be present.
Examples
--------
>>> # Fetch latest scores grouped by horizon (requires DB) # doctest: +SKIP
>>> from app.ml_utils import get_latest_training_logs # doctest: +SKIP
>>> latest = get_latest_training_logs() # doctest: +SKIP
>>> isinstance(latest, dict) # doctest: +SKIP
True
>>> # Create a tiny regression net (no DB required) # doctest: +SKIP
>>> import torch # doctest: +SKIP
>>> from app.ml_utils import SimpleRegressionNet # doctest: +SKIP
>>> net = SimpleRegressionNet(input_dim=4) # doctest: +SKIP
>>> y = net(torch.randn(2, 4)) # doctest: +SKIP
>>> y.shape # doctest: +SKIP
torch.Size([2, 1])
"""
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from torch import Tensor, nn
from .database import SessionLocal
from .models import TrainingLog
logger = logging.getLogger(__name__)
ENV_DATA_PATH_KEY: str = "SLURM_JOB_DATA_PATH"
DEFAULT_DATA_PATH: Path = Path("/data")
DATA_DIRECTORY: Path = Path(os.getenv(ENV_DATA_PATH_KEY, str(DEFAULT_DATA_PATH)))
MODEL_SUBDIR: str = "models"
TIMESTAMP_FORMAT: str = "%Y-%m-%d %H:%M:%S"
[docs]
class TrainingLogDetails(TypedDict):
"""Structured details for a single training log entry.
This typed mapping captures the essential fields used by the UI and
reporting layers when presenting the most recent score per horizon.
Attributes
----------
timestamp : datetime | None
Completion time of the training run in UTC.
sklearn_score : float
R^2 score of the Scikit-learn model for this run.
pytorch_score : float
R^2 score of the PyTorch model for this run.
data_count : int
Number of samples used for training and evaluation.
coord_latitude : float | None
Coordinate latitude associated with the run, if any.
coord_longitude : float | None
Coordinate longitude associated with the run, if any.
horizon_label : str | None
Human-friendly horizon label (e.g., ``"5min"``, ``"1h"``), if set.
horizon_display_name : str
Preformatted string suitable for charts/legends.
"""
timestamp: Optional[datetime]
sklearn_score: float
pytorch_score: float
data_count: int
coord_latitude: Optional[float]
coord_longitude: Optional[float]
horizon_label: Optional[str]
horizon_display_name: str
[docs]
class SimpleRegressionNet(nn.Module):
"""A minimal fully connected network for regression tasks.
Parameters
----------
input_dim : int
Number of input features; must be a positive integer.
Examples
--------
>>> import torch # doctest: +SKIP
>>> net = SimpleRegressionNet(input_dim=3) # doctest: +SKIP
>>> out = net(torch.randn(2, 3)) # doctest: +SKIP
>>> out.shape # doctest: +SKIP
torch.Size([2, 1])
"""
def __init__(self, input_dim: int) -> None:
"""Construct the module with a single hidden layer.
Parameters
----------
input_dim : int
The number of input features.
"""
assert_positive_input_dim(input_dim)
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Compute predictions for a batch of inputs.
Parameters
----------
x : torch.Tensor
Input feature tensor with shape ``(batch_size, input_dim)``.
Returns
-------
torch.Tensor
Output tensor with shape ``(batch_size, 1)``.
"""
output: Tensor = self.net(x)
return output
def _get_distinct_horizon_keys(session: Session) -> List[str]:
"""Retrieve distinct, non-null horizon keys from ``TrainingLog``.
Parameters
----------
session : sqlalchemy.orm.Session
Active database session bound to the application engine.
Returns
-------
list[str]
Sorted list of unique horizon key strings (order as returned by DB).
Examples
--------
>>> from app.database import SessionLocal # doctest: +SKIP
>>> with SessionLocal() as s: # doctest: +SKIP
... keys = _get_distinct_horizon_keys(s) # doctest: +SKIP
... isinstance(keys, list) # doctest: +SKIP
True
"""
rows = session.query(TrainingLog.horizon).distinct().all()
return [row[0] for row in rows if row[0] is not None]
def _format_display_name(
latitude: Optional[float], longitude: Optional[float], label: Optional[str]
) -> str:
"""Format a compact, human-readable display name.
Parameters
----------
latitude : float | None
Coordinate latitude in decimal degrees, if available.
longitude : float | None
Coordinate longitude in decimal degrees, if available.
label : str | None
Horizon label (e.g., ``"5min"`` or ``"1h"``) if known.
Returns
-------
str
A formatted display name, for example
``"Coord (59.33, 18.07) - Horizon: 1h"``.
"""
lat_str = f"{latitude:.2f}" if latitude is not None else "N/A"
lon_str = f"{longitude:.2f}" if longitude is not None else "N/A"
label_str = label if label else "N/A"
return f"Coord ({lat_str}, {lon_str}) - Horizon: {label_str}"
def _build_latest_training_log_details(
session: Session, key: str
) -> TrainingLogDetails:
"""Build details for the most recent log of a horizon.
Parameters
----------
session : sqlalchemy.orm.Session
Active database session.
key : str
Horizon key to filter by.
Returns
-------
TrainingLogDetails
Structured details for the latest log entry with the given key.
Raises
------
ValueError
If no entry is found for the provided horizon key.
"""
entry = (
session.query(TrainingLog)
.filter(TrainingLog.horizon == key)
.order_by(TrainingLog.timestamp.desc())
.first()
)
if entry is None:
raise ValueError(f"No training log entry found for horizon key: {key}")
display_name = _format_display_name(
entry.coord_latitude, entry.coord_longitude, entry.horizon_label
)
return TrainingLogDetails(
timestamp=entry.timestamp,
sklearn_score=float(entry.sklearn_score),
pytorch_score=float(entry.pytorch_score),
data_count=int(entry.data_count),
coord_latitude=entry.coord_latitude,
coord_longitude=entry.coord_longitude,
horizon_label=entry.horizon_label,
horizon_display_name=display_name,
)
def _build_historical_scores_for_key(session: Session, key: str) -> Dict[str, Any]:
"""Collect historical scores for a specific horizon key.
Parameters
----------
session : sqlalchemy.orm.Session
Active database session.
key : str
Horizon key to filter by.
Returns
-------
dict[str, Any]
Mapping with keys ``"timestamps"``, ``"sklearn_scores"``,
``"pytorch_scores"``, and ``"display_name"``.
Raises
------
ValueError
If no entries exist for the provided horizon key.
"""
entries = (
session.query(TrainingLog)
.filter(TrainingLog.horizon == key)
.order_by(TrainingLog.timestamp)
.all()
)
if not entries:
raise ValueError(f"No training log entries found for horizon key: {key}")
display_name = _format_display_name(
entries[0].coord_latitude, entries[0].coord_longitude, entries[0].horizon_label
)
return {
"timestamps": [entry.timestamp.strftime(TIMESTAMP_FORMAT) for entry in entries],
"sklearn_scores": [entry.sklearn_score for entry in entries],
"pytorch_scores": [entry.pytorch_score for entry in entries],
"display_name": display_name,
}
[docs]
def get_latest_training_logs() -> Dict[str, TrainingLogDetails]:
"""Fetch the latest training log per horizon.
Iterates over distinct horizon keys in ``training_logs`` and returns the
most recent entry for each. Database errors are logged and an empty
mapping is returned on failure to keep callers resilient.
Returns
-------
dict[str, TrainingLogDetails]
Mapping from horizon key to latest log details.
Notes
-----
- All exceptions are caught and logged; on any error this function returns
an empty dictionary.
"""
try:
with SessionLocal() as session:
horizon_keys = _get_distinct_horizon_keys(session)
logger.debug("Found %d distinct horizon keys", len(horizon_keys))
return {
key: _build_latest_training_log_details(session, key)
for key in horizon_keys
}
except SQLAlchemyError as err:
logger.error(
"Database error fetching latest training logs: %s", err, exc_info=True
)
except Exception as err:
logger.error(
"Unexpected error fetching latest training logs: %s", err, exc_info=True
)
return {}
[docs]
def get_historical_scores() -> Dict[str, Dict[str, Any]]:
"""Fetch historical scores grouped by horizon.
Returns time-ordered scores for every distinct horizon key found in the
database. Database errors are logged and an empty mapping is returned on
failure to keep callers resilient.
Returns
-------
dict[str, dict[str, Any]]
Mapping from horizon key to a dictionary with keys ``"timestamps"``,
``"sklearn_scores"``, ``"pytorch_scores"``, and ``"display_name"``.
Notes
-----
- All exceptions are caught and logged; on any error this function returns
an empty dictionary.
"""
try:
with SessionLocal() as session:
horizon_keys = _get_distinct_horizon_keys(session)
logger.debug("Found %d distinct horizon keys", len(horizon_keys))
return {
key: _build_historical_scores_for_key(session, key)
for key in horizon_keys
}
except SQLAlchemyError as err:
logger.error(
"Database error fetching historical training scores: %s", err, exc_info=True
)
except Exception as err:
logger.error(
"Unexpected error fetching historical scores: %s", err, exc_info=True
)
return {}