training_jobs.py#

"""
Coordinate manual and background training jobs across horizons.

This module orchestrates the end-to-end training flow invoked either from the
web application or from scheduled background tasks. It sets and clears the
``TrainingStatus`` singleton, fetches raw observations as a Pandas DataFrame,
iterates over configured horizons, and delegates horizon-specific work to the
training helpers. Errors are logged with stack traces and the database state
is kept consistent by clearing the in-progress flag on failure.

See Also
--------
app.ml_train.HORIZON_SHIFTS
app.training_helpers.get_horizon_shift
app.training_helpers.train_models_for_horizon
app.training_helpers.unpack_training_result
app.models.TrainingStatus
app.models.WeatherObservation
app.database.SessionLocal

Notes
-----
- Primary role: orchestrate training across all horizons and maintain the
  ``TrainingStatus`` state while delegating scoring to the training pipeline.
- Key dependencies: ``app.database.SessionLocal`` for DB access; ORM tables
  ``app.models.TrainingStatus`` and ``app.models.WeatherObservation``;
  horizon mapping from ``app.ml_train.HORIZON_SHIFTS``; helpers in
  ``app.training_helpers``.
- Invariants: ``TrainingStatus`` is treated as a singleton with ``id=1``; the
  database must be reachable.

Examples
--------
>>> from app.training_jobs import launch_training_thread           # doctest: +SKIP
>>> launch_training_thread()                                       # doctest: +SKIP
>>> # Or run synchronously (useful for tests/local debugging)      # doctest: +SKIP
>>> from app.training_jobs import run_training_and_update_status   # doctest: +SKIP
>>> run_training_and_update_status()                               # doctest: +SKIP
"""

import logging
from datetime import datetime
from threading import Thread
from typing import Callable, Optional

import pandas as pd
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session

from .database import SessionLocal
from .ml_train import HORIZON_SHIFTS
from .models import TrainingStatus, WeatherObservation
from .training_helpers import (
    get_horizon_shift,
    train_models_for_horizon,
    unpack_training_result,
)

logger = logging.getLogger(__name__)

TRAINING_STATUS_ID: int = 1


def launch_training_thread(on_complete: Optional[Callable[[], None]] = None) -> None:
    """Launch the training workflow in a daemon thread.

    The thread runs :func:`run_training_and_update_status` and returns
    immediately to the caller. The optional ``on_complete`` callback is
    invoked only when the training finishes without raising an exception.

    Parameters
    ----------
    on_complete : Callable[[], None] | None, optional
        Callback executed after a successful training run, by default ``None``.

    Examples
    --------
    >>> from app.training_jobs import launch_training_thread       # doctest: +SKIP
    >>> launch_training_thread()                                   # doctest: +SKIP

    Notes
    -----
    - The thread is created with ``daemon=True``; it may be terminated early
      if the hosting process exits before training completes.
    """
    thread = Thread(
        target=run_training_and_update_status, args=(on_complete,), daemon=True
    )
    thread.start()
    logger.info("Training thread started.")


def run_training_and_update_status(
    on_complete: Optional[Callable[[], None]] = None,
) -> None:
    """Run the full training workflow across all horizons.

    The function toggles the ``TrainingStatus`` singleton to in-progress,
    fetches all observations, iterates through the configured horizons, and
    delegates per-horizon training. On success, it clears the in-progress flag
    and optionally invokes ``on_complete``. On any exception, the error is
    logged, and the in-progress flag is cleared with an error marker.

    Parameters
    ----------
    on_complete : Callable[[], None] | None, optional
        Callback executed after a successful run, by default ``None``.

    Notes
    -----
    - All exceptions are caught and logged; they do not propagate to the
      caller. Inspect logs and the ``TrainingStatus`` row for error details.
    """
    try:
        with SessionLocal() as session:
            set_training_in_progress(get_or_create_training_status(session), session)

        data_frame = fetch_all_data_points()
        for horizon_key in HORIZON_SHIFTS:
            update_current_horizon(horizon_key)
            train_and_log_for_horizon(data_frame, horizon_key)

        with SessionLocal() as session:
            clear_training_in_progress(session)

        logger.info("Training completed for all horizons.")
        if on_complete:
            on_complete()
    except Exception as error:
        logger.error(f"Error during threaded training: {error}", exc_info=True)
        with SessionLocal() as session:
            clear_training_in_progress(session, error=True)


def get_or_create_training_status(session: Session) -> TrainingStatus:
    """Return the singleton ``TrainingStatus`` row, creating it if missing.

    Parameters
    ----------
    session : sqlalchemy.orm.Session
        Open SQLAlchemy session bound to the application engine.

    Returns
    -------
    app.models.TrainingStatus
        The existing or newly created singleton row with ``id=1``.

    Raises
    ------
    sqlalchemy.exc.SQLAlchemyError
        If the insert/commit of a missing row fails or if the query fails.
    """
    logger.debug("Fetching or creating TrainingStatus row.")
    try:
        status = session.get(TrainingStatus, TRAINING_STATUS_ID)
        if status is None:
            status = TrainingStatus(
                id=TRAINING_STATUS_ID, is_training=False, train_count=0
            )
            session.add(status)
            session.commit()
            session.refresh(status)
            logger.info(f"Created TrainingStatus row with id {TRAINING_STATUS_ID}.")
        else:
            logger.debug("Found existing TrainingStatus row.")
        return status
    except SQLAlchemyError as exc:
        logger.error(f"Error fetching or creating TrainingStatus: {exc}", exc_info=True)
        raise


