# src/app/ml_train.py"""Standalone ML training job executed inside the Slurm cluster.This module encapsulates the horizon-wise training workflow used by theplatform's MLOps pipeline. It loads non-imputed observations for selectedcoordinates, prepares future targets per forecasting horizon, trains abaseline Scikit-learn model and a simple PyTorch regressor, persists modelartifacts under the shared ``/data`` volume, and records scores in thedatabase for monitoring and visualization.See Also--------app.slurm_job_trigger.create_and_dispatch_training_jobapp.slurm_job_trigger.trigger_slurm_jobapp.ml_utils.get_latest_training_logsapp.ml_utils.get_historical_scoresapp.models.TrainingLogapp.models.TrainingStatusNotes------ Primary role: execute training for configured horizons and persist both models and TrainingLog entries, updating the TrainingStatus singleton throughout execution.- Key dependencies: a reachable database via ``DATABASE_URL``, a writable shared volume at ``/data``, and environment variables controlling training.- Invariants: ``DATABASE_URL`` must be set; the shared ``/data`` volume must exist and be writable; minimum data thresholds determine training viability.Examples-------->>> # Executed by Slurm via the dispatcher # doctest: +SKIP>>> # $ python3 /data/app_code_for_slurm/ml_train.py # doctest: +SKIP>>> from app.ml_train import main # doctest: +SKIP>>> main() # doctest: +SKIP"""importloggingimportosimportuuidfromdatetimeimportdatetimefrompathlibimportPathfromtypingimportTuple,castimportjoblibimportnumpyasnpimportpandasaspdimporttorchfromnumpy.typingimportNDArrayfromsklearn.linear_modelimportRidgefromsklearn.metricsimportr2_scorefromsklearn.model_selectionimporttrain_test_splitfromsqlalchemyimportBoolean,DateTime,Float,Integer,String,create_enginefromsqlalchemy.excimportSQLAlchemyErrorfromsqlalchemy.ormimport(DeclarativeBase,Mapped,SessionasSQLAlchemySession,mapped_column,sessionmaker,)fromtorchimportnnfromtorch.utils.dataimportDataLoader,TensorDataset# --- Configuration and Constants ---LOG_FORMAT="%(asctime)s - %(name)s - %(levelname)s - %(message)s"logging.basicConfig(level=logging.INFO,format=LOG_FORMAT)logger=logging.getLogger(__name__)DATA_DIRECTORY=Path(os.getenv("SLURM_JOB_DATA_PATH","/data"))HORIZON_SHIFTS={"5min":1,"1h":12,"12h":144,"24h":288}FEATURE_COLUMNS=["air_temperature","wind_speed","wind_direction","precipitation_amount",]TARGET_COLUMN="air_temperature"DEFAULT_NUM_EPOCHS=int(os.getenv("ML_NUM_EPOCHS","10"))DEFAULT_LEARNING_RATE=float(os.getenv("ML_LEARNING_RATE","1e-3"))DEFAULT_BATCH_SIZE=int(os.getenv("ML_BATCH_SIZE","32"))DEFAULT_TEST_SIZE=float(os.getenv("ML_TEST_SIZE","0.2"))MIN_DATA_POINTS_FOR_TRAINING=int(os.getenv("ML_MIN_DATA_POINTS","50"))# --- ORM Base and Models ---classBase(DeclarativeBase):"""Base class for ORM models in this standalone job. Notes ----- The standalone training job defines a minimal subset of ORM models to avoid importing the main application's metadata. Tables mirror the fields in ``app.models`` sufficiently for writing logs and status. """passclassWeatherObservation(Base):"""Observational weather record (non-imputed preferred for training). Attributes ---------- timestamp : datetime Unique timestamp for the observation (UTC). latitude : float Coordinate latitude in decimal degrees. longitude : float Coordinate longitude in decimal degrees. air_temperature : float | None Air temperature in degrees Celsius. wind_speed : float | None Wind speed in m/s. wind_direction : float | None Wind direction in degrees. cloud_area_fraction : float | None Cloud cover fraction (0–1) when available. precipitation_amount : float | None Precipitation amount in mm for the interval. is_imputed : bool Whether the record was imputed (training uses non-imputed). """__tablename__="weather_observations"timestamp:Mapped[datetime]=mapped_column(DateTime,primary_key=True)latitude:Mapped[float]=mapped_column(Float,primary_key=True)longitude:Mapped[float]=mapped_column(Float,primary_key=True)air_temperature:Mapped[float|None]=mapped_column(Float,nullable=True)wind_speed:Mapped[float|None]=mapped_column(Float,nullable=True)wind_direction:Mapped[float|None]=mapped_column(Float,nullable=True)cloud_area_fraction:Mapped[float|None]=mapped_column(Float,nullable=True)precipitation_amount:Mapped[float|None]=mapped_column(Float,nullable=True)is_imputed:Mapped[bool]=mapped_column(Boolean,default=False,nullable=False)classTrainingLog(Base):"""Log entry for one training run and horizon. Attributes ---------- id : str UUID identifier of the run. timestamp : datetime When the run completed (UTC). horizon : str Unique horizon key (often ``"<coord>_<horizon>"``). sklearn_score : float R^2 score for the Sklearn model on the test split. pytorch_score : float R^2 score for the PyTorch model on the test split. data_count : int Number of samples used for this horizon after preprocessing. coord_latitude : float | None Latitude of the coordinate or ``None`` for aggregate runs. coord_longitude : float | None Longitude of the coordinate or ``None`` for aggregate runs. horizon_label : str | None One of ``{"5min", "1h", "12h", "24h"}``. """__tablename__="training_logs"id:Mapped[str]=mapped_column(String,primary_key=True,default=lambda:str(uuid.uuid4()),index=True)timestamp:Mapped[datetime]=mapped_column(DateTime,default=datetime.utcnow)horizon:Mapped[str]=mapped_column(String)sklearn_score:Mapped[float]=mapped_column(Float)pytorch_score:Mapped[float]=mapped_column(Float)data_count:Mapped[int]=mapped_column(Integer)coord_latitude:Mapped[float|None]=mapped_column(Float,nullable=True)coord_longitude:Mapped[float|None]=mapped_column(Float,nullable=True)horizon_label:Mapped[str|None]=mapped_column(String,nullable=True)classTrainingStatus(Base):"""Singleton table reflecting the current training state. Attributes ---------- id : int Primary key (always 1 in this job flow). is_training : bool Whether a training job is running. last_trained_at : datetime | None Timestamp of the last successful training completion. train_count : int Number of training runs since system start or initialization. current_horizon : str | None Human-readable status message or horizon marker. """__tablename__="training_status"id:Mapped[int]=mapped_column(Integer,primary_key=True,default=1)is_training:Mapped[bool]=mapped_column(Boolean,default=False)last_trained_at:Mapped[datetime|None]=mapped_column(DateTime,nullable=True)train_count:Mapped[int]=mapped_column(Integer,default=0)current_horizon:Mapped[str|None]=mapped_column(String,nullable=True)classSimpleRegressionNet(nn.Module):"""Simple feed-forward regression network for 1-step regression. Parameters ---------- input_dim : int Number of input features (must be positive). Examples -------- >>> import torch # doctest: +SKIP >>> from app.ml_train import SimpleRegressionNet # doctest: +SKIP >>> net = SimpleRegressionNet(input_dim=4) # doctest: +SKIP >>> x = torch.randn(2, 4) # doctest: +SKIP >>> y = net(x) # doctest: +SKIP >>> y.shape == (2, 1) # doctest: +SKIP True """def__init__(self,input_dim:int)->None:super().__init__()assertinput_dim>0,f"input_dim must be positive, got {input_dim}"self.net=nn.Sequential(nn.Linear(input_dim,64),nn.ReLU(),nn.Linear(64,1),)defforward(self,x:torch.Tensor)->torch.Tensor:"""Compute predictions for input features. Parameters ---------- x : torch.Tensor Input batch with shape ``(batch_size, input_dim)``. Returns ------- torch.Tensor Predicted target values with shape ``(batch_size, 1)``. """returncast(torch.Tensor,self.net(x))# --- Database Session ---defget_db_session()->SQLAlchemySession:"""Create and return a new SQLAlchemy session. Returns ------- sqlalchemy.orm.Session Session bound to the engine specified by ``DATABASE_URL``. Raises ------ ValueError If ``DATABASE_URL`` is not defined in the environment. sqlalchemy.exc.SQLAlchemyError If the engine cannot be created or the metadata initialization fails. Examples -------- >>> import os # doctest: +SKIP >>> os.environ["DATABASE_URL"] = "sqlite:///:memory:" # doctest: +SKIP >>> from app.ml_train import get_db_session # doctest: +SKIP >>> s = get_db_session() # doctest: +SKIP >>> s.close() # doctest: +SKIP """database_url=os.getenv("DATABASE_URL")ifnotdatabase_url:logger.error("DATABASE_URL environment variable not set.")raiseValueError("DATABASE_URL is required")connect_args={}ifdatabase_url.startswith("sqlite"):connect_args["check_same_thread"]=Falseengine=create_engine(database_url,connect_args=connect_args)Base.metadata.create_all(engine)SessionLocal=sessionmaker(bind=engine)returnSessionLocal()# --- Core Functions ---defload_training_data(session:SQLAlchemySession,latitude:float,longitude:float)->pd.DataFrame:"""Load non-imputed observations for one coordinate. Parameters ---------- session : sqlalchemy.orm.Session Open SQLAlchemy session. latitude : float Coordinate latitude in decimal degrees. longitude : float Coordinate longitude in decimal degrees. Returns ------- pandas.DataFrame DataFrame with a ``timestamp`` column parsed to ``datetime64[ns]``. Empty DataFrame if no non-imputed rows are available. Examples -------- >>> # Requires seeded DB with weather_observations # doctest: +SKIP >>> from app.ml_train import load_training_data # doctest: +SKIP >>> df = load_training_data(session, 57.70, 11.90) # doctest: +SKIP >>> isinstance(df.empty, bool) # doctest: +SKIP True """logger.info(f"Loading non-imputed data for ({latitude}, {longitude})")try:query=session.query(WeatherObservation).filter_by(latitude=latitude,longitude=longitude,is_imputed=False)assertsession.bind,"Session must be bound to an engine"df=pd.read_sql(query.statement,session.bind)ifdf.empty:logger.warning(f"No actual data for ({latitude}, {longitude})")returnpd.DataFrame()df["timestamp"]=pd.to_datetime(df["timestamp"])logger.info(f"Loaded {len(df)} rows for ({latitude}, {longitude})")returndfexceptSQLAlchemyErrorase:logger.error(f"Error loading data for ({latitude}, {longitude}): {e}",exc_info=True)returnpd.DataFrame()defprepare_horizon_data(df:pd.DataFrame,horizon_label:str,shift_steps:int)->Tuple[NDArray[np.float64],NDArray[np.float64],int]:"""Prepare features and targets for a specific horizon. Parameters ---------- df : pandas.DataFrame Input data sorted/filtered per coordinate. horizon_label : str Label for the horizon (e.g., ``"5min"``, ``"1h"``). shift_steps : int Positive number of steps the target is shifted into the future. Returns ------- tuple[numpy.ndarray, numpy.ndarray, int] Tuple ``(X, y, count)`` where count is the number of samples retained after dropping NA rows for the chosen horizon. Examples -------- >>> import pandas as pd # doctest: +SKIP >>> from app.ml_train import prepare_horizon_data # doctest: +SKIP >>> ts = pd.date_range("2024-01-01", periods=12, freq="5min") # doctest: +SKIP >>> df = pd.DataFrame({ # doctest: +SKIP ... "timestamp": ts, ... "air_temperature": list(range(12)), ... "wind_speed": [1.0]*12, ... "wind_direction": [0.0]*12, ... "precipitation_amount": [0.0]*12, ... }) >>> X, y, count = prepare_horizon_data(df, "5min", 1) # doctest: +SKIP >>> count >= 10 # doctest: +SKIP True """ifdf.emptyorlen(df)<shift_steps+1:logger.debug(f"Insufficient rows for horizon '{horizon_label}'")return(np.empty((0,len(FEATURE_COLUMNS)),dtype=np.float64),np.empty(0,dtype=np.float64),0,)df_sorted=df.sort_values("timestamp").copy()future_col=f"future_{TARGET_COLUMN}_{horizon_label}"df_sorted[future_col]=df_sorted[TARGET_COLUMN].shift(-shift_steps)df_sorted.dropna(subset=FEATURE_COLUMNS+[future_col],inplace=True)count=len(df_sorted)threshold=(max(5,int(MIN_DATA_POINTS_FOR_TRAINING*0.2))ifhorizon_label=="5min"elseMIN_DATA_POINTS_FOR_TRAINING)ifcount<threshold:logger.warning(f"Not enough data after prep for '{horizon_label}' ({count} < {threshold})")return(np.empty((0,len(FEATURE_COLUMNS)),dtype=np.float64),np.empty(0,dtype=np.float64),0,)X=df_sorted[FEATURE_COLUMNS].to_numpy(dtype=np.float64)y=df_sorted[future_col].to_numpy(dtype=np.float64)returnX,y,countdeftrain_and_save_model(X_train:NDArray[np.float64],y_train:NDArray[np.float64],X_test:NDArray[np.float64],y_test:NDArray[np.float64],horizon:str,coord_str:str,)->Tuple[float,float]:"""Train Ridge and PyTorch models, persist them, and return R² scores. Parameters ---------- X_train, y_train, X_test, y_test : numpy.ndarray Training and test splits. horizon : str Horizon label used for filenames and logging. coord_str : str Coordinate identifier (e.g., ``"lat57_7000_lon11_9000"``). Returns ------- tuple[float, float] ``(sklearn_r2, pytorch_r2)`` scores on the test split. Notes ----- - Models are saved under ``/data/models/<coord_str>/`` with deterministic filenames per horizon and framework. Raises ------ OSError If persisting the model artifacts to disk fails. Examples -------- >>> import numpy as np # doctest: +SKIP >>> from app.ml_train import train_and_save_model # doctest: +SKIP >>> X = np.random.rand(100, 4); y = np.random.rand(100) # doctest: +SKIP >>> s = int(len(X)*0.8) # doctest: +SKIP >>> train_and_save_model(X[:s], y[:s], X[s:], y[s:], # doctest: +SKIP ... "5min", "lat57_7000_lon11_9000") # doctest: +SKIP (..., ...) """model_dir=DATA_DIRECTORY/"models"/coord_strmodel_dir.mkdir(parents=True,exist_ok=True)sklearn_path=model_dir/f"sklearn_model_{horizon}.pkl"pytorch_path=model_dir/f"pytorch_model_{horizon}.pt"# Sklearn Ridgesklearn_model=Ridge()sklearn_model.fit(X_train,y_train)sklearn_score=sklearn_model.score(X_test,y_test)iflen(X_test)else0.0joblib.dump(sklearn_model,sklearn_path)logger.info(f"Saved Sklearn model at {sklearn_path} (R²={sklearn_score:.4f})")# PyTorchXt=torch.from_numpy(X_train.astype(np.float32))yt=torch.from_numpy(y_train.astype(np.float32)).view(-1,1)dataset=TensorDataset(Xt,yt)loader=DataLoader(dataset,batch_size=DEFAULT_BATCH_SIZE,shuffle=True)net=SimpleRegressionNet(X_train.shape[1])criterion=nn.MSELoss()optimizer=torch.optim.Adam(net.parameters(),lr=DEFAULT_LEARNING_RATE)net.train()forepochinrange(DEFAULT_NUM_EPOCHS):total_loss=0.0forbatch_X,batch_yinloader:optimizer.zero_grad()preds=net(batch_X)loss=criterion(preds,batch_y)loss.backward()optimizer.step()total_loss+=loss.item()logger.debug(f"Epoch {epoch+1}/{DEFAULT_NUM_EPOCHS}: Loss={total_loss/len(loader):.4f}")net.eval()pytorch_score=0.0iflen(X_test):withtorch.no_grad():preds=(net(torch.from_numpy(X_test.astype(np.float32))).cpu().numpy().flatten())pytorch_score=r2_score(y_test,preds)torch.save(net.state_dict(),pytorch_path)logger.info(f"Saved PyTorch model at {pytorch_path} (R²={pytorch_score:.4f})")returnfloat(sklearn_score),float(pytorch_score)defupdate_training_status(session:SQLAlchemySession,is_training:bool,current_status_message:str|None=None,increment_count:bool=False,)->None:"""Update the ``training_status`` row with the current state. Parameters ---------- session : sqlalchemy.orm.Session Open SQLAlchemy session. is_training : bool Flag indicating whether a training is in progress. current_status_message : str | None, optional Optional status message or horizon marker. increment_count : bool, optional Whether to increment the total training counter. Examples -------- >>> # Within an open SQLAlchemy session # doctest: +SKIP >>> from app.ml_train import update_training_status # doctest: +SKIP >>> update_training_status(session, True, "Job started") # doctest: +SKIP """try:status=session.query(TrainingStatus).get(1)ifnotstatus:status=TrainingStatus(id=1)session.add(status)status.is_training=is_trainingifcurrent_status_messageisnotNone:status.current_horizon=current_status_messageifnotis_training:status.last_trained_at=datetime.utcnow()status.current_horizon=Noneifincrement_count:status.train_count=(status.train_countor0)+1session.commit()exceptSQLAlchemyErrorase:logger.error(f"Failed to update training status: {e}",exc_info=True)session.rollback()defmain()->None:"""Entry point for the standalone ML training job. Notes ----- - Coordinates are fetched (preferentially central ones) and iterated. For each horizon, data are prepared, models trained, artifacts saved, and a ``TrainingLog`` row written. - Status updates are written to ``TrainingStatus`` throughout the run. Examples -------- >>> # Executed inside Slurm job container # doctest: +SKIP >>> from app.ml_train import main # doctest: +SKIP >>> main() # doctest: +SKIP """logger.info("=== Starting ML Training Job ===")try:session=get_db_session()except(ValueError,AssertionError)ase:logger.error(f"Configuration error: {e}")returntry:fromsqlalchemyimporttextcentral_query=session.execute(text("SELECT latitude, longitude FROM coordinates WHERE is_central = true"))all_coords=[(row.latitude,row.longitude)forrowincentral_query.fetchall()]exceptSQLAlchemyErrorase:logger.warning(f"Could not fetch central coords: {e}")all_coords=[(56.8618,14.8069)]ifnotall_coords:update_training_status(session,is_training=False,current_status_message="No data")session.close()returnupdate_training_status(session,is_training=True,current_status_message="Job started")overall_success=Truetry:forlat,loninall_coords:coord_str=f"lat{lat:.4f}_lon{lon:.4f}".replace(".","_")update_training_status(session,is_training=True,current_status_message=f"Training {coord_str}",)df=load_training_data(session,lat,lon)ifdf.empty:logger.warning(f"Skipping {coord_str}, no data")continueforhorizon,stepsinHORIZON_SHIFTS.items():update_training_status(session,is_training=True,current_status_message=f"{coord_str}:{horizon}",)X,y,count=prepare_horizon_data(df,horizon,steps)ifcount==0:session.add(TrainingLog(id=str(uuid.uuid4()),horizon=f"{coord_str}_{horizon}",sklearn_score=float(0.0),pytorch_score=float(0.0),data_count=int(0),coord_latitude=float(lat),coord_longitude=float(lon),horizon_label=horizon,))session.commit()continueX_train,X_test,y_train,y_test=train_test_split(X,y,test_size=DEFAULT_TEST_SIZE,shuffle=False)if(len(X_train)<MIN_DATA_POINTS_FOR_TRAINING//2orlen(X_test)<MIN_DATA_POINTS_FOR_TRAINING//10):logger.warning(f"Skipping {coord_str}:{horizon}, insufficient split data")continuesklearn_score,pytorch_score=train_and_save_model(X_train,y_train,X_test,y_test,horizon,coord_str)session.add(TrainingLog(id=str(uuid.uuid4()),horizon=f"{coord_str}_{horizon}",sklearn_score=float(sklearn_score),pytorch_score=float(pytorch_score),data_count=int(count),coord_latitude=float(lat),coord_longitude=float(lon),horizon_label=horizon,))session.commit()exceptSQLAlchemyErrorase:logger.error(f"Training loop error: {e}",exc_info=True)overall_success=Falseupdate_training_status(session,is_training=False,current_status_message=f"Error: {e}")finally:ifoverall_success:update_training_status(session,is_training=False,current_status_message="Completed")logger.info("=== ML Training Job Completed ===")else:logger.error("=== ML Training Job Failed ===")session.close()if__name__=="__main__":# Prepare directoriesDATA_DIRECTORY.mkdir(parents=True,exist_ok=True)(DATA_DIRECTORY/"models").mkdir(parents=True,exist_ok=True)main()