"""
Utility helpers that standardize interactions with the training pipeline.
This module provides a compact set of helper functions used by higher-level
training orchestration code. The helpers serve two purposes: (1) centralize
validation and mapping of horizon labels to deterministic shift values, and
(2) normalize return values from legacy training entrypoints during the
architecture transition to the consolidated ``app.ml_train`` module. By
keeping these behaviors in one place, the surrounding application code stays
small, explicit, and easier to test.
See Also
--------
app.training_jobs : Orchestrates background training using these helpers.
app.ml_train.HORIZON_SHIFTS : Source of truth for horizon-to-shift steps.
Notes
-----
- Primary role: validate horizons and normalize training results for code
paths that orchestrate training across multiple horizons.
- Key dependencies: ``app.ml_train.HORIZON_SHIFTS`` for the canonical mapping
of horizon labels to step shifts; ``pandas`` only for type annotations in
compatibility wrappers.
- Invariants: ``get_horizon_shift`` raises ``KeyError`` for unknown horizons;
``unpack_training_result`` accepts exactly two tuple shapes and asserts
otherwise.
Examples
--------
>>> from app.training_helpers import get_horizon_shift, unpack_training_result
>>> get_horizon_shift("5min")
1
>>> unpack_training_result((0.91, 0.87, 120))
(0.91, 0.87, 120)
>>> unpack_training_result(("5min", 0.91, 0.87, 120))
(0.91, 0.87, 120)
"""
import logging
from typing import Tuple, Union
import pandas as pd
from .ml_train import HORIZON_SHIFTS
logger = logging.getLogger(__name__)
[docs]
def get_horizon_shift(horizon: str) -> int:
"""Return the number of time steps for a horizon label.
Parameters
----------
horizon : str
Horizon label key (for example, ``"5min"``, ``"1h"``, ``"12h"``,
``"24h"``). The label must exist in
``app.ml_train.HORIZON_SHIFTS``.
Returns
-------
int
The positive integer shift value associated with the given horizon.
Raises
------
KeyError
If ``horizon`` is not present in ``HORIZON_SHIFTS``.
Examples
--------
>>> from app.training_helpers import get_horizon_shift
>>> get_horizon_shift("1h")
12
"""
try:
return HORIZON_SHIFTS[horizon]
except KeyError:
valid_keys = list(HORIZON_SHIFTS)
logger.error(
f"Invalid horizon '{horizon}'. Valid horizons: {valid_keys}", exc_info=True
)
raise
[docs]
def unpack_training_result(
result: Union[Tuple[float, float, int], Tuple[str, float, float, int]],
) -> Tuple[float, float, int]:
"""Normalize a training result to ``(sklearn_score, pytorch_score, count)``.
This helper accepts two legacy return shapes from training calls and
produces a consistent 3-tuple used by callers. Supported shapes are
``(float, float, int)`` and ``(str, float, float, int)`` (where the
leading ``str`` is a horizon label that is ignored by this function).
Parameters
----------
result : tuple[float, float, int] | tuple[str, float, float, int]
Raw result from a training function.
Returns
-------
tuple[float, float, int]
``(sklearn_score, pytorch_score, data_count)``.
Raises
------
AssertionError
If the value types do not match the supported shapes.
Examples
--------
>>> from app.training_helpers import unpack_training_result
>>> unpack_training_result((0.75, 0.60, 42))
(0.75, 0.6, 42)
>>> unpack_training_result(("5min", 0.75, 0.60, 42))
(0.75, 0.6, 42)
"""
if len(result) == 4:
_, sklearn_score, pytorch_score, data_count = result
else:
sklearn_score, pytorch_score, data_count = result
assert (
isinstance(sklearn_score, float)
and isinstance(pytorch_score, float)
and isinstance(data_count, int)
), (
"Unexpected training result format "
f"{result!r}; expected (float, float, int) or (str, float, float, int)"
)
return sklearn_score, pytorch_score, data_count
[docs]
def train_models_for_horizon(
data_frame: pd.DataFrame, horizon: str, shift_value: int
) -> Tuple[float, float, int]:
"""Compatibility wrapper that preserves a legacy training API.
The platform's current training implementation lives in ``app.ml_train``.
This function exists to avoid breaking older call sites that expect a
function with this signature. It logs a warning and returns deterministic
dummy values.
Parameters
----------
data_frame : pandas.DataFrame
Unused; present for backward compatibility only.
horizon : str
Horizon label for logging context.
shift_value : int
Shift steps for the horizon; unused here but retained for callers
that still provide it.
Returns
-------
tuple[float, float, int]
Always returns ``(0.0, 0.0, 0)``.
Notes
-----
- Prefer invoking the consolidated training flow in ``app.ml_train``.
Examples
--------
>>> import pandas as pd
>>> from app.training_helpers import train_models_for_horizon
>>> df = pd.DataFrame({"a": [1, 2, 3]})
>>> train_models_for_horizon(df, "5min", 1)
(0.0, 0.0, 0)
"""
logger.warning(
f"train_models_for_horizon called for horizon {horizon} - this is a compatibility wrapper"
)
# Return dummy values - in practice, the new architecture doesn't use this function
return 0.0, 0.0, 0