Source code for idmtools.core.task_factory
"""
Define our tasks factory. This is crucial to build tasks when fetching from the server.
Copyright 2021, Bill & Melinda Gates Foundation. All rights reserved.
"""
from logging import getLogger
from typing import NoReturn, Type
from idmtools.entities.itask import ITask
from idmtools.registry.task_specification import TaskSpecification
logger = getLogger(__name__)
TASK_BUILDERS = None
[docs]class DynamicTaskSpecification(TaskSpecification):
"""
This class allows users to quickly define a spec for special tasks.
"""
[docs] def __init__(self, task_type: Type[ITask], description: str = ''):
"""
Initialize our specification.
Args:
task_type: Task type to register
description: Description to register with task
"""
self.task_type = task_type
self.description = description
[docs] def get(self, configuration: dict) -> ITask:
"""
Get an instance of our task using configuration.
Args:
configuration: Configuration keyword args.
Returns:
Task with configuration specified
"""
return self.task_type(**configuration)
[docs] def get_description(self) -> str:
"""
Get description of our plugin.
Returns:
Returns the user-defined plugin description.
"""
return self.description
[docs] def get_type(self) -> Type[ITask]:
"""
Get our task type.
Returns:
Returns our task type
"""
return self.task_type
[docs]class TaskFactory:
"""
TaskFactory allows creation of tasks that are derived from plugins.
"""
DEFAULT_KEY = 'idmtools.entities.command_task.CommandTask'
[docs] def __init__(self):
"""
Initialize our Factory.
"""
global TASK_BUILDERS
if TASK_BUILDERS is None:
from idmtools.registry.task_specification import TaskPlugins
TASK_BUILDERS = TaskPlugins().get_plugin_map()
self._builders = TASK_BUILDERS
aliases = dict()
# register types as full paths as well
for _model, spec in self._builders.items():
try:
aliases[f'{spec.get_type().__module__}.{spec.get_type().__name__}'] = spec
aliases[f'{spec.get_type().__name__}'] = spec
except Exception as e:
logger.warning(f"Could not load alias for {spec}")
logger.exception(e)
self._builders.update(aliases)
[docs] def register(self, spec: TaskSpecification) -> NoReturn:
"""
Register a TaskSpecification dynamically.
Args:
spec: Specification to register
Returns:
None
"""
type_name = spec.get_type().__name__
module_name = {spec.get_type().__module__}
logger.debug(f'Registering task: {type_name} as both {type_name} and as {module_name}.{type_name}')
self._builders[type_name] = spec
self._builders[f'{module_name}.{type_name}'] = spec
[docs] def register_task(self, task: Type[ITask]) -> NoReturn:
"""
Dynamically register a class using the DynamicTaskSpecification.
Args:
task: Task to register
Returns:
None
"""
spec = DynamicTaskSpecification(task)
self.register(spec)
[docs] def create(self, key, fallback=None, **kwargs) -> ITask: # noqa: F821
"""
Create a task of type key.
Args:
key: Type of task to create
fallback: Fallback task type. Default to DEFAULT_KEY if not provided
**kwargs: Optional arguments to pass to the task
Returns:
Task with option specified
"""
if key is None:
key = self.DEFAULT_KEY
logger.warning(f'No task type tag found, assuming type: {key}')
if key not in self._builders:
if not fallback:
raise ValueError(f"The TaskFactory could not create an task of type {key}")
else:
return fallback()
task_spec: TaskSpecification = self._builders.get(key)
return task_spec.get(kwargs)
task_factory = TaskFactory()