slurm_job_trigger.py#

"""
Dispatch ML training jobs to a Slurm cluster from the web app or CLI.

This module prepares and submits training jobs to a containerized Slurm master.
It synchronizes the application code to a shared ``/data`` volume, generates a
deterministic ``sbatch`` script, and executes ``sbatch`` inside the Slurm
master container via the Docker Engine. It also updates the training status in
the database to reflect submission progress or failure.

See Also
--------
app.ml_train : Standalone training job executed by the sbatch script.
app.main : FastAPI endpoints that initiate training dispatch.
app.database.SessionLocal : Session factory used to update status.
app.models.TrainingStatus : Singleton status row updated on dispatch.

Notes
-----
- Primary role: synchronize code to shared volume, write the sbatch script,
  and trigger job submission within the Slurm master container.
- Key dependencies: a running Docker daemon, a Slurm master container whose
  name matches one of :data:`SLURM_MASTER_CONTAINER_NAMES`, a writable
  filesystem mounted at ``/data`` shared by Slurm nodes, and a valid
  ``DATABASE_URL`` environment variable for the job.
- Invariants: the shared path :data:`SHARED_DATA_PATH` (``/data``)
  must exist and be writable; the sbatch script name is fixed by
  :data:`SBATCH_SCRIPT_NAME`.

Examples
--------
>>> from app.slurm_job_trigger import create_and_dispatch_training_job  # doctest: +SKIP
>>> create_and_dispatch_training_job()                                   # doctest: +SKIP
"""

import logging
import os
import shutil
import subprocess
from pathlib import Path
from typing import Any, Optional, Tuple

import docker
from docker.errors import DockerException, NotFound as DockerNotFound
from sqlalchemy.exc import SQLAlchemyError

from .database import SessionLocal
from .models import TrainingStatus

logger = logging.getLogger(__name__)

SLURM_MASTER_CONTAINER_NAMES = (
    "slurm-master-test",
    "slurm-master",
    "ml_weather-slurm-master-1",
)
SHARED_DATA_PATH = Path("/data")
SBATCH_SCRIPT_NAME = "run_ml_training_job.sbatch"
RSYNC_COMMAND = "rsync"
RSYNC_ARGS = ("-a", "--delete")
CP_COMMAND = "cp"
CP_ARGS = ("-rT",)


def _find_slurm_master_container() -> Optional[Any]:
    """Locate the running Slurm master Docker container.

    Iterates over :data:`SLURM_MASTER_CONTAINER_NAMES` and returns the first
    matching, running container object if found. Logs at ``INFO`` when a
    container is located and at ``ERROR`` if none are found.

    Returns
    -------
    object | None
        A Docker SDK container object if found; otherwise ``None``.

    Raises
    ------
    docker.errors.DockerException
        If communicating with the Docker daemon fails while probing names.

    Examples
    --------
    >>> # Requires a running Slurm master container                  # doctest: +SKIP
    >>> from app.slurm_job_trigger import _find_slurm_master_container  # doctest: +SKIP
    >>> container = _find_slurm_master_container()                    # doctest: +SKIP
    >>> bool(container)                                               # doctest: +SKIP
    True
    """
    for name in SLURM_MASTER_CONTAINER_NAMES:
        try:
            container = docker.from_env().containers.get(name)
            logger.info(f"Found Slurm master container: {name}")
            return container
        except DockerNotFound:
            logger.debug(f"Container '{name}' not found, trying next.")
    logger.error(
        f"No Slurm master container found. Checked: {SLURM_MASTER_CONTAINER_NAMES}"
    )
    return None


def _execute_in_container(
    container: Any,
    command: str,
    user: str = "slurm",
) -> Tuple[int, str, str]:
    """Execute a command inside a container.

    Parameters
    ----------
    container : object
        Docker SDK container object returned by :func:`docker.from_env`.
    command : str
        The command to execute inside the container.
    user : str, optional
        Linux user to execute as within the container, by default ``"slurm"``.

    Returns
    -------
    tuple[int, str, str]
        A tuple of ``(exit_code, stdout, stderr)`` where both streams are
        decoded as UTF‑8 with replacement and stripped of trailing whitespace.

    Raises
    ------
    docker.errors.DockerException
        If the Docker Engine cannot execute the command.

    Examples
    --------
    >>> # Requires a suitable container instance                   # doctest: +SKIP
    >>> from app.slurm_job_trigger import _execute_in_container     # doctest: +SKIP
    >>> _execute_in_container(container, "echo hello")             # doctest: +SKIP
    (0, 'hello', '')
    """
    exit_code, output = container.exec_run(command, user=user, demux=True)
    stdout_str, stderr_str = "", ""
    if output:
        out, err = output
        if out:
            stdout_str = out.decode("utf-8", errors="replace").strip()
        if err:
            stderr_str = err.decode("utf-8", errors="replace").strip()
    return exit_code, stdout_str, stderr_str


