schemas.py#

"""
Pydantic models for request/response payloads and internal validation.

This module defines the data transfer objects (DTOs) used by the FastAPI
application and background jobs to validate inputs and serialize outputs.
Schemas mirror the structure of the SQLAlchemy ORM models where appropriate
and are designed to be safe for JSON encoding in API responses and logs.

See Also
--------
app.models : SQLAlchemy models mirrored by these schemas.
app.main : FastAPI endpoints that emit and consume these schemas.
app.database : Session/engine configuration used to produce ORM objects.

Notes
-----
- Primary role: provide Pydantic v2 models for API endpoints and service
  boundaries, enabling strict, well-typed validation and serialization.
- Key dependencies: relies on Pydantic v2. The ``from_attributes=True``
  configuration allows converting ORM objects (e.g., ``app.models``) into
  schemas via ``model_validate``.
- Invariants: field names intentionally match ORM attributes to simplify
  conversions. Timestamps are naive ``datetime`` objects assumed to be in UTC.

Examples
--------
>>> from datetime import datetime
>>> from app.schemas import CoordinateSchema, WeatherObservationSchema
>>> center = CoordinateSchema(latitude=59.33, longitude=18.06, label="STHLM")
>>> center.model_dump()["label"]
'STHLM'
>>> obs = WeatherObservationSchema(
...     timestamp=datetime(2024, 1, 1, 12, 0, 0),
...     latitude=59.33,
...     longitude=18.06,
...     air_temperature=2.3,
...     is_imputed=False,
... )
>>> isinstance(obs.timestamp, datetime)
True
"""

import logging
from datetime import datetime
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field

logger = logging.getLogger(__name__)


class CoordinateSchema(BaseModel):
    """Geographic coordinate used for data collection and grouping.

    This schema represents a point on Earth expressed in decimal degrees.
    It mirrors :class:`app.models.Coordinate` and is commonly returned by the
    API for map/grid operations and stored as part of training metadata.

    Parameters
    ----------
    id : int, optional
        Surrogate primary key if available; not required for requests.
    latitude : float
        Coordinate latitude in decimal degrees.
    longitude : float
        Coordinate longitude in decimal degrees.
    label : str, optional
        Optional human‑readable label for the coordinate.
    is_central : bool, optional
        Marks the central coordinate used as a reference point.

    Attributes
    ----------
    id : int | None
        Primary key, when present.
    latitude : float
        Latitude in decimal degrees.
    longitude : float
        Longitude in decimal degrees.
    label : str | None
        Optional label for UI or logs.
    is_central : bool | None
        Whether this coordinate is the designated center.

    Examples
    --------
    >>> CoordinateSchema(latitude=59.33, longitude=18.06).model_dump()["latitude"]
    59.33

    See Also
    --------
    app.models.Coordinate : ORM model providing the database mapping.
    """

    id: Optional[int] = Field(None, description="Primary key")
    latitude: float = Field(..., description="Latitude of the point")
    longitude: float = Field(..., description="Longitude of the point")
    label: Optional[str] = Field(None, description="Optional label for the point")
    is_central: Optional[bool] = Field(
        None, description="True if this is the central coordinate"
    )
    model_config = ConfigDict(from_attributes=True)


class WeatherObservationSchema(BaseModel):
    """Point‑in‑time weather values at a geographic location.

    Mirrors :class:`app.models.WeatherObservation`. The composite identity of
    an observation is ``(timestamp, latitude, longitude)``. Optional fields may
    be missing in raw feeds but can be provided after imputation.

    Parameters
    ----------
    timestamp : datetime
        Observation timestamp in UTC (naive ``datetime`` assumed to be UTC).
    latitude : float
        Coordinate latitude in decimal degrees.
    longitude : float
        Coordinate longitude in decimal degrees.
    air_temperature : float, optional
        Air temperature in degrees Celsius.
    wind_speed : float, optional
        Wind speed in meters per second.
    wind_direction : float, optional
        Wind direction in degrees (meteorological convention).
    cloud_area_fraction : float, optional
        Fraction of the sky covered by clouds (0–1).
    precipitation_amount : float, optional
        Precipitation amount in millimeters over the interval.
    is_imputed : bool
        Whether this record was generated by an imputation process.

    Attributes
    ----------
    Same as Parameters.

    Examples
    --------
    >>> WeatherObservationSchema(
    ...     timestamp=datetime(2024, 1, 1, 0, 0, 0),
    ...     latitude=59.33,
    ...     longitude=18.06,
    ...     is_imputed=False,
    ... ).model_dump()["is_imputed"]
    False

    See Also
    --------
    app.models.WeatherObservation : ORM model with the canonical schema.
    """

    timestamp: datetime
    latitude: float
    longitude: float
    air_temperature: Optional[float] = None
    wind_speed: Optional[float] = None
    wind_direction: Optional[float] = None
    cloud_area_fraction: Optional[float] = None
    precipitation_amount: Optional[float] = None
    is_imputed: bool
    model_config = ConfigDict(from_attributes=True)


