training_helpers.py#

"""
Utility helpers that standardize interactions with the training pipeline.

This module provides a compact set of helper functions used by higher-level
training orchestration code. The helpers serve two purposes: (1) centralize
validation and mapping of horizon labels to deterministic shift values, and
(2) normalize return values from legacy training entrypoints during the
architecture transition to the consolidated ``app.ml_train`` module. By
keeping these behaviors in one place, the surrounding application code stays
small, explicit, and easier to test.

See Also
--------
app.training_jobs : Orchestrates background training using these helpers.
app.ml_train.HORIZON_SHIFTS : Source of truth for horizon-to-shift steps.

Notes
-----
- Primary role: validate horizons and normalize training results for code
  paths that orchestrate training across multiple horizons.
- Key dependencies: ``app.ml_train.HORIZON_SHIFTS`` for the canonical mapping
  of horizon labels to step shifts; ``pandas`` only for type annotations in
  compatibility wrappers.
- Invariants: ``get_horizon_shift`` raises ``KeyError`` for unknown horizons;
  ``unpack_training_result`` accepts exactly two tuple shapes and asserts
  otherwise.

Examples
--------
>>> from app.training_helpers import get_horizon_shift, unpack_training_result
>>> get_horizon_shift("5min")
1
>>> unpack_training_result((0.91, 0.87, 120))
(0.91, 0.87, 120)
>>> unpack_training_result(("5min", 0.91, 0.87, 120))
(0.91, 0.87, 120)
"""

import logging
from typing import Tuple, Union

import pandas as pd

from .ml_train import HORIZON_SHIFTS

logger = logging.getLogger(__name__)


def get_horizon_shift(horizon: str) -> int:
    """Return the number of time steps for a horizon label.

    Parameters
    ----------
    horizon : str
        Horizon label key (for example, ``"5min"``, ``"1h"``, ``"12h"``,
        ``"24h"``). The label must exist in
        ``app.ml_train.HORIZON_SHIFTS``.

    Returns
    -------
    int
        The positive integer shift value associated with the given horizon.

    Raises
    ------
    KeyError
        If ``horizon`` is not present in ``HORIZON_SHIFTS``.

    Examples
    --------
    >>> from app.training_helpers import get_horizon_shift
    >>> get_horizon_shift("1h")
    12
    """
    try:
        return HORIZON_SHIFTS[horizon]
    except KeyError:
        valid_keys = list(HORIZON_SHIFTS)
        logger.error(
            f"Invalid horizon '{horizon}'. Valid horizons: {valid_keys}", exc_info=True
        )
        raise


def unpack_training_result(
    result: Union[Tuple[float, float, int], Tuple[str, float, float, int]],
) -> Tuple[float, float, int]:
    """Normalize a training result to ``(sklearn_score, pytorch_score, count)``.

    This helper accepts two legacy return shapes from training calls and
    produces a consistent 3-tuple used by callers. Supported shapes are
    ``(float, float, int)`` and ``(str, float, float, int)`` (where the
    leading ``str`` is a horizon label that is ignored by this function).

    Parameters
    ----------
    result : tuple[float, float, int] | tuple[str, float, float, int]
        Raw result from a training function.

    Returns
    -------
    tuple[float, float, int]
        ``(sklearn_score, pytorch_score, data_count)``.

    Raises
    ------
    AssertionError
        If the value types do not match the supported shapes.

    Examples
    --------
    >>> from app.training_helpers import unpack_training_result
    >>> unpack_training_result((0.75, 0.60, 42))
    (0.75, 0.6, 42)
    >>> unpack_training_result(("5min", 0.75, 0.60, 42))
    (0.75, 0.6, 42)
    """
    if len(result) == 4:
        _, sklearn_score, pytorch_score, data_count = result
    else:
        sklearn_score, pytorch_score, data_count = result
    assert (
        isinstance(sklearn_score, float)
        and isinstance(pytorch_score, float)
        and isinstance(data_count, int)
    ), (
        "Unexpected training result format "
        f"{result!r}; expected (float, float, int) or (str, float, float, int)"
    )
    return sklearn_score, pytorch_score, data_count


def train_models_for_horizon(
    data_frame: pd.DataFrame, horizon: str, shift_value: int
) -> Tuple[float, float, int]:
    """Compatibility wrapper that preserves a legacy training API.

    The platform's current training implementation lives in ``app.ml_train``.
    This function exists to avoid breaking older call sites that expect a
    function with this signature. It logs a warning and returns deterministic
    dummy values.

    Parameters
    ----------
    data_frame : pandas.DataFrame
        Unused; present for backward compatibility only.
    horizon : str
        Horizon label for logging context.
    shift_value : int
        Shift steps for the horizon; unused here but retained for callers
        that still provide it.

    Returns
    -------
    tuple[float, float, int]
        Always returns ``(0.0, 0.0, 0)``.

    Notes
    -----
    - Prefer invoking the consolidated training flow in ``app.ml_train``.

    Examples
    --------
    >>> import pandas as pd
    >>> from app.training_helpers import train_models_for_horizon
    >>> df = pd.DataFrame({"a": [1, 2, 3]})
    >>> train_models_for_horizon(df, "5min", 1)
    (0.0, 0.0, 0)
    """
    logger.warning(
        f"train_models_for_horizon called for horizon {horizon} - this is a compatibility wrapper"
    )
    # Return dummy values - in practice, the new architecture doesn't use this function
    return 0.0, 0.0, 0