def clear_training_flag_on_failure(reason: str = "Failed to dispatch job") -> None:
    """Reset the training flag and record a failure reason.

    Sets ``TrainingStatus.is_training`` to ``False`` and stores the provided
    ``reason`` in ``TrainingStatus.current_horizon`` for traceability. Intended
    to be called whenever dispatching fails so that the UI reflects a stopped
    state with a short explanation.

    Parameters
    ----------
    reason : str, optional
        Human-readable failure reason persisted to the database. Must be
        non-empty, by default ``"Failed to dispatch job"``.

    Raises
    ------
    AssertionError
        If ``reason`` is an empty string.

    Notes
    -----
    - Any :class:`sqlalchemy.exc.SQLAlchemyError` is caught and logged; the
      function does not raise on database errors.
    """
    assert reason, f"Reason must be non-empty, got '{reason}'."
    try:
        with SessionLocal() as session:
            status = session.query(TrainingStatus).get(1)
            if status and status.is_training:
                status.is_training = False
                status.current_horizon = reason
                session.commit()
                logger.info(f"Training flag cleared. Reason: {reason}")
    except SQLAlchemyError as error:
        logger.error(f"Failed to clear training flag: {error}", exc_info=True)


def trigger_slurm_job(script_path: str) -> bool:
    """Submit an ``sbatch`` job to the Slurm master container.

    Finds the Slurm master container, executes ``sbatch`` with the provided
    script path inside that container, and parses the output to confirm
    submission. Logs the full outcome including STDOUT/STDERR for diagnosis.

    Parameters
    ----------
    script_path : str
        Absolute path to the ``.sbatch`` script inside the shared volume
        (typically under ``/data``) that Slurm should execute.

    Returns
    -------
    bool
        ``True`` if the job was successfully submitted (detected by the
        presence of ``"Submitted batch job"`` in STDOUT); ``False`` otherwise
        or if the Slurm master container could not be found.

    Raises
    ------
    AssertionError
        If ``script_path`` is an empty string.

    Examples
    --------
    >>> # Requires a running Slurm master container and shared volume   # doctest: +SKIP
    >>> from app.slurm_job_trigger import trigger_slurm_job              # doctest: +SKIP
    >>> trigger_slurm_job("/data/run_ml_training_job.sbatch")           # doctest: +SKIP
    True
    """
    assert script_path, f"Script path must be non-empty, got '{script_path}'."
    container = _find_slurm_master_container()
    if not container:
        return False

    command = f"sbatch {script_path}"
    logger.info(f"Executing '{command}' in container '{container.name}'.")
    try:
        exit_code, stdout, stderr = _execute_in_container(container, command)
        output_log = f"STDOUT: {stdout}" if stdout else "No STDOUT."
        if stderr:
            output_log += f" STDERR: {stderr}"
        if exit_code == 0:
            logger.info(f"sbatch executed. {output_log}")
            if "Submitted batch job" in stdout:
                logger.info("Job submission confirmed.")
                return True
            logger.warning(f"No submission confirmation. {output_log}")
            return False
        logger.error(f"sbatch failed with exit code {exit_code}. {output_log}")
        return False
    except DockerException as error:
        logger.error(f"Error executing sbatch: {error}", exc_info=True)
        return False


