models.py#

# src/app/models.py
"""
Declarative SQLAlchemy ORM models used across the application.

This module defines the core database entities for the platform: coordinates
for data collection, time-stamped weather observations, a singleton table for
training status, and an append-only log of ML training runs. These models are
the canonical schema definition for the application and are used by both the
FastAPI backend and the standalone Slurm-executed training jobs for reading
and writing domain data.

See Also
--------
app.database : Engine/session factories and declarative ``Base``.
app.schemas : Pydantic schemas mirroring these ORM models.
app.ml_utils : Read-only helpers that query ``TrainingLog``.
app.ml_train : Slurm-run training job that writes ``TrainingLog`` entries.
app.coordinates_manager : Utilities for seeding coordinate grids.
app.imputation : Imputation routines operating on ``WeatherObservation``.

Notes
-----
- Primary role: provide ORM mappings for coordinates, observations,
  training status, and training logs bound to :class:`app.database.Base`.
- Key dependencies: a configured SQLAlchemy engine via
  :data:`app.database.engine` and corresponding session factories. The schema
  is created via :func:`app.database.ensure_database_schema`.
- Invariants: ``WeatherObservation`` uses a composite primary key
  ``(timestamp, latitude, longitude)``. ``TrainingStatus`` is treated as a
  singleton with primary key ``id=1``. ``TrainingLog.horizon`` is a
  non-null string key identifying coordinate+horizon groupings.

Examples
--------
>>> # Basic query pattern (requires an initialized DB)         # doctest: +SKIP
>>> from app.database import SessionLocal, ensure_database_schema
>>> from app.models import Coordinate, TrainingLog
>>> ensure_database_schema()                                   # doctest: +SKIP
>>> with SessionLocal() as session:                            # doctest: +SKIP
...     count = session.query(Coordinate).count()              # doctest: +SKIP
...     latest = (session.query(TrainingLog)
...               .order_by(TrainingLog.timestamp.desc())
...               .first())                                    # doctest: +SKIP
"""


import logging
import uuid
from datetime import datetime
from typing import Optional

from sqlalchemy import Boolean, DateTime, Float, Integer, String
from sqlalchemy.orm import Mapped, mapped_column

from .database import Base

logger = logging.getLogger(__name__)


class TrainingStatus(Base):
    """Snapshot of the current ML training state.

    This singleton table reflects whether a training job is running, when the
    last successful training completed, and which horizon is currently being
    processed. The application typically ensures that there is exactly one
    row with ``id=1`` and updates it transactionally during training flows.

    Attributes
    ----------
    id : int
        Primary key. Conventionally set to ``1`` to model a singleton row.
    is_training : bool
        Flag indicating whether a training job is currently running.
    last_trained_at : datetime | None
        Timestamp of the last completed training job in UTC, if any.
    train_count : int
        Monotonic counter of completed training runs.
    current_horizon : str | None
        Human-readable horizon label (e.g., ``"5min"``) or status message.
    """

    __tablename__ = "training_status"

    id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
    is_training: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
    last_trained_at: Mapped[Optional[datetime]] = mapped_column(
        DateTime, default=None, nullable=True
    )
    train_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
    current_horizon: Mapped[Optional[str]] = mapped_column(String, nullable=True)


class Coordinate(Base):
    """Geographic point used for weather data collection.

    Attributes
    ----------
    id : int
        Surrogate primary key.
    latitude : float
        Coordinate latitude in decimal degrees.
    longitude : float
        Coordinate longitude in decimal degrees.
    label : str | None
        Optional human-readable label for the coordinate.
    is_central : bool
        Marks the central coordinate used as a reference point.

    Notes
    -----
    - No uniqueness constraint is enforced at the ORM level for
      ``(latitude, longitude)``; duplicates are possible unless prevented by
      a database constraint or application logic.
    """

    __tablename__ = "coordinates"

    id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
    latitude: Mapped[float] = mapped_column(Float, nullable=False)
    longitude: Mapped[float] = mapped_column(Float, nullable=False)
    label: Mapped[Optional[str]] = mapped_column(String, nullable=True)
    is_central: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)


class TrainingLog(Base):
    """Append-only log of ML training runs and scores.

    Each row represents one training execution for a given horizon and (optionally)
    a specific coordinate. Scores from both a Scikit-learn model and a PyTorch
    model are recorded, along with the number of data points used.

    Attributes
    ----------
    id : str
        Primary key string. The training pipeline typically assigns a
        ``uuid4`` string explicitly when inserting a row. The default defined
        here is a string constant created at import time and should be
        overridden by callers to avoid collisions.
    timestamp : datetime
        Completion time of the training run in UTC.
    horizon : str
        Non-empty key identifying the grouping (often
        ``"<coord>_<horizon_label>"``).
    sklearn_score : float
        R^2 score from the Scikit-learn model.
    pytorch_score : float
        R^2 score from the PyTorch model.
    data_count : int
        Number of samples used for training/evaluation for this run.
    coord_latitude : float | None
        Coordinate latitude associated with the run, if available.
    coord_longitude : float | None
        Coordinate longitude associated with the run, if available.
    horizon_label : str | None
        Human-friendly label for the horizon (e.g., ``"5min"``, ``"1h"``).
    """

    __tablename__ = "training_logs"

    id: Mapped[str] = mapped_column(
        String, primary_key=True, default=str(uuid.uuid4()), index=True
    )
    timestamp: Mapped[datetime] = mapped_column(
        DateTime, default=datetime.utcnow, nullable=False
    )
    horizon: Mapped[str] = mapped_column(String, nullable=False)
    sklearn_score: Mapped[float] = mapped_column(Float, nullable=False)
    pytorch_score: Mapped[float] = mapped_column(Float, nullable=False)
    data_count: Mapped[int] = mapped_column(Integer, nullable=False)
    coord_latitude: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
    coord_longitude: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
    horizon_label: Mapped[Optional[str]] = mapped_column(String, nullable=True)


class WeatherObservation(Base):
    """Nowcast weather values for a specific time and location.

    This table mirrors the structure of MET Norway Nowcast 2.0 fields consumed
    by the application. The composite primary key ``(timestamp, latitude,
    longitude)`` uniquely identifies each observation in time and space.

    Attributes
    ----------
    timestamp : datetime
        Observation timestamp in UTC (part of the composite primary key).
    latitude : float
        Coordinate latitude in decimal degrees (part of the primary key).
    longitude : float
        Coordinate longitude in decimal degrees (part of the primary key).
    air_temperature : float | None
        Air temperature in degrees Celsius.
    wind_speed : float | None
        Wind speed in meters per second.
    wind_direction : float | None
        Wind direction in degrees (from which the wind is coming).
    cloud_area_fraction : float | None
        Fraction of the sky covered by clouds (0–1).
    precipitation_amount : float | None
        Precipitation amount in millimeters for the interval.
    is_imputed : bool
        Whether this record was imputed by preprocessing routines.
    """

    __tablename__ = "weather_observations"

    timestamp: Mapped[datetime] = mapped_column(
        DateTime, primary_key=True, nullable=False
    )
    latitude: Mapped[float] = mapped_column(Float, primary_key=True, nullable=False)
    longitude: Mapped[float] = mapped_column(Float, primary_key=True, nullable=False)
    air_temperature: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
    wind_speed: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
    wind_direction: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
    cloud_area_fraction: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
    precipitation_amount: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
    is_imputed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)