"""
Coordinate manual and background training jobs across horizons.
This module orchestrates the end-to-end training flow invoked either from the
web application or from scheduled background tasks. It sets and clears the
``TrainingStatus`` singleton, fetches raw observations as a Pandas DataFrame,
iterates over configured horizons, and delegates horizon-specific work to the
training helpers. Errors are logged with stack traces and the database state
is kept consistent by clearing the in-progress flag on failure.
See Also
--------
app.ml_train.HORIZON_SHIFTS
app.training_helpers.get_horizon_shift
app.training_helpers.train_models_for_horizon
app.training_helpers.unpack_training_result
app.models.TrainingStatus
app.models.WeatherObservation
app.database.SessionLocal
Notes
-----
- Primary role: orchestrate training across all horizons and maintain the
``TrainingStatus`` state while delegating scoring to the training pipeline.
- Key dependencies: ``app.database.SessionLocal`` for DB access; ORM tables
``app.models.TrainingStatus`` and ``app.models.WeatherObservation``;
horizon mapping from ``app.ml_train.HORIZON_SHIFTS``; helpers in
``app.training_helpers``.
- Invariants: ``TrainingStatus`` is treated as a singleton with ``id=1``; the
database must be reachable.
Examples
--------
>>> from app.training_jobs import launch_training_thread # doctest: +SKIP
>>> launch_training_thread() # doctest: +SKIP
>>> # Or run synchronously (useful for tests/local debugging) # doctest: +SKIP
>>> from app.training_jobs import run_training_and_update_status # doctest: +SKIP
>>> run_training_and_update_status() # doctest: +SKIP
"""
import logging
from datetime import datetime
from threading import Thread
from typing import Callable, Optional
import pandas as pd
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from .database import SessionLocal
from .ml_train import HORIZON_SHIFTS
from .models import TrainingStatus, WeatherObservation
from .training_helpers import (
get_horizon_shift,
train_models_for_horizon,
unpack_training_result,
)
logger = logging.getLogger(__name__)
TRAINING_STATUS_ID: int = 1
[docs]
def launch_training_thread(on_complete: Optional[Callable[[], None]] = None) -> None:
"""Launch the training workflow in a daemon thread.
The thread runs :func:`run_training_and_update_status` and returns
immediately to the caller. The optional ``on_complete`` callback is
invoked only when the training finishes without raising an exception.
Parameters
----------
on_complete : Callable[[], None] | None, optional
Callback executed after a successful training run, by default ``None``.
Examples
--------
>>> from app.training_jobs import launch_training_thread # doctest: +SKIP
>>> launch_training_thread() # doctest: +SKIP
Notes
-----
- The thread is created with ``daemon=True``; it may be terminated early
if the hosting process exits before training completes.
"""
thread = Thread(
target=run_training_and_update_status, args=(on_complete,), daemon=True
)
thread.start()
logger.info("Training thread started.")
[docs]
def run_training_and_update_status(
on_complete: Optional[Callable[[], None]] = None,
) -> None:
"""Run the full training workflow across all horizons.
The function toggles the ``TrainingStatus`` singleton to in-progress,
fetches all observations, iterates through the configured horizons, and
delegates per-horizon training. On success, it clears the in-progress flag
and optionally invokes ``on_complete``. On any exception, the error is
logged, and the in-progress flag is cleared with an error marker.
Parameters
----------
on_complete : Callable[[], None] | None, optional
Callback executed after a successful run, by default ``None``.
Notes
-----
- All exceptions are caught and logged; they do not propagate to the
caller. Inspect logs and the ``TrainingStatus`` row for error details.
"""
try:
with SessionLocal() as session:
set_training_in_progress(get_or_create_training_status(session), session)
data_frame = fetch_all_data_points()
for horizon_key in HORIZON_SHIFTS:
update_current_horizon(horizon_key)
train_and_log_for_horizon(data_frame, horizon_key)
with SessionLocal() as session:
clear_training_in_progress(session)
logger.info("Training completed for all horizons.")
if on_complete:
on_complete()
except Exception as error:
logger.error(f"Error during threaded training: {error}", exc_info=True)
with SessionLocal() as session:
clear_training_in_progress(session, error=True)
[docs]
def get_or_create_training_status(session: Session) -> TrainingStatus:
"""Return the singleton ``TrainingStatus`` row, creating it if missing.
Parameters
----------
session : sqlalchemy.orm.Session
Open SQLAlchemy session bound to the application engine.
Returns
-------
app.models.TrainingStatus
The existing or newly created singleton row with ``id=1``.
Raises
------
sqlalchemy.exc.SQLAlchemyError
If the insert/commit of a missing row fails or if the query fails.
"""
logger.debug("Fetching or creating TrainingStatus row.")
try:
status = session.get(TrainingStatus, TRAINING_STATUS_ID)
if status is None:
status = TrainingStatus(
id=TRAINING_STATUS_ID, is_training=False, train_count=0
)
session.add(status)
session.commit()
session.refresh(status)
logger.info(f"Created TrainingStatus row with id {TRAINING_STATUS_ID}.")
else:
logger.debug("Found existing TrainingStatus row.")
return status
except SQLAlchemyError as exc:
logger.error(f"Error fetching or creating TrainingStatus: {exc}", exc_info=True)
raise
[docs]
def set_training_in_progress(training_status: TrainingStatus, session: Session) -> None:
"""Mark training as in-progress and increment the run counter.
Parameters
----------
training_status : app.models.TrainingStatus
The status row to update (its ``id`` is used in the update query).
session : sqlalchemy.orm.Session
Open SQLAlchemy session.
Raises
------
sqlalchemy.exc.SQLAlchemyError
If the update or commit fails.
"""
logger.info("Setting TrainingStatus to in-progress.")
try:
session.query(TrainingStatus).filter(
TrainingStatus.id == training_status.id
).update(
{
TrainingStatus.is_training: True,
TrainingStatus.train_count: TrainingStatus.train_count + 1,
TrainingStatus.current_horizon: "",
},
synchronize_session=False,
)
session.commit()
logger.info("TrainingStatus in-progress flag set and count incremented.")
except SQLAlchemyError:
logger.error("Error setting TrainingStatus to in-progress.", exc_info=True)
session.rollback()
raise
[docs]
def fetch_all_data_points() -> pd.DataFrame:
"""Load all observations as a timestamp-sorted DataFrame.
Returns
-------
pandas.DataFrame
All rows from ``WeatherObservation`` sorted ascending by ``timestamp``.
Raises
------
AssertionError
If the session is not bound to an engine.
sqlalchemy.exc.SQLAlchemyError
If the query or read fails due to database errors.
"""
with SessionLocal() as session:
assert session.bind is not None, "Session is not bound to an engine."
data_frame = pd.read_sql(
session.query(WeatherObservation).statement, session.bind
)
sorted_df = data_frame.sort_values("timestamp")
logger.info(f"Fetched {len(sorted_df)} data points for training.")
return sorted_df
[docs]
def update_current_horizon(horizon: str) -> None:
"""Persist the current horizon label to ``TrainingStatus``.
Parameters
----------
horizon : str
Horizon label being trained (for example, ``"5min"`` or ``"1h"``).
Raises
------
sqlalchemy.exc.SQLAlchemyError
If the update or commit fails.
"""
logger.info(f"Updating TrainingStatus: current_horizon={horizon}")
try:
with SessionLocal() as session:
session.query(TrainingStatus).filter(
TrainingStatus.id == TRAINING_STATUS_ID
).update(
{TrainingStatus.current_horizon: horizon}, synchronize_session=False
)
session.commit()
logger.info(f"TrainingStatus database updated: current_horizon={horizon}")
except SQLAlchemyError as exc:
logger.error(
f"Error updating current_horizon to '{horizon}': {exc}", exc_info=True
)
raise
[docs]
def train_and_log_for_horizon(data_frame: pd.DataFrame, horizon: str) -> None:
"""Train for one horizon and log summary metrics.
Delegates to the helpers to obtain the shift value and execute training.
Supports a legacy call signature by retrying without a shift argument when
a ``TypeError`` is raised (for backward compatibility). Metrics are
normalized via :func:`app.training_helpers.unpack_training_result` and
logged for observability.
Parameters
----------
data_frame : pandas.DataFrame
Full dataset used for training across horizons.
horizon : str
The horizon label to train for.
Raises
------
Exception
Any exception raised by the underlying training helpers after the
compatibility retry is re-raised to the caller.
"""
logger.info(f"Starting training for horizon '{horizon}'.")
try:
shift_value = get_horizon_shift(horizon)
try:
result = train_models_for_horizon(data_frame, horizon, shift_value)
except TypeError:
result = train_models_for_horizon(data_frame, horizon) # type: ignore[call-arg]
sklearn_score, pytorch_score, data_count = unpack_training_result(result)
logger.info(
f"Trained models for horizon {horizon}: "
f"sklearn_score={sklearn_score:.4f}, "
f"pytorch_score={pytorch_score:.4f}, data_count={data_count}"
)
except Exception as exc:
logger.error(f"Training failed for horizon {horizon}: {exc}", exc_info=True)
raise
else:
logger.info(f"Completed training and logging for horizon '{horizon}'.")
[docs]
def clear_training_in_progress(session: Session, error: bool = False) -> None:
"""Clear the in-progress flag and update completion timestamp.
Parameters
----------
session : sqlalchemy.orm.Session
Open SQLAlchemy session.
error : bool, optional
Whether training ended with an error (affects log level), by default
``False``.
Raises
------
sqlalchemy.exc.SQLAlchemyError
If the update or commit fails.
"""
logger.info(f"Clearing TrainingStatus in-progress flag (error={error}).")
try:
session.query(TrainingStatus).filter(
TrainingStatus.id == TRAINING_STATUS_ID
).update(
{
TrainingStatus.is_training: False,
TrainingStatus.last_trained_at: datetime.utcnow(),
TrainingStatus.current_horizon: "",
},
synchronize_session=False,
)
session.commit()
if error:
logger.warning("TrainingStatus cleared after error.")
else:
logger.info("TrainingStatus cleared after successful training.")
except SQLAlchemyError:
logger.error("Error clearing TrainingStatus in-progress flag.", exc_info=True)
session.rollback()
raise