Source code for idmtools_models.json_configured_task
"""idmtools json configured task.
Copyright 2021, Bill & Melinda Gates Foundation. All rights reserved.
"""
import json
from dataclasses import dataclass, field, fields
from functools import partial
from logging import getLogger, DEBUG
from typing import Union, Dict, Any, List, Optional, Type, TYPE_CHECKING
from idmtools.assets import Asset, AssetCollection
from idmtools.entities.itask import ITask
from idmtools.entities.simulation import Simulation
from idmtools.registry.task_specification import TaskSpecification
if TYPE_CHECKING: # pragma: no cover
from idmtools.entities.iplatform import IPlatform
TJSONConfigKeyType = Union[str, int, float]
TJSONConfigValueType = Union[str, int, float, Dict[TJSONConfigKeyType, Any]]
logger = getLogger(__name__)
user_logger = getLogger('user')
[docs]@dataclass
class JSONConfiguredTask(ITask):
"""
Defines an extensible simple task that implements functionality through optional supplied use hooks.
"""
# Note: large amounts of parameters will increase size of metadata
parameters: dict = field(default_factory=lambda: {}, metadata={"md": True})
envelope: str = field(default=None, metadata={"md": True})
# If we don't define this we assume static name the script consuming file will know
config_file_name: str = field(default="config.json", metadata={"md": True})
# is the config file a common asset or a transient. We default ot transient
is_config_common: bool = field(default=False)
configfile_argument: str = field(default=None)
# If command_line_argument is set, defines if we pass the filename after the argument
# for example, if the argument is --config and the config file name is config.json we would run the command as
# cmd --config config.json
command_line_argument_no_filename: bool = field(default=False)
def __post_init__(self):
"""Constructor."""
super().__post_init__()
if self.parameters is not None and self.envelope is not None and self.envelope in self.parameters:
logger.debug(f'Loading parameters from envelope: {self.envelope}')
self.parameters = self.parameters[self.envelope]
[docs] def gather_common_assets(self) -> AssetCollection:
"""
Gather assets common across an Experiment(Set of Simulations).
Returns:
Common AssetCollection
"""
if self.is_config_common:
self.__dump_config(self.common_assets)
return self.common_assets
[docs] def gather_transient_assets(self) -> AssetCollection:
"""
Gather assets that are unique to this simulation/worktiem.
Returns:
Simulation/workitem level AssetCollection
"""
self.__dump_config(self.transient_assets)
return self.transient_assets
def __dump_config(self, assets) -> None:
"""
Writes the configuration out to asset.
Args:
assets: Asset to add configuration too
Returns:
None
"""
if self.config_file_name is not None:
params = {self.envelope: self.parameters} if self.envelope else self.parameters
if logger.isEnabledFor(DEBUG):
logger.debug('Adding JSON Configured File %s', self.config_file_name)
logger.debug(f'Generating {self.config_file_name} as an asset from JSONConfiguredTask')
logger.debug('Writing Config %s', json.dumps(params))
assets.add_or_replace_asset(Asset(filename=self.config_file_name, content=json.dumps(params)))
[docs] def set_parameter(self, key: TJSONConfigKeyType, value: TJSONConfigValueType):
"""
Update a parameter. The type hinting encourages JSON supported types.
Args:
key: Config
value:
Returns:
Tags to be defined on the simulation/workitem
"""
if logger.isEnabledFor(DEBUG):
logger.info('Setting parameter %s to %s', key, str(value))
self.parameters[key] = value
return {key: value}
[docs] def get_parameter(self, key: TJSONConfigKeyType) -> TJSONConfigValueType:
"""
Returns a parameter value.
Args:
key: Key of parameter
Returns:
Value of parameter
Raises:
KeyError
"""
return self.parameters[key]
[docs] def update_parameters(self, values: Dict[TJSONConfigKeyType, TJSONConfigValueType]):
"""
Perform bulk update from another dictionary.
Args:
values: Values to update as dictionaryy
Returns:
Values
"""
if logger.isEnabledFor(DEBUG):
for k, p in values.items():
logger.debug('Setting parameter %s to %s', k, str(p))
self.parameters.update(values)
return values
[docs] def reload_from_simulation(self, simulation: 'Simulation', config_file_name: Optional[str] = None,
envelope: Optional[str] = None, **kwargs): # noqa: F821
"""
Reload from Simulation.
To do this, the process is
1. First check for a configfile name from arguments, then tags, or the default name
2. Load the json config file
3. Check if we got an envelope argument from parameters or the simulation tags, or on the task object
Args:
simulation: Simulation object with metadata to load info from
config_file_name: Optional name of config file
envelope: Optional name of envelope
Returns:
Populates the config with config from object
"""
if simulation.platform:
self.parameters = self.__find_config(simulation)
if envelope and envelope in self.parameters:
self.parameters = self.parameters[envelope]
elif 'task_envelope' in simulation.tags and simulation.tags['task_envelope'] in self.parameters:
self.parameters = self.parameters[simulation.tags['task_envelope']]
elif self.envelope and self.envelope in self.parameters:
self.parameters = self.parameters[self.envelope]
def __find_config(self, simulation: Simulation, config_file_name: str = None) -> Dict[str, Any]:
"""
Used to rebuild configuration using simulation data that has been ran.
Args:
simulation: Simulation to load from
config_file_name: Config file name
Returns:
Config reloaded
"""
# find the ocnfig
if config_file_name:
cfn = config_file_name
elif 'task_config_file_name' in simulation.tags:
cfn = simulation.tags['task_config_file_name']
else:
cfn = self.config_file_name
if logger.isEnabledFor(DEBUG):
logger.debug(f'Loading Config from {simulation.id}:{cfn}')
config = dict()
if simulation.assets and isinstance(simulation.assets, (AssetCollection, list)):
for file in simulation.assets:
if file.filename == cfn:
config = file.content
if isinstance(config, bytes):
config = json.loads(config.decode('utf-8'))
new_assets = []
# filter our config from the simulation
for _i, asset in enumerate(simulation.assets.assets):
if asset.filename != cfn:
new_assets.append(asset)
simulation.assets.assets = new_assets
else:
# try to load the config
config = simulation.platform.get_files(simulation, [cfn])
config = config[cfn]
if isinstance(config, bytes):
config = json.loads(config.decode('utf-8'))
# filter config from transient assets
if self.transient_assets:
nw = AssetCollection()
for asset in self.transient_assets:
if isinstance(asset, dict) and asset['filename'] != cfn:
self.transient_assets.add_asset(Asset(**asset))
elif isinstance(asset, Asset) and asset.filename != cfn:
self.transient_assets.add_asset(Asset(**asset))
self.transient_assets = nw
return config
[docs] def pre_creation(self, parent: Union['Simulation', 'WorkflowItem'], platform: 'IPlatform'): # noqa: F821
"""
Pre-creation. For JSONConfiguredTask, we finalize our configuration file and command line here.
Args:
parent: Parent of task
platform: Platform task is being created on
Returns:
None
"""
defaults = [x for x in fields(JSONConfiguredTask) if x.name == "config_file_name"][0].default
if self.config_file_name != defaults:
logger.info('Found non-default name for config_file_name. Adding tag task_config_file_name')
parent.tags['task_config_file_name'] = self.config_file_name
if self.envelope:
logger.info('Found envelope name. Adding tag envelope')
parent.tags['task_envelope'] = self.envelope
# Ensure our command line argument is added if configured
if self.configfile_argument:
logger.debug('Adding command_line_argument to command')
if self.configfile_argument not in self.command.arguments:
# check if we should add filename with arg?
if self.command_line_argument_no_filename:
self.command.add_argument(self.configfile_argument)
else:
self.command.add_argument(self.configfile_argument)
self.command.add_argument(self.config_file_name)
def __repr__(self):
"""String version of task Prints config filename and parameters."""
return f"<JSONConfiguredTask config:{self.config_file_name} parameters: {self.parameters}"
[docs] @staticmethod
def set_parameter_sweep_callback(simulation: Simulation, param: str, value: Any) -> Dict[str, Any]:
"""
Performs a callback with a parameter and a value. Most likely users want to use set_parameter_partial instead of this method.
Args:
simulation: Simulation object
param: Param name
value: Value to set
Returns:
Tags to add to simulation
"""
if not hasattr(simulation.task, 'set_parameter'):
raise ValueError("update_task_with_set_parameter can only be used on tasks with a set_parameter")
return simulation.task.set_parameter(param, value)
[docs] @classmethod
def set_parameter_partial(cls, parameter: str):
"""
Callback to be used when sweeping with a json configured model.
Args:
parameter: Param name
Returns:
Partial setting a specific parameter
Notes:
- TODO Reference some examples code here
"""
return partial(cls.set_parameter_sweep_callback, param=parameter)
[docs]class JSONConfiguredTaskSpecification(TaskSpecification):
"""
JSONConfiguredTaskSpecification defines the plugin specs for JSONConfiguredTask.
"""
[docs] def get(self, configuration: dict) -> JSONConfiguredTask:
"""
Get instance of JSONConfiguredTask with configuration specified.
Args:
configuration: Configuration for configuration
Returns:
JSONConfiguredTask with configuration
"""
return JSONConfiguredTask(**configuration)
[docs] def get_description(self) -> str:
"""
Get description for plugin.
Returns:
Description of plugin
"""
return "Defines a general command that has a simple JSON based config"
[docs] def get_example_urls(self) -> List[str]:
"""
Get list of urls with examples for JSONConfiguredTask.
Returns:
List of urls that point to examples relating to JSONConfiguredTask
"""
from idmtools_models import __version__
examples = [f'examples/{example}' for example in ['python_model', 'load_lib']]
return [self.get_version_url(f'v{__version__}', x) for x in examples]
[docs] def get_type(self) -> Type[JSONConfiguredTask]:
"""
Get task type provided by plugin.
Returns:
JSONConfiguredTask
"""
return JSONConfiguredTask
[docs] def get_version(self) -> str:
"""
Returns the version of the plugin.
Returns:
Plugin Version
"""
from idmtools_models import __version__
return __version__