Source code for idmtools_platform_slurm.utils.slurm_job.slurm_job

"""
This is a SlurmPlatform utility.

Copyright 2021, Bill & Melinda Gates Foundation. All rights reserved.
"""
import os
import subprocess
from os import PathLike
from pathlib import Path
from dataclasses import dataclass, field
from typing import NoReturn, Union, List, TYPE_CHECKING
from idmtools.core import NoPlatformException
from jinja2 import Template
from logging import getLogger
from idmtools_platform_slurm.utils.slurm_job import create_slurm_indicator, slurm_installed

user_logger = getLogger('user')

if TYPE_CHECKING:
    from idmtools_platform_slurm.slurm_platform import SlurmPlatform

DEFAULT_TEMPLATE_FILE = "script_sbatch.sh.jinja2"
MSG = """Note: any output information from your script is stored in file stdout.txt under the script folder. For example, if you are running a script under current directory which kicks out another Slurm job, then the second Slurm job id is stored in stdout.txt under the current directory."""

TEMP_FILES = ['sbatch.sh', 'job_id.txt', 'job_status.txt', 'stdout.txt', 'stderr.txt']


[docs]def generate_script(platform: 'SlurmPlatform', command: str, template: Union[Path, str] = DEFAULT_TEMPLATE_FILE, batch_dir: str = None, **kwargs) -> None: """ Generate batch file sbatch.sh Args: platform: Slurm Platform command: execution command template: template to be used to build batch file kwargs: keyword arguments used to expand functionality Returns: None """ from idmtools_platform_slurm.slurm_platform import CONFIG_PARAMETERS template_vars = dict( platform=platform, command=command ) # Populate from our platform config vars for p in CONFIG_PARAMETERS: if getattr(platform, p) is not None: template_vars[p] = getattr(platform, p) template_vars.update(kwargs) if platform.modules: template_vars['modules'] = platform.modules with open(Path(__file__).parent.joinpath(template)) as tin: t = Template(tin.read()) # Write our file if batch_dir is None: output_target = Path.cwd().joinpath("sbatch.sh") else: output_target = Path(batch_dir).joinpath("sbatch.sh") with open(output_target, "w") as tout: tout.write(t.render(template_vars)) # Make executable platform._op_client.update_script_mode(output_target)
[docs]@dataclass(repr=False) class SlurmJob: script_path: PathLike = field(init=True) platform: 'SlurmPlatform' = field(default=None, init=True) executable: str = field(default='python3', init=True) script_params: List[str] = field(default=None, init=True) cleanup: bool = field(default=True, init=True) def __post_init__(self): if self.script_path is None: raise RuntimeError("script_path is missing!") # load platform from context or from passed in value self.platform = self.__check_for_platform_from_context(self.platform) self.working_directory = Path(self.script_path).parent self.script_params = self.script_params if self.script_params is not None and len( self.script_params) > 0 else None self.slurm_job_id = None
[docs] def initialization(self): # make str list so that we may join them together if self.script_params is not None: self.script_params = [str(i) for i in self.script_params] if self.script_params is not None: command = f"{self.executable} {Path(self.script_path).name} {' '.join(self.script_params)}" else: command = f"{self.executable} {Path(self.script_path).name}" generate_script(self.platform, command, batch_dir=self.working_directory)
[docs] def run(self, dry_run: bool = False, **kwargs) -> NoReturn: if self.cleanup: self.clean(self.working_directory) self.initialization() if not dry_run: if not slurm_installed(): user_logger.warning('Slurm is not installed/available!') exit(-1) user_logger.info('Script is running as a slurm job!\n') # Add indicator to avoid recursive loop create_slurm_indicator() # Run script as Slurm job result = subprocess.run(['sbatch', '--parsable', 'sbatch.sh'], stdout=subprocess.PIPE, cwd=str(self.working_directory)) self.slurm_job_id = result.stdout.decode('utf-8').strip().split(';')[0] user_logger.info(f"{'job_id: '.ljust(20)} {self.slurm_job_id}") user_logger.info(f"{'job_directory: '.ljust(20)} {self.platform.job_directory}\n") user_logger.warning(MSG) else: print('Script run with dry_run = True')
def __check_for_platform_from_context(self, platform) -> 'IPlatform': # noqa: F821 """ Try to determine platform of current object from self or current platform. Args: platform: Passed in platform object Raises: NoPlatformException: when no platform is on current context Returns: Platform object """ if self.platform is None: # check context for current platform if platform is None: from idmtools.core.context import CURRENT_PLATFORM if CURRENT_PLATFORM is None: raise NoPlatformException("No Platform defined on object, in current context, or passed to run") platform = CURRENT_PLATFORM self.platform = platform return self.platform
[docs] def clean(self, cwd: str = os.getcwd()): """ Delete generated slurm job related files. Args: cwd: the directory containing the files Returns: None """ for file_path in TEMP_FILES: f = os.path.join(cwd, file_path) if os.path.exists(f): try: os.remove(f) except: pass