def _sync_app_source(source_dir: Path, target_dir: Path) -> None:
    """Synchronize application source code for Slurm workers.

    Copies the current application source from ``source_dir`` to
    ``target_dir`` inside the shared volume used by Slurm nodes. Uses
    ``rsync`` with archive and delete flags when available for correctness and
    performance; falls back to ``cp -rT`` otherwise.

    Parameters
    ----------
    source_dir : Path
        Local path to the code to be synchronized (usually ``app/``).
    target_dir : Path
        Destination under ``/data`` (e.g., ``/data/app_code_for_slurm``).

    Raises
    ------
    AssertionError
        If ``source_dir`` does not exist.
    OSError
        If the synchronization command exits with a non-zero status.

    Examples
    --------
    >>> from pathlib import Path                                   # doctest: +SKIP
    >>> from app.slurm_job_trigger import _sync_app_source         # doctest: +SKIP
    >>> _sync_app_source(Path("src/app"), Path("/data/app_code_for_slurm"))  # doctest: +SKIP
    """
    assert source_dir.exists(), f"Source directory {source_dir} does not exist."
    if target_dir.exists():
        logger.info("Removing old app_code_for_slurm for clean sync.")
        shutil.rmtree(target_dir, ignore_errors=True)
    target_dir.mkdir(parents=True, exist_ok=True)

    if shutil.which(RSYNC_COMMAND):
        cmd = [RSYNC_COMMAND] + list(RSYNC_ARGS) + [f"{source_dir}/", f"{target_dir}/"]
    else:
        logger.warning(
            "rsync not available, falling back to cp. Install rsync for better performance."
        )
        cmd = [CP_COMMAND] + list(CP_ARGS) + [str(source_dir), f"{target_dir}/"]

    logger.info(f"Synchronizing code with: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True, check=False)
    if result.returncode != 0:
        raise OSError(
            f"Code synchronization failed with exit code {result.returncode}: {result.stderr}"
        )
    logger.info("Application source synchronized successfully.")


def _write_sbatch_script(
    script_path: Path,
    database_url: str,
    python_script: str,
    python_path: str,
) -> None:
    """Write a deterministic Slurm ``sbatch`` script to the shared volume.

    Parameters
    ----------
    script_path : Path
        Absolute path under ``/data`` where the script will be written.
    database_url : str
        Database URL exported to the job via ``DATABASE_URL``.
    python_script : str
        Path to the Python entrypoint executed by the job
        (e.g., ``/data/app_code_for_slurm/ml_train.py``).
    python_path : str
        Value for ``PYTHONPATH`` so that the dispatched code can import
        application modules (typically the directory containing the code copy).

    Raises
    ------
    OSError
        If the script cannot be written to disk.

    Notes
    -----
    - The script logs basic job metadata and exits with the Python process
      exit code for Slurm accounting.
    - Output and error logs are written to ``/data/logs``.
    """
    script_content = f"""#!/bin/bash
#SBATCH --job-name=ml_training_job
#SBATCH --partition=cpu-nodes
#SBATCH --output={SHARED_DATA_PATH}/logs/ml_train_%j.out
#SBATCH --error={SHARED_DATA_PATH}/logs/ml_train_%j.err
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2
#SBATCH --mem=2G

export DATABASE_URL="{database_url}"
export PYTHONPATH={python_path}
export SLURM_JOB_DATA_PATH={SHARED_DATA_PATH}

echo "--- SLURM JOB START ---"
echo "Job ID: $SLURM_JOB_ID"
echo "Running on node: $(hostname)"
echo "Date: $(date)"
echo "Python path: $PYTHONPATH"
echo "--- Python Execution ---"
python3 {python_script}
exit_code=$?
echo "Script exit code: $exit_code"
exit $exit_code
"""
    try:
        script_path.write_text(script_content)
        logger.info(f"sbatch script written to {script_path}")
    except OSError as error:
        logger.error(f"Failed to write sbatch script: {error}", exc_info=True)
        raise


def _update_training_status(message: str) -> None:
    """Update the training status record with the given message.

    Parameters
    ----------
    message : str
        Human-readable status update stored in ``TrainingStatus.current_horizon``.

    Raises
    ------
    AssertionError
        If ``message`` is empty.

    Notes
    -----
    - Any :class:`sqlalchemy.exc.SQLAlchemyError` is caught and logged; the
      function does not raise on database errors.
    """
    assert message, f"Status message must be non-empty, got '{message}'."
    try:
        with SessionLocal() as session:
            status = session.query(TrainingStatus).get(1)
            if status:
                status.current_horizon = message
                session.commit()
                logger.info(f"Training status updated: {message}")
    except SQLAlchemyError as error:
        logger.error(f"Failed to update training status: {error}", exc_info=True)


def create_and_dispatch_training_job() -> None:
    """Prepare code, write sbatch script, and dispatch the training job to Slurm.

    High-level orchestration that ensures the shared ``/data`` volume has a
    fresh copy of the application code, writes a minimal ``sbatch`` script with
    required environment variables, and triggers submission on the Slurm master
    container. Updates :class:`app.models.TrainingStatus` accordingly, or resets
    the training flag with a human-readable reason on failure.

    Returns
    -------
    None

    Notes
    -----
    - The function relies on the ``DATABASE_URL`` environment variable. If it
      is missing, no submission is attempted and the status is cleared.
    - Operational errors from Docker, filesystem, or database are logged and
      cause a safe status reset without raising exceptions to the caller.

    Examples
    --------
    >>> from app.slurm_job_trigger import create_and_dispatch_training_job  # doctest: +SKIP
    >>> create_and_dispatch_training_job()                                   # doctest: +SKIP
    """
    logger.info("Dispatching a new ML training job to Slurm cluster.")
    database_url = os.getenv("DATABASE_URL")
    if not database_url:
        logger.error("DATABASE_URL not set.")
        clear_training_flag_on_failure("DATABASE_URL not set")
        return

    source_dir = Path(__file__).parent.resolve()
    target_dir = SHARED_DATA_PATH / "app_code_for_slurm"
    sbatch_script_path = SHARED_DATA_PATH / SBATCH_SCRIPT_NAME

    try:
        (SHARED_DATA_PATH / "logs").mkdir(parents=True, exist_ok=True)
        _sync_app_source(source_dir, target_dir)
        _write_sbatch_script(
            sbatch_script_path,
            database_url,
            str(target_dir / "ml_train.py"),
            str(target_dir),
        )
        if trigger_slurm_job(str(sbatch_script_path)):
            _update_training_status("Job submitted to Slurm")
        else:
            clear_training_flag_on_failure("Dispatch process failed")
    except (OSError, DockerException, SQLAlchemyError) as error:
        logger.error(f"Dispatch failed: {error}", exc_info=True)
        clear_training_flag_on_failure("Dispatch process failed")