def set_training_in_progress(training_status: TrainingStatus, session: Session) -> None:
    """Mark training as in-progress and increment the run counter.

    Parameters
    ----------
    training_status : app.models.TrainingStatus
        The status row to update (its ``id`` is used in the update query).
    session : sqlalchemy.orm.Session
        Open SQLAlchemy session.

    Raises
    ------
    sqlalchemy.exc.SQLAlchemyError
        If the update or commit fails.
    """
    logger.info("Setting TrainingStatus to in-progress.")
    try:
        session.query(TrainingStatus).filter(
            TrainingStatus.id == training_status.id
        ).update(
            {
                TrainingStatus.is_training: True,
                TrainingStatus.train_count: TrainingStatus.train_count + 1,
                TrainingStatus.current_horizon: "",
            },
            synchronize_session=False,
        )
        session.commit()
        logger.info("TrainingStatus in-progress flag set and count incremented.")
    except SQLAlchemyError:
        logger.error("Error setting TrainingStatus to in-progress.", exc_info=True)
        session.rollback()
        raise


def fetch_all_data_points() -> pd.DataFrame:
    """Load all observations as a timestamp-sorted DataFrame.

    Returns
    -------
    pandas.DataFrame
        All rows from ``WeatherObservation`` sorted ascending by ``timestamp``.

    Raises
    ------
    AssertionError
        If the session is not bound to an engine.
    sqlalchemy.exc.SQLAlchemyError
        If the query or read fails due to database errors.
    """
    with SessionLocal() as session:
        assert session.bind is not None, "Session is not bound to an engine."
        data_frame = pd.read_sql(
            session.query(WeatherObservation).statement, session.bind
        )
    sorted_df = data_frame.sort_values("timestamp")
    logger.info(f"Fetched {len(sorted_df)} data points for training.")
    return sorted_df


def update_current_horizon(horizon: str) -> None:
    """Persist the current horizon label to ``TrainingStatus``.

    Parameters
    ----------
    horizon : str
        Horizon label being trained (for example, ``"5min"`` or ``"1h"``).

    Raises
    ------
    sqlalchemy.exc.SQLAlchemyError
        If the update or commit fails.
    """
    logger.info(f"Updating TrainingStatus: current_horizon={horizon}")
    try:
        with SessionLocal() as session:
            session.query(TrainingStatus).filter(
                TrainingStatus.id == TRAINING_STATUS_ID
            ).update(
                {TrainingStatus.current_horizon: horizon}, synchronize_session=False
            )
            session.commit()
        logger.info(f"TrainingStatus database updated: current_horizon={horizon}")
    except SQLAlchemyError as exc:
        logger.error(
            f"Error updating current_horizon to '{horizon}': {exc}", exc_info=True
        )
        raise


def train_and_log_for_horizon(data_frame: pd.DataFrame, horizon: str) -> None:
    """Train for one horizon and log summary metrics.

    Delegates to the helpers to obtain the shift value and execute training.
    Supports a legacy call signature by retrying without a shift argument when
    a ``TypeError`` is raised (for backward compatibility). Metrics are
    normalized via :func:`app.training_helpers.unpack_training_result` and
    logged for observability.

    Parameters
    ----------
    data_frame : pandas.DataFrame
        Full dataset used for training across horizons.
    horizon : str
        The horizon label to train for.

    Raises
    ------
    Exception
        Any exception raised by the underlying training helpers after the
        compatibility retry is re-raised to the caller.
    """
    logger.info(f"Starting training for horizon '{horizon}'.")
    try:
        shift_value = get_horizon_shift(horizon)
        try:
            result = train_models_for_horizon(data_frame, horizon, shift_value)
        except TypeError:
            result = train_models_for_horizon(data_frame, horizon)  # type: ignore[call-arg]
        sklearn_score, pytorch_score, data_count = unpack_training_result(result)
        logger.info(
            f"Trained models for horizon {horizon}: "
            f"sklearn_score={sklearn_score:.4f}, "
            f"pytorch_score={pytorch_score:.4f}, data_count={data_count}"
        )
    except Exception as exc:
        logger.error(f"Training failed for horizon {horizon}: {exc}", exc_info=True)
        raise
    else:
        logger.info(f"Completed training and logging for horizon '{horizon}'.")


def clear_training_in_progress(session: Session, error: bool = False) -> None:
    """Clear the in-progress flag and update completion timestamp.

    Parameters
    ----------
    session : sqlalchemy.orm.Session
        Open SQLAlchemy session.
    error : bool, optional
        Whether training ended with an error (affects log level), by default
        ``False``.

    Raises
    ------
    sqlalchemy.exc.SQLAlchemyError
        If the update or commit fails.
    """
    logger.info(f"Clearing TrainingStatus in-progress flag (error={error}).")
    try:
        session.query(TrainingStatus).filter(
            TrainingStatus.id == TRAINING_STATUS_ID
        ).update(
            {
                TrainingStatus.is_training: False,
                TrainingStatus.last_trained_at: datetime.utcnow(),
                TrainingStatus.current_horizon: "",
            },
            synchronize_session=False,
        )
        session.commit()
        if error:
            logger.warning("TrainingStatus cleared after error.")
        else:
            logger.info("TrainingStatus cleared after successful training.")
    except SQLAlchemyError:
        logger.error("Error clearing TrainingStatus in-progress flag.", exc_info=True)
        session.rollback()
        raise