Source code for app.training_helpers

"""
Utility helpers that standardize interactions with the training pipeline.

This module provides a compact set of helper functions used by higher-level
training orchestration code. The helpers serve two purposes: (1) centralize
validation and mapping of horizon labels to deterministic shift values, and
(2) normalize return values from legacy training entrypoints during the
architecture transition to the consolidated ``app.ml_train`` module. By
keeping these behaviors in one place, the surrounding application code stays
small, explicit, and easier to test.

See Also
--------
app.training_jobs : Orchestrates background training using these helpers.
app.ml_train.HORIZON_SHIFTS : Source of truth for horizon-to-shift steps.

Notes
-----
- Primary role: validate horizons and normalize training results for code
  paths that orchestrate training across multiple horizons.
- Key dependencies: ``app.ml_train.HORIZON_SHIFTS`` for the canonical mapping
  of horizon labels to step shifts; ``pandas`` only for type annotations in
  compatibility wrappers.
- Invariants: ``get_horizon_shift`` raises ``KeyError`` for unknown horizons;
  ``unpack_training_result`` accepts exactly two tuple shapes and asserts
  otherwise.

Examples
--------
>>> from app.training_helpers import get_horizon_shift, unpack_training_result
>>> get_horizon_shift("5min")
1
>>> unpack_training_result((0.91, 0.87, 120))
(0.91, 0.87, 120)
>>> unpack_training_result(("5min", 0.91, 0.87, 120))
(0.91, 0.87, 120)
"""

import logging
from typing import Tuple, Union

import pandas as pd

from .ml_train import HORIZON_SHIFTS

logger = logging.getLogger(__name__)


[docs] def get_horizon_shift(horizon: str) -> int: """Return the number of time steps for a horizon label. Parameters ---------- horizon : str Horizon label key (for example, ``"5min"``, ``"1h"``, ``"12h"``, ``"24h"``). The label must exist in ``app.ml_train.HORIZON_SHIFTS``. Returns ------- int The positive integer shift value associated with the given horizon. Raises ------ KeyError If ``horizon`` is not present in ``HORIZON_SHIFTS``. Examples -------- >>> from app.training_helpers import get_horizon_shift >>> get_horizon_shift("1h") 12 """ try: return HORIZON_SHIFTS[horizon] except KeyError: valid_keys = list(HORIZON_SHIFTS) logger.error( f"Invalid horizon '{horizon}'. Valid horizons: {valid_keys}", exc_info=True ) raise
[docs] def unpack_training_result( result: Union[Tuple[float, float, int], Tuple[str, float, float, int]], ) -> Tuple[float, float, int]: """Normalize a training result to ``(sklearn_score, pytorch_score, count)``. This helper accepts two legacy return shapes from training calls and produces a consistent 3-tuple used by callers. Supported shapes are ``(float, float, int)`` and ``(str, float, float, int)`` (where the leading ``str`` is a horizon label that is ignored by this function). Parameters ---------- result : tuple[float, float, int] | tuple[str, float, float, int] Raw result from a training function. Returns ------- tuple[float, float, int] ``(sklearn_score, pytorch_score, data_count)``. Raises ------ AssertionError If the value types do not match the supported shapes. Examples -------- >>> from app.training_helpers import unpack_training_result >>> unpack_training_result((0.75, 0.60, 42)) (0.75, 0.6, 42) >>> unpack_training_result(("5min", 0.75, 0.60, 42)) (0.75, 0.6, 42) """ if len(result) == 4: _, sklearn_score, pytorch_score, data_count = result else: sklearn_score, pytorch_score, data_count = result assert ( isinstance(sklearn_score, float) and isinstance(pytorch_score, float) and isinstance(data_count, int) ), ( "Unexpected training result format " f"{result!r}; expected (float, float, int) or (str, float, float, int)" ) return sklearn_score, pytorch_score, data_count
[docs] def train_models_for_horizon( data_frame: pd.DataFrame, horizon: str, shift_value: int ) -> Tuple[float, float, int]: """Compatibility wrapper that preserves a legacy training API. The platform's current training implementation lives in ``app.ml_train``. This function exists to avoid breaking older call sites that expect a function with this signature. It logs a warning and returns deterministic dummy values. Parameters ---------- data_frame : pandas.DataFrame Unused; present for backward compatibility only. horizon : str Horizon label for logging context. shift_value : int Shift steps for the horizon; unused here but retained for callers that still provide it. Returns ------- tuple[float, float, int] Always returns ``(0.0, 0.0, 0)``. Notes ----- - Prefer invoking the consolidated training flow in ``app.ml_train``. Examples -------- >>> import pandas as pd >>> from app.training_helpers import train_models_for_horizon >>> df = pd.DataFrame({"a": [1, 2, 3]}) >>> train_models_for_horizon(df, "5min", 1) (0.0, 0.0, 0) """ logger.warning( f"train_models_for_horizon called for horizon {horizon} - this is a compatibility wrapper" ) # Return dummy values - in practice, the new architecture doesn't use this function return 0.0, 0.0, 0