"""
Dispatch ML training jobs to a Slurm cluster from the web app or CLI.
This module prepares and submits training jobs to a containerized Slurm master.
It synchronizes the application code to a shared ``/data`` volume, generates a
deterministic ``sbatch`` script, and executes ``sbatch`` inside the Slurm
master container via the Docker Engine. It also updates the training status in
the database to reflect submission progress or failure.
See Also
--------
app.ml_train : Standalone training job executed by the sbatch script.
app.main : FastAPI endpoints that initiate training dispatch.
app.database.SessionLocal : Session factory used to update status.
app.models.TrainingStatus : Singleton status row updated on dispatch.
Notes
-----
- Primary role: synchronize code to shared volume, write the sbatch script,
and trigger job submission within the Slurm master container.
- Key dependencies: a running Docker daemon, a Slurm master container whose
name matches one of :data:`SLURM_MASTER_CONTAINER_NAMES`, a writable
filesystem mounted at ``/data`` shared by Slurm nodes, and a valid
``DATABASE_URL`` environment variable for the job.
- Invariants: the shared path :data:`SHARED_DATA_PATH` (``/data``)
must exist and be writable; the sbatch script name is fixed by
:data:`SBATCH_SCRIPT_NAME`.
Examples
--------
>>> from app.slurm_job_trigger import create_and_dispatch_training_job # doctest: +SKIP
>>> create_and_dispatch_training_job() # doctest: +SKIP
"""
import logging
import os
import shutil
import subprocess
from pathlib import Path
from typing import Any, Optional, Tuple
import docker
from docker.errors import DockerException, NotFound as DockerNotFound
from sqlalchemy.exc import SQLAlchemyError
from .database import SessionLocal
from .models import TrainingStatus
logger = logging.getLogger(__name__)
SLURM_MASTER_CONTAINER_NAMES = (
"slurm-master-test",
"slurm-master",
"ml_weather-slurm-master-1",
)
SHARED_DATA_PATH = Path("/data")
SBATCH_SCRIPT_NAME = "run_ml_training_job.sbatch"
RSYNC_COMMAND = "rsync"
RSYNC_ARGS = ("-a", "--delete")
CP_COMMAND = "cp"
CP_ARGS = ("-rT",)
def _find_slurm_master_container() -> Optional[Any]:
"""Locate the running Slurm master Docker container.
Iterates over :data:`SLURM_MASTER_CONTAINER_NAMES` and returns the first
matching, running container object if found. Logs at ``INFO`` when a
container is located and at ``ERROR`` if none are found.
Returns
-------
object | None
A Docker SDK container object if found; otherwise ``None``.
Raises
------
docker.errors.DockerException
If communicating with the Docker daemon fails while probing names.
Examples
--------
>>> # Requires a running Slurm master container # doctest: +SKIP
>>> from app.slurm_job_trigger import _find_slurm_master_container # doctest: +SKIP
>>> container = _find_slurm_master_container() # doctest: +SKIP
>>> bool(container) # doctest: +SKIP
True
"""
for name in SLURM_MASTER_CONTAINER_NAMES:
try:
container = docker.from_env().containers.get(name)
logger.info(f"Found Slurm master container: {name}")
return container
except DockerNotFound:
logger.debug(f"Container '{name}' not found, trying next.")
logger.error(
f"No Slurm master container found. Checked: {SLURM_MASTER_CONTAINER_NAMES}"
)
return None
def _execute_in_container(
container: Any,
command: str,
user: str = "slurm",
) -> Tuple[int, str, str]:
"""Execute a command inside a container.
Parameters
----------
container : object
Docker SDK container object returned by :func:`docker.from_env`.
command : str
The command to execute inside the container.
user : str, optional
Linux user to execute as within the container, by default ``"slurm"``.
Returns
-------
tuple[int, str, str]
A tuple of ``(exit_code, stdout, stderr)`` where both streams are
decoded as UTF‑8 with replacement and stripped of trailing whitespace.
Raises
------
docker.errors.DockerException
If the Docker Engine cannot execute the command.
Examples
--------
>>> # Requires a suitable container instance # doctest: +SKIP
>>> from app.slurm_job_trigger import _execute_in_container # doctest: +SKIP
>>> _execute_in_container(container, "echo hello") # doctest: +SKIP
(0, 'hello', '')
"""
exit_code, output = container.exec_run(command, user=user, demux=True)
stdout_str, stderr_str = "", ""
if output:
out, err = output
if out:
stdout_str = out.decode("utf-8", errors="replace").strip()
if err:
stderr_str = err.decode("utf-8", errors="replace").strip()
return exit_code, stdout_str, stderr_str
[docs]
def clear_training_flag_on_failure(reason: str = "Failed to dispatch job") -> None:
"""Reset the training flag and record a failure reason.
Sets ``TrainingStatus.is_training`` to ``False`` and stores the provided
``reason`` in ``TrainingStatus.current_horizon`` for traceability. Intended
to be called whenever dispatching fails so that the UI reflects a stopped
state with a short explanation.
Parameters
----------
reason : str, optional
Human-readable failure reason persisted to the database. Must be
non-empty, by default ``"Failed to dispatch job"``.
Raises
------
AssertionError
If ``reason`` is an empty string.
Notes
-----
- Any :class:`sqlalchemy.exc.SQLAlchemyError` is caught and logged; the
function does not raise on database errors.
"""
assert reason, f"Reason must be non-empty, got '{reason}'."
try:
with SessionLocal() as session:
status = session.query(TrainingStatus).get(1)
if status and status.is_training:
status.is_training = False
status.current_horizon = reason
session.commit()
logger.info(f"Training flag cleared. Reason: {reason}")
except SQLAlchemyError as error:
logger.error(f"Failed to clear training flag: {error}", exc_info=True)
[docs]
def trigger_slurm_job(script_path: str) -> bool:
"""Submit an ``sbatch`` job to the Slurm master container.
Finds the Slurm master container, executes ``sbatch`` with the provided
script path inside that container, and parses the output to confirm
submission. Logs the full outcome including STDOUT/STDERR for diagnosis.
Parameters
----------
script_path : str
Absolute path to the ``.sbatch`` script inside the shared volume
(typically under ``/data``) that Slurm should execute.
Returns
-------
bool
``True`` if the job was successfully submitted (detected by the
presence of ``"Submitted batch job"`` in STDOUT); ``False`` otherwise
or if the Slurm master container could not be found.
Raises
------
AssertionError
If ``script_path`` is an empty string.
Examples
--------
>>> # Requires a running Slurm master container and shared volume # doctest: +SKIP
>>> from app.slurm_job_trigger import trigger_slurm_job # doctest: +SKIP
>>> trigger_slurm_job("/data/run_ml_training_job.sbatch") # doctest: +SKIP
True
"""
assert script_path, f"Script path must be non-empty, got '{script_path}'."
container = _find_slurm_master_container()
if not container:
return False
command = f"sbatch {script_path}"
logger.info(f"Executing '{command}' in container '{container.name}'.")
try:
exit_code, stdout, stderr = _execute_in_container(container, command)
output_log = f"STDOUT: {stdout}" if stdout else "No STDOUT."
if stderr:
output_log += f" STDERR: {stderr}"
if exit_code == 0:
logger.info(f"sbatch executed. {output_log}")
if "Submitted batch job" in stdout:
logger.info("Job submission confirmed.")
return True
logger.warning(f"No submission confirmation. {output_log}")
return False
logger.error(f"sbatch failed with exit code {exit_code}. {output_log}")
return False
except DockerException as error:
logger.error(f"Error executing sbatch: {error}", exc_info=True)
return False
def _sync_app_source(source_dir: Path, target_dir: Path) -> None:
"""Synchronize application source code for Slurm workers.
Copies the current application source from ``source_dir`` to
``target_dir`` inside the shared volume used by Slurm nodes. Uses
``rsync`` with archive and delete flags when available for correctness and
performance; falls back to ``cp -rT`` otherwise.
Parameters
----------
source_dir : Path
Local path to the code to be synchronized (usually ``app/``).
target_dir : Path
Destination under ``/data`` (e.g., ``/data/app_code_for_slurm``).
Raises
------
AssertionError
If ``source_dir`` does not exist.
OSError
If the synchronization command exits with a non-zero status.
Examples
--------
>>> from pathlib import Path # doctest: +SKIP
>>> from app.slurm_job_trigger import _sync_app_source # doctest: +SKIP
>>> _sync_app_source(Path("src/app"), Path("/data/app_code_for_slurm")) # doctest: +SKIP
"""
assert source_dir.exists(), f"Source directory {source_dir} does not exist."
if target_dir.exists():
logger.info("Removing old app_code_for_slurm for clean sync.")
shutil.rmtree(target_dir, ignore_errors=True)
target_dir.mkdir(parents=True, exist_ok=True)
if shutil.which(RSYNC_COMMAND):
cmd = [RSYNC_COMMAND] + list(RSYNC_ARGS) + [f"{source_dir}/", f"{target_dir}/"]
else:
logger.warning(
"rsync not available, falling back to cp. Install rsync for better performance."
)
cmd = [CP_COMMAND] + list(CP_ARGS) + [str(source_dir), f"{target_dir}/"]
logger.info(f"Synchronizing code with: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode != 0:
raise OSError(
f"Code synchronization failed with exit code {result.returncode}: {result.stderr}"
)
logger.info("Application source synchronized successfully.")
def _write_sbatch_script(
script_path: Path,
database_url: str,
python_script: str,
python_path: str,
) -> None:
"""Write a deterministic Slurm ``sbatch`` script to the shared volume.
Parameters
----------
script_path : Path
Absolute path under ``/data`` where the script will be written.
database_url : str
Database URL exported to the job via ``DATABASE_URL``.
python_script : str
Path to the Python entrypoint executed by the job
(e.g., ``/data/app_code_for_slurm/ml_train.py``).
python_path : str
Value for ``PYTHONPATH`` so that the dispatched code can import
application modules (typically the directory containing the code copy).
Raises
------
OSError
If the script cannot be written to disk.
Notes
-----
- The script logs basic job metadata and exits with the Python process
exit code for Slurm accounting.
- Output and error logs are written to ``/data/logs``.
"""
script_content = f"""#!/bin/bash
#SBATCH --job-name=ml_training_job
#SBATCH --partition=cpu-nodes
#SBATCH --output={SHARED_DATA_PATH}/logs/ml_train_%j.out
#SBATCH --error={SHARED_DATA_PATH}/logs/ml_train_%j.err
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2
#SBATCH --mem=2G
export DATABASE_URL="{database_url}"
export PYTHONPATH={python_path}
export SLURM_JOB_DATA_PATH={SHARED_DATA_PATH}
echo "--- SLURM JOB START ---"
echo "Job ID: $SLURM_JOB_ID"
echo "Running on node: $(hostname)"
echo "Date: $(date)"
echo "Python path: $PYTHONPATH"
echo "--- Python Execution ---"
python3 {python_script}
exit_code=$?
echo "Script exit code: $exit_code"
exit $exit_code
"""
try:
script_path.write_text(script_content)
logger.info(f"sbatch script written to {script_path}")
except OSError as error:
logger.error(f"Failed to write sbatch script: {error}", exc_info=True)
raise
def _update_training_status(message: str) -> None:
"""Update the training status record with the given message.
Parameters
----------
message : str
Human-readable status update stored in ``TrainingStatus.current_horizon``.
Raises
------
AssertionError
If ``message`` is empty.
Notes
-----
- Any :class:`sqlalchemy.exc.SQLAlchemyError` is caught and logged; the
function does not raise on database errors.
"""
assert message, f"Status message must be non-empty, got '{message}'."
try:
with SessionLocal() as session:
status = session.query(TrainingStatus).get(1)
if status:
status.current_horizon = message
session.commit()
logger.info(f"Training status updated: {message}")
except SQLAlchemyError as error:
logger.error(f"Failed to update training status: {error}", exc_info=True)
[docs]
def create_and_dispatch_training_job() -> None:
"""Prepare code, write sbatch script, and dispatch the training job to Slurm.
High-level orchestration that ensures the shared ``/data`` volume has a
fresh copy of the application code, writes a minimal ``sbatch`` script with
required environment variables, and triggers submission on the Slurm master
container. Updates :class:`app.models.TrainingStatus` accordingly, or resets
the training flag with a human-readable reason on failure.
Returns
-------
None
Notes
-----
- The function relies on the ``DATABASE_URL`` environment variable. If it
is missing, no submission is attempted and the status is cleared.
- Operational errors from Docker, filesystem, or database are logged and
cause a safe status reset without raising exceptions to the caller.
Examples
--------
>>> from app.slurm_job_trigger import create_and_dispatch_training_job # doctest: +SKIP
>>> create_and_dispatch_training_job() # doctest: +SKIP
"""
logger.info("Dispatching a new ML training job to Slurm cluster.")
database_url = os.getenv("DATABASE_URL")
if not database_url:
logger.error("DATABASE_URL not set.")
clear_training_flag_on_failure("DATABASE_URL not set")
return
source_dir = Path(__file__).parent.resolve()
target_dir = SHARED_DATA_PATH / "app_code_for_slurm"
sbatch_script_path = SHARED_DATA_PATH / SBATCH_SCRIPT_NAME
try:
(SHARED_DATA_PATH / "logs").mkdir(parents=True, exist_ok=True)
_sync_app_source(source_dir, target_dir)
_write_sbatch_script(
sbatch_script_path,
database_url,
str(target_dir / "ml_train.py"),
str(target_dir),
)
if trigger_slurm_job(str(sbatch_script_path)):
_update_training_status("Job submitted to Slurm")
else:
clear_training_flag_on_failure("Dispatch process failed")
except (OSError, DockerException, SQLAlchemyError) as error:
logger.error(f"Dispatch failed: {error}", exc_info=True)
clear_training_flag_on_failure("Dispatch process failed")