class TrainingStatusSchema(BaseModel):
    """Snapshot of the current ML training state.

    Represents the single logical row tracked by the application to indicate
    whether a training job is running and metadata about recent runs.

    Parameters
    ----------
    id : int
        Primary key, conventionally ``1`` for the singleton row.
    is_training : bool
        Flag indicating whether a training job is currently running.
    last_trained_at : datetime, optional
        Timestamp of the last completed training job in UTC.
    train_count : int
        Monotonic counter of completed training runs.
    current_horizon : str, optional
        Human‑readable horizon label (e.g., ``"5min"``) or status text.

    Attributes
    ----------
    Same as Parameters.

    Examples
    --------
    >>> TrainingStatusSchema(id=1, is_training=False, train_count=3).train_count
    3

    See Also
    --------
    app.models.TrainingStatus : ORM model used by the backend.
    """

    id: int
    is_training: bool
    last_trained_at: Optional[datetime]
    train_count: int
    current_horizon: Optional[str]
    model_config = ConfigDict(from_attributes=True)


class TrainingLogSchema(BaseModel):
    """Append‑only log entry with training scores and metadata.

    Mirrors :class:`app.models.TrainingLog`. Each entry captures model scores
    for a given horizon and, optionally, coordinate context.

    Parameters
    ----------
    id : str
        Unique identifier (UUID string) for the run.
    timestamp : datetime
        Completion time of the training run in UTC.
    horizon : str
        Non‑empty key identifying the grouping (often
        ``"<coord>_<horizon_label>"``).
    sklearn_score : float
        R^2 score from the Scikit‑learn model.
    pytorch_score : float
        R^2 score from the PyTorch model.
    data_count : int
        Number of data points used for the run.
    coord_latitude : float, optional
        Coordinate latitude associated with the run.
    coord_longitude : float, optional
        Coordinate longitude associated with the run.
    horizon_label : str, optional
        Human‑friendly label for the horizon (e.g., ``"5min"``).

    Attributes
    ----------
    Same as Parameters.

    Examples
    --------
    >>> TrainingLogSchema(
    ...     id="00000000-0000-0000-0000-000000000000",
    ...     timestamp=datetime(2024, 1, 1, 12, 0, 0),
    ...     horizon="59.33_18.06_5min",
    ...     sklearn_score=0.92,
    ...     pytorch_score=0.93,
    ...     data_count=1000,
    ... ).model_dump()["horizon"]
    '59.33_18.06_5min'

    See Also
    --------
    app.models.TrainingLog : ORM model persisted by the training jobs.
    """

    id: str  # Changed from int to str for UUID
    timestamp: datetime
    horizon: str  # This is the combined key e.g. "latX_lonY_5min"
    sklearn_score: float
    pytorch_score: float
    data_count: int
    coord_latitude: Optional[float] = None
    coord_longitude: Optional[float] = None
    horizon_label: Optional[str] = None  # This is the pure horizon label e.g. "5min"
    model_config = ConfigDict(from_attributes=True)


class CoordinateListResponse(BaseModel):
    """Response payload carrying a list of coordinates.

    Parameters
    ----------
    coordinates : list[CoordinateSchema]
        Collection of coordinates to return to the client.

    Attributes
    ----------
    coordinates : list[CoordinateSchema]
        The returned coordinates.

    Examples
    --------
    >>> resp = CoordinateListResponse(
    ...     coordinates=[CoordinateSchema(latitude=0, longitude=0)]
    ... )
    >>> len(resp.coordinates)
    1
    """

    coordinates: List[CoordinateSchema]


class PredictionDataResponse(BaseModel):
    """Historical model performance keyed by horizon identifier.

    Parameters
    ----------
    history : dict[str, dict[str, Any]]
        Mapping from horizon key (e.g., ``"latX_lonY_5min"``) to arbitrary
        value dictionaries suitable for charting or tabular display.

    Attributes
    ----------
    history : dict[str, dict[str, Any]]
        The historical data structure for visualization.

    Examples
    --------
    >>> payload = PredictionDataResponse(history={"h1": {"r2": 0.9}})
    >>> float(payload.history["h1"]["r2"]) == 0.9
    True
    """

    history: Dict[
        str, Dict[str, Any]
    ]  # Key is horizon_log_name (e.g. "latX_lonY_5min")


class GenericStatusResponse(BaseModel):
    """Simple status envelope for human‑readable messages.

    Parameters
    ----------
    status : str
        Machine‑friendly status value (e.g., ``"ok"``, ``"error"``).
    message : str, optional
        Human‑readable description or context for the status.

    Attributes
    ----------
    Same as Parameters.

    Examples
    --------
    >>> GenericStatusResponse(status="ok").model_dump()["status"]
    'ok'
    """

    status: str
    message: Optional[str] = None