# src/app/ml_utils.py"""Machine learning utilities for models and training-log queries.This module provides small, focused utilities that are used across thetraining and visualization layers: a minimal PyTorch regression network,typed structures for training results, and helpers to retrieve and format thelatest and historical training scores from the database. It exists to keepshared ML-centric logic isolated from the FastAPI views and the standalonetraining job.See Also--------app.ml_train : Standalone Slurm-executed training pipeline that writes logs.app.slurm_job_trigger : Dispatches training jobs into the Slurm cluster.app.database : Engine and session factories (``SessionLocal``).app.models.TrainingLog : ORM model consumed by the query helpers here.Notes------ Primary role: define lightweight ML helpers and expose read-only queries for :class:`app.models.TrainingLog` suitable for dashboards and APIs.- Key dependencies: a reachable database via :data:`app.database.SessionLocal` and an optional writable shared volume at ``/data`` for model artifacts.- Invariants: the database schema for ``training_logs`` must be present.Examples-------->>> # Fetch latest scores grouped by horizon (requires DB) # doctest: +SKIP>>> from app.ml_utils import get_latest_training_logs # doctest: +SKIP>>> latest = get_latest_training_logs() # doctest: +SKIP>>> isinstance(latest, dict) # doctest: +SKIPTrue>>> # Create a tiny regression net (no DB required) # doctest: +SKIP>>> import torch # doctest: +SKIP>>> from app.ml_utils import SimpleRegressionNet # doctest: +SKIP>>> net = SimpleRegressionNet(input_dim=4) # doctest: +SKIP>>> y = net(torch.randn(2, 4)) # doctest: +SKIP>>> y.shape # doctest: +SKIPtorch.Size([2, 1])"""importloggingimportosfromdatetimeimportdatetimefrompathlibimportPathfromtypingimportAny,Dict,List,Optional,TypedDictfromsqlalchemy.excimportSQLAlchemyErrorfromsqlalchemy.ormimportSessionfromtorchimportTensor,nnfrom.databaseimportSessionLocalfrom.modelsimportTrainingLoglogger=logging.getLogger(__name__)ENV_DATA_PATH_KEY:str="SLURM_JOB_DATA_PATH"DEFAULT_DATA_PATH:Path=Path("/data")DATA_DIRECTORY:Path=Path(os.getenv(ENV_DATA_PATH_KEY,str(DEFAULT_DATA_PATH)))MODEL_SUBDIR:str="models"TIMESTAMP_FORMAT:str="%Y-%m-%d %H:%M:%S"classTrainingLogDetails(TypedDict):"""Structured details for a single training log entry. This typed mapping captures the essential fields used by the UI and reporting layers when presenting the most recent score per horizon. Attributes ---------- timestamp : datetime | None Completion time of the training run in UTC. sklearn_score : float R^2 score of the Scikit-learn model for this run. pytorch_score : float R^2 score of the PyTorch model for this run. data_count : int Number of samples used for training and evaluation. coord_latitude : float | None Coordinate latitude associated with the run, if any. coord_longitude : float | None Coordinate longitude associated with the run, if any. horizon_label : str | None Human-friendly horizon label (e.g., ``"5min"``, ``"1h"``), if set. horizon_display_name : str Preformatted string suitable for charts/legends. """timestamp:Optional[datetime]sklearn_score:floatpytorch_score:floatdata_count:intcoord_latitude:Optional[float]coord_longitude:Optional[float]horizon_label:Optional[str]horizon_display_name:strdefassert_positive_input_dim(input_dim:int)->None:"""Validate that ``input_dim`` is a positive integer. Parameters ---------- input_dim : int The number of input features expected by the model. Must be ``> 0``. Raises ------ ValueError If ``input_dim`` is not a positive integer. Examples -------- >>> assert_positive_input_dim(4) >>> assert_positive_input_dim(0) Traceback (most recent call last): ... ValueError: input_dim must be a positive integer, but was 0 (type: <class 'int'>). """ifnotisinstance(input_dim,int)orinput_dim<=0:raiseValueError(f"input_dim must be a positive integer, but was {input_dim} (type: {type(input_dim)}).")classSimpleRegressionNet(nn.Module):"""A minimal fully connected network for regression tasks. Parameters ---------- input_dim : int Number of input features; must be a positive integer. Examples -------- >>> import torch # doctest: +SKIP >>> net = SimpleRegressionNet(input_dim=3) # doctest: +SKIP >>> out = net(torch.randn(2, 3)) # doctest: +SKIP >>> out.shape # doctest: +SKIP torch.Size([2, 1]) """def__init__(self,input_dim:int)->None:"""Construct the module with a single hidden layer. Parameters ---------- input_dim : int The number of input features. """assert_positive_input_dim(input_dim)super().__init__()self.net=nn.Sequential(nn.Linear(input_dim,64),nn.ReLU(),nn.Linear(64,1),)defforward(self,x:Tensor)->Tensor:"""Compute predictions for a batch of inputs. Parameters ---------- x : torch.Tensor Input feature tensor with shape ``(batch_size, input_dim)``. Returns ------- torch.Tensor Output tensor with shape ``(batch_size, 1)``. """output:Tensor=self.net(x)returnoutputdef_get_distinct_horizon_keys(session:Session)->List[str]:"""Retrieve distinct, non-null horizon keys from ``TrainingLog``. Parameters ---------- session : sqlalchemy.orm.Session Active database session bound to the application engine. Returns ------- list[str] Sorted list of unique horizon key strings (order as returned by DB). Examples -------- >>> from app.database import SessionLocal # doctest: +SKIP >>> with SessionLocal() as s: # doctest: +SKIP ... keys = _get_distinct_horizon_keys(s) # doctest: +SKIP ... isinstance(keys, list) # doctest: +SKIP True """rows=session.query(TrainingLog.horizon).distinct().all()return[row[0]forrowinrowsifrow[0]isnotNone]def_format_display_name(latitude:Optional[float],longitude:Optional[float],label:Optional[str])->str:"""Format a compact, human-readable display name. Parameters ---------- latitude : float | None Coordinate latitude in decimal degrees, if available. longitude : float | None Coordinate longitude in decimal degrees, if available. label : str | None Horizon label (e.g., ``"5min"`` or ``"1h"``) if known. Returns ------- str A formatted display name, for example ``"Coord (59.33, 18.07) - Horizon: 1h"``. """lat_str=f"{latitude:.2f}"iflatitudeisnotNoneelse"N/A"lon_str=f"{longitude:.2f}"iflongitudeisnotNoneelse"N/A"label_str=labeliflabelelse"N/A"returnf"Coord ({lat_str}, {lon_str}) - Horizon: {label_str}"def_build_latest_training_log_details(session:Session,key:str)->TrainingLogDetails:"""Build details for the most recent log of a horizon. Parameters ---------- session : sqlalchemy.orm.Session Active database session. key : str Horizon key to filter by. Returns ------- TrainingLogDetails Structured details for the latest log entry with the given key. Raises ------ ValueError If no entry is found for the provided horizon key. """entry=(session.query(TrainingLog).filter(TrainingLog.horizon==key).order_by(TrainingLog.timestamp.desc()).first())ifentryisNone:raiseValueError(f"No training log entry found for horizon key: {key}")display_name=_format_display_name(entry.coord_latitude,entry.coord_longitude,entry.horizon_label)returnTrainingLogDetails(timestamp=entry.timestamp,sklearn_score=float(entry.sklearn_score),pytorch_score=float(entry.pytorch_score),data_count=int(entry.data_count),coord_latitude=entry.coord_latitude,coord_longitude=entry.coord_longitude,horizon_label=entry.horizon_label,horizon_display_name=display_name,)def_build_historical_scores_for_key(session:Session,key:str)->Dict[str,Any]:"""Collect historical scores for a specific horizon key. Parameters ---------- session : sqlalchemy.orm.Session Active database session. key : str Horizon key to filter by. Returns ------- dict[str, Any] Mapping with keys ``"timestamps"``, ``"sklearn_scores"``, ``"pytorch_scores"``, and ``"display_name"``. Raises ------ ValueError If no entries exist for the provided horizon key. """entries=(session.query(TrainingLog).filter(TrainingLog.horizon==key).order_by(TrainingLog.timestamp).all())ifnotentries:raiseValueError(f"No training log entries found for horizon key: {key}")display_name=_format_display_name(entries[0].coord_latitude,entries[0].coord_longitude,entries[0].horizon_label)return{"timestamps":[entry.timestamp.strftime(TIMESTAMP_FORMAT)forentryinentries],"sklearn_scores":[entry.sklearn_scoreforentryinentries],"pytorch_scores":[entry.pytorch_scoreforentryinentries],"display_name":display_name,}defget_latest_training_logs()->Dict[str,TrainingLogDetails]:"""Fetch the latest training log per horizon. Iterates over distinct horizon keys in ``training_logs`` and returns the most recent entry for each. Database errors are logged and an empty mapping is returned on failure to keep callers resilient. Returns ------- dict[str, TrainingLogDetails] Mapping from horizon key to latest log details. Notes ----- - All exceptions are caught and logged; on any error this function returns an empty dictionary. """try:withSessionLocal()assession:horizon_keys=_get_distinct_horizon_keys(session)logger.debug("Found %d distinct horizon keys",len(horizon_keys))return{key:_build_latest_training_log_details(session,key)forkeyinhorizon_keys}exceptSQLAlchemyErroraserr:logger.error("Database error fetching latest training logs: %s",err,exc_info=True)exceptExceptionaserr:logger.error("Unexpected error fetching latest training logs: %s",err,exc_info=True)return{}defget_historical_scores()->Dict[str,Dict[str,Any]]:"""Fetch historical scores grouped by horizon. Returns time-ordered scores for every distinct horizon key found in the database. Database errors are logged and an empty mapping is returned on failure to keep callers resilient. Returns ------- dict[str, dict[str, Any]] Mapping from horizon key to a dictionary with keys ``"timestamps"``, ``"sklearn_scores"``, ``"pytorch_scores"``, and ``"display_name"``. Notes ----- - All exceptions are caught and logged; on any error this function returns an empty dictionary. """try:withSessionLocal()assession:horizon_keys=_get_distinct_horizon_keys(session)logger.debug("Found %d distinct horizon keys",len(horizon_keys))return{key:_build_historical_scores_for_key(session,key)forkeyinhorizon_keys}exceptSQLAlchemyErroraserr:logger.error("Database error fetching historical training scores: %s",err,exc_info=True)exceptExceptionaserr:logger.error("Unexpected error fetching historical scores: %s",err,exc_info=True)return{}