"""
Base classes for laser-measles components and models.
This module contains the base classes for laser-measles components and models.
The BaseComponent class is the base class for all laser-measles components.
It provides a uniform interface for all components with a __call__(model, tick) method
for execution during simulation loops.
The BaseLaserModel class is the base class for all laser-measles models.
"""
from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import Protocol
from typing import TypeVar
import alive_progress
import matplotlib.pyplot as plt
import polars as pl
from laser_core.laserframe import LaserFrame
from laser_core.random import seed as seed_prng
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.figure import Figure
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from laser_measles.utils import StateArray
from laser_measles.utils import get_laserframe_properties
from laser_measles.utils import select_implementation
from laser_measles.wrapper import PrettyComponentsList
from laser_measles.wrapper import pretty_laserframe
class ParamsProtocol(Protocol):
"""Protocol defining the expected structure of model parameters."""
seed: int
start_time: str
num_ticks: int
verbose: bool
show_progress: bool
@property
def time_step_days(self) -> int: ...
@property
def states(self) -> list[str]: ...
class BaseModelParams(BaseModel):
"""
Base parameters for all laser-measles models.
This class provides common parameters that are shared across all model types.
Model-specific parameter classes should inherit from this class.
"""
model_config = ConfigDict(extra="forbid")
seed: int = Field(default=20250314, description="Random seed")
start_time: str = Field(default="2000-01", description="Initial start time of simulation in YYYY-MM format")
num_ticks: int = Field(default=365, description="Number of time steps")
verbose: bool = Field(default=False, description="Whether to print verbose output")
show_progress: bool = Field(default=True, description="Whether to show progress bar during simulation")
use_numba: bool = Field(default=True, description="Whether to use numba acceleration when available")
@property
def time_step_days(self) -> int:
"""Time step in days. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement time_step_days")
@property
def states(self) -> list[str]:
"""List of model states. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement states")
@pretty_laserframe
class BasePatchLaserFrame(LaserFrame):
"""LaserFrame that has a states property."""
states: StateArray # StateArray with attribute access (S, E, I, R, etc.)
@pretty_laserframe
class BasePeopleLaserFrame(LaserFrame):
"""
Base class for people LaserFrames with enhanced printing capabilities.
This class provides factory methods for creating new instances with the same
properties but different capacity, making it easy to resize people collections.
"""
@classmethod
def create_with_capacity(cls, capacity: int, source_frame: BasePeopleLaserFrame, initial_count: int = -1) -> Any:
"""
Create a new instance of the same type with specified capacity.
This factory method creates a new instance of the same class as the source_frame,
with the specified capacity, and copies all properties from the source.
Args:
capacity: The capacity for the new LaserFrame.
source_frame: The source LaserFrame to copy properties from.
initial_count: The initial number of "active" agents in the new frame.
If -1, the count is set to the capacity. Defaults to -1.
Returns:
A new instance of the same type with copied properties.
"""
# Create new instance of the same type
new_frame = cls(capacity=capacity, initial_count=initial_count)
# Copy all properties from source
new_frame.copy_properties_from(source_frame)
return new_frame
def copy_properties_from(self, source_frame: BasePeopleLaserFrame) -> None:
"""
Copy all properties from another LaserFrame instance.
This method copies all scalar and vector properties from the source frame,
including their data types and default values.
Args:
source_frame: The source LaserFrame to copy properties from.
"""
properties = get_laserframe_properties(source_frame)
for property_name in properties:
source_property = getattr(source_frame, property_name)
if source_property.ndim == 1:
# Scalar property
self.add_scalar_property(
property_name, dtype=source_property.dtype, default=source_property[0] if len(source_property) > 0 else 0
)
elif source_property.ndim == 2:
# Vector property
self.add_vector_property(
property_name,
len(source_property),
dtype=source_property.dtype,
default=source_property[:, 0] if source_property.shape[1] > 0 else 0,
)
else:
# Handle higher dimensional properties if needed
raise NotImplementedError(f"Property {property_name} has {source_property.ndim} dimensions, not supported")
[docs]
class BaseLaserModel(ABC):
"""
Base class for laser-measles simulation models.
Provides common functionality for model initialization, component management,
timing, metrics collection, and execution loops.
"""
ScenarioType = TypeVar("ScenarioType")
ParamsType = TypeVar("ParamsType", bound=ParamsProtocol)
def __init__(self, scenario: pl.DataFrame | BaseScenario, params: BaseModelParams, name: str) -> None:
"""
Initialize the model with common attributes.
Args:
scenario: Scenario data (type varies by model).
params: Model parameters (type varies by model).
name: Model name.
"""
self.tinit = datetime.now(tz=None) # noqa: DTZ005
if params.verbose:
print(f"{self.tinit}: Creating the {name} model…")
# Auto-wrap polars DataFrame in appropriate scenario class if needed
if isinstance(scenario, pl.DataFrame) and self.scenario_wrapper_class is not None:
scenario = self.scenario_wrapper_class(scenario)
self.scenario: BaseScenario = scenario
self.params: BaseModelParams = params
self.name = name
# Initialize random number generator
seed_value = params.seed if hasattr(params, "seed") and params.seed is not None else self.tinit.microsecond
self.prng = seed_prng(seed_value)
# Component management attributes
self._components: list = []
self.instances: list = []
self.phases: list = [] # Called every tick
# Metrics and timing
self.metrics: list = []
self._tstart: datetime | None = None
self._tfinish: datetime | None = None
# Time tracking
self.start_time = datetime.strptime(self.params.start_time, "%Y-%m") # noqa DTZ007
self.current_date = self.start_time
# Type annotations for attributes that subclasses will set
self.patches: BasePatchLaserFrame
# Attribute for subclasses to specify scenario wrapper
self.scenario_wrapper_class: type[BaseScenario] | None = None
def __repr__(self) -> str:
"""
Return a string representation of the model, showing key attributes.
Returns:
str: String representation of the model, including LaserFrame attributes.
"""
attrs = []
for attr in dir(self):
if attr.startswith("_"):
continue
value = getattr(self, attr)
# Check if the attribute is a LaserFrame
if isinstance(value, LaserFrame):
attrs.append(f"{attr}=<LaserFrame capacity={getattr(value, 'capacity', None)}>")
else:
# Only show simple types to avoid clutter
if isinstance(value, int | float | str | bool | type(None)):
attrs.append(f"{attr}={value!r}")
return f"<{self.__class__.__name__}({', '.join(attrs)})>"
def __str__(self) -> str:
"""
Return a string representation of the model, showing key attributes.
"""
attrs = {}
for attr in dir(self):
if attr.startswith("_"):
continue
value = getattr(self, attr)
# Check if the attribute is a LaserFrame
if isinstance(value, LaserFrame):
attrs[attr] = "\n" + value.__str__()
else:
# Only show simple types to avoid clutter
if isinstance(value, int | float | str | bool | type(None)):
attrs[attr] = value.__str__()
newline = "\n"
return f"<{self.__class__.__name__}>:\n{newline.join([f'{k}: {v}' for k, v in attrs.items()])}>"
@abstractmethod
def __call__(self, model: BaseLaserModel, tick: int) -> None:
"""
Hook for subclasses to update the model for a given tick.
Args:
model: The model instance.
tick: The current time step or tick.
"""
@property
def components(self) -> PrettyComponentsList:
"""
Retrieve the list of model components.
Returns:
A PrettyComponentsList containing the components with enhanced formatting.
"""
return PrettyComponentsList(self._components)
@components.setter
def components(self, components: list[type[BaseComponent]]) -> None:
"""
Sets up the components of the model and constructs all instances.
Args:
components: A list of component classes to be initialized and integrated into the model.
"""
self._components = components
self.instances = []
self.phases = []
for component in components:
instance = component(self, verbose=getattr(self.params, "verbose", False))
self.instances.append(instance)
if "__call__" in dir(instance):
self.phases.append(instance)
# Allow subclasses to perform additional component setup
self._setup_components()
def add_component(self, component: type[BaseComponent]) -> None:
"""
Add the component class and an instance in model.instances.
Note that this does not create new instances of other components.
Args:
component: A component class to be initialized and integrated into the model.
"""
self._components.append(component)
instance = component(self, verbose=getattr(self.params, "verbose", False))
self.instances.append(instance)
if "__call__" in dir(instance):
self.phases.append(instance)
self._setup_components()
def prepend_component(self, component: type[BaseComponent]) -> None:
"""
Add a component to the beginning of the component list.
Args:
component: A component class to be initialized and integrated into the model.
"""
self._components.insert(0, component)
instance = component(self, verbose=getattr(self.params, "verbose", False))
self.instances.insert(0, instance)
if "__call__" in dir(instance):
self.phases.insert(0, instance)
self._setup_components()
def run(self) -> None:
"""
Execute the model for a specified number of ticks, recording timing metrics.
"""
# Check that there are some components to the model
if len(self.components) == 0:
raise RuntimeError("No components have been added to the model")
# Initialize all component instances
self._initialize()
# TODO: Check that the model has been initialized
num_ticks = self.params.num_ticks
self._tstart = datetime.now(tz=None) # noqa: DTZ005
if self.params.verbose:
print(f"{self._tstart}: Running the {self.name} model for {num_ticks} ticks…")
self.metrics = []
# Create progress bar only if show_progress is True
if self.params.show_progress:
with alive_progress.alive_bar(num_ticks) as bar:
for tick in range(num_ticks):
self._execute_tick(tick)
bar()
else:
# Run without progress bar
for tick in range(num_ticks):
self._execute_tick(tick)
self._tfinish = datetime.now(tz=None) # noqa: DTZ005
if self.params.verbose:
print(f"Completed the {self.name} model at {self._tfinish}…")
self._print_timing_summary()
def time_elapsed(self, units: str = "days") -> int | float:
"""
Return time elapsed since the start of the model.
Args:
units: Time units to return. Currently only supports "days" and "ticks".
Returns:
Time elapsed in the specified units.
Raises:
ValueError: If invalid time units are specified.
"""
if units.lower() == "days":
return (self.current_date - self.start_time).days
elif units.lower() == "ticks":
return (self.current_date - self.start_time).days / self.params.time_step_days
else:
raise ValueError(f"Invalid time units: {units}")
def _initialize(self) -> None:
"""
Initialize all component instances in the model.
This method calls initialize() on all component instances and sets
their initialized flag to True after successful initialization.
"""
for instance in self.instances:
if hasattr(instance, "_initialize") and hasattr(instance, "initialized"):
instance._initialize(self)
instance.initialized = True
def get_tick_date(self, tick: int) -> datetime:
"""
Return the date for a given tick.
"""
return self.start_time + timedelta(days=tick * self.params.time_step_days)
def cleanup(self) -> None:
"""
Clean up model resources to prevent memory leaks.
This method should be called when the model is no longer needed
to free up memory from LaserFrame objects and other large data structures.
"""
try:
# Clear LaserFrame objects
if hasattr(self, "patches") and self.patches is not None:
# Clear all properties from the LaserFrame
if hasattr(self.patches, "_properties"):
for prop_name in list(self.patches._properties.keys()):
setattr(self.patches, prop_name, None)
self.patches._properties.clear()
# Reset LaserFrame capacity and count
if hasattr(self.patches, "_capacity"):
self.patches._capacity = 0
if hasattr(self.patches, "_count"):
self.patches._count = 0
self.patches = None
if hasattr(self, "people") and self.people is not None:
# Clear all properties from the LaserFrame
if hasattr(self.people, "_properties"):
for prop_name in list(self.people._properties.keys()):
setattr(self.people, prop_name, None)
self.people._properties.clear()
# Reset LaserFrame capacity and count
if hasattr(self.people, "_capacity"):
self.people._capacity = 0
if hasattr(self.people, "_count"):
self.people._count = 0
self.people = None
# Clear component instances and their references
if hasattr(self, "instances"):
for instance in self.instances:
# Clear any LaserFrame references in components
if hasattr(instance, "model"):
instance.model = None
# Clear any large data structures in components
for attr_name in dir(instance):
if not attr_name.startswith("_") and attr_name not in ["initialized", "verbose"]:
attr_value = getattr(instance, attr_name, None)
if hasattr(attr_value, "__len__") and not callable(attr_value):
try:
setattr(instance, attr_name, None)
except (AttributeError, TypeError):
pass # Skip if attribute is read-only
self.instances.clear()
# Clear phases and components
if hasattr(self, "phases"):
self.phases.clear()
if hasattr(self, "_components"):
self._components.clear()
# Clear metrics and other large data structures
if hasattr(self, "metrics"):
self.metrics.clear()
# Clear scenario and params references to large data
if hasattr(self, "scenario"):
del self.scenario
if hasattr(self, "params"):
# Clear any large data structures in params
del self.params
# Clear random number generator
if hasattr(self, "prng"):
del self.prng
except Exception as e:
# Don't let cleanup errors crash the program
print(f"Warning: Error during model cleanup: {e}")
def get_instance(self, cls: type | str) -> list:
"""
Get all instances of a specific component class.
Args:
cls: The component class to search for.
Returns:
List of instances of the specified class, or [None] if none found.
Works with inheritance - subclasses will match parent class searches.
Example:
state_trackers = model.get_instance(StateTracker)
if state_trackers:
state_tracker = state_trackers[0] # Get first instance
"""
if isinstance(cls, str):
matches = [instance for instance in self.instances if instance.name == cls]
else:
matches = [instance for instance in self.instances if isinstance(instance, cls)]
return matches if matches else [None]
def get_component(self, cls: type | str) -> list:
"""
Alias for get_instance (instances are instantiated, components are not).
Args:
cls: The component class to search for.
Returns:
List of instances of the specified class, or [None] if none found.
"""
return self.get_instance(cls)
def visualize(self, pdf: bool = True) -> None:
"""
Visualize each component instances either by displaying plots or saving them to a PDF file.
Args:
pdf: If True, save the plots to a PDF file. If False, display the plots interactively.
Defaults to True.
Returns:
None
"""
if not pdf:
for instance in self.instances:
for _plot in instance.plot():
plt.show()
else:
print("Generating PDF output…")
pdf_filename = f"{self.name} {self._tstart:%Y-%m-%d %H%M%S}.pdf"
with PdfPages(pdf_filename) as pdf_file:
for instance in self.instances:
for _plot in instance.plot():
pdf_file.savefig()
plt.close()
print(f"PDF output saved to '{pdf_filename}'.")
return
def plot(self, fig: Figure | None = None):
"""
Placeholder for plotting method.
Args:
fig: Optional matplotlib figure to plot on.
Raises:
NotImplementedError: Subclasses must implement this method.
"""
raise NotImplementedError("Subclasses must implement this method")
def _execute_tick(self, tick: int) -> None:
"""
Execute a single tick.
Can be overridden by subclasses for custom behavior.
Args:
tick: The current tick number.
"""
timing = [tick]
for phase in self.phases:
tstart = datetime.now(tz=None) # noqa: DTZ005
phase(self, tick)
tfinish = datetime.now(tz=None) # noqa: DTZ005
delta = tfinish - tstart
timing.append(delta.seconds * 1_000_000 + delta.microseconds)
self.metrics.append(timing)
# Update current date by time_step_days
self.current_date += timedelta(days=self.params.time_step_days)
def _print_timing_summary(self) -> None:
"""
Print timing summary for verbose mode.
"""
try:
import pandas as pd # noqa: PLC0415
names = [type(phase).__name__ for phase in self.phases]
# Fix the pandas DataFrame creation by using proper column specification
metrics = pd.DataFrame(self.metrics)
if len(names) > 0:
metrics.columns = ["tick", *names]
plot_columns = metrics.columns[1:]
sum_columns = metrics[plot_columns].sum()
width = max(map(len, sum_columns.index))
for key in sum_columns.index:
print(f"{key:{width}}: {sum_columns[key]:13,} µs")
print("=" * (width + 2 + 13 + 3))
print(f"{'Total:':{width + 1}} {sum_columns.sum():13,} microseconds")
except ImportError:
try:
import polars as pl # noqa: PLC0415
names = [type(phase).__name__ for phase in self.phases]
metrics = pl.DataFrame(self.metrics, schema=["tick", *names])
plot_columns = metrics.columns[1:]
sum_columns = metrics.select(plot_columns).sum()
# Handle polars DataFrame differently
print("Timing summary available but detailed formatting requires pandas")
except ImportError:
print("Timing summary requires pandas or polars")
@abstractmethod
def _setup_components(self) -> None:
"""
Hook for subclasses to perform additional component setup.
"""
[docs]
class BaseComponent:
"""
Base class for all laser-measles components.
Components follow a uniform interface with __call__(model, tick) method
for execution during simulation loops.
"""
ModelType = TypeVar("ModelType")
def __init__(self, model: BaseLaserModel, verbose: bool = False, params: None = None) -> None: # TODO: add ParamsType
"""
Initialize the component.
Args:
model: The model instance this component belongs to.
verbose: Whether to enable verbose output. Defaults to False.
"""
self.model = model
self.verbose = verbose
self.initialized = False
self.params = params
if not hasattr(self, "name"):
self.name = self.__class__.__name__
def __str__(self) -> str:
"""
Return string representation using class docstring.
Returns:
String representation of the component.
"""
# Use child class docstring if available, otherwise parent class
doc = self.__class__.__doc__ or BaseComponent.__doc__
return doc.strip() if doc else f"{self.__class__.__name__} component"
def select_function(self, numpy_func: Any, numba_func: Any) -> Any:
"""
Select between numpy and numba implementations based on model configuration.
This method provides a convenient way for components to choose between
numpy and numba implementations based on model parameters and environment
variables.
Args:
numpy_func: The numpy implementation function.
numba_func: The numba implementation function.
Returns:
The selected function implementation.
Example:
>>> # In a component's __init__ or _initialize method:
>>> self.update_func = self.select_function(numpy_update, numba_update)
"""
# Check if model has use_numba parameter
use_numba = getattr(self.model.params, "use_numba", True)
return select_implementation(numpy_func, numba_func, use_numba)
def plot(self, fig: Figure | None = None):
"""
Placeholder for plotting method.
Args:
fig: Optional matplotlib figure to plot on.
Yields:
None: Placeholder for plot objects.
"""
yield None
@abstractmethod
def _initialize(self, model: BaseLaserModel) -> None:
"""
Hook for subclasses to initialize the component based on other existing components.
This is run at the beginning of model.run().
Args:
model: The model instance.
"""
class BasePhase(BaseComponent):
"""
Base class for all laser-measles phases.
Phases are components that are called every tick and include a __call__ method.
"""
@abstractmethod
def __call__(self, model, tick: int) -> None:
"""
Execute component logic for a given simulation tick.
Args:
model: The model instance.
tick: The current simulation tick.
"""
@abstractmethod
def _initialize(self, model: BaseLaserModel) -> None:
pass
class BaseScenario(ABC):
"""
Base class for scenario data wrappers.
Provides a wrapper around polars DataFrames with additional validation
and convenience methods.
"""
def __init__(self, df: pl.DataFrame):
"""
Initialize the scenario with a DataFrame.
Args:
df: The polars DataFrame containing scenario data.
"""
self._df = df
def __getattr__(self, attr):
"""
Forward attribute access to the underlying DataFrame.
Args:
attr: The attribute name.
Returns:
The attribute value from the underlying DataFrame.
"""
# Forward attribute access to the underlying DataFrame
return getattr(self._df, attr)
def __getitem__(self, key):
"""
Forward item access to the underlying DataFrame.
Args:
key: The key to access.
Returns:
The value from the underlying DataFrame.
"""
return self._df[key]
def __repr__(self):
"""
Return string representation of the scenario.
Returns:
String representation of the underlying DataFrame.
"""
return repr(self._df)
def __len__(self):
"""
Return the length of the underlying DataFrame.
Returns:
The number of rows in the DataFrame.
"""
return len(self._df)
def unwrap(self) -> pl.DataFrame:
"""
Return the underlying polars DataFrame.
Returns:
The underlying polars DataFrame.
"""
return self._df
def find_row_number(self, column: str, target_value: str) -> int:
"""
Find the row number (0-based index) of a target string in a DataFrame column.
Args:
column: Column name to search in.
target_value: String value to find.
Returns:
Row number (0-based index) of the target string.
Raises:
ValueError: If the target string is not found.
"""
# Use arg_max on a boolean mask for maximum efficiency
mask = self._df[column] == target_value
# Check if value exists
if not mask.any():
raise ValueError(f"String '{target_value}' not found in column '{column}'")
# arg_max returns the index of the first True value
result = mask.arg_max()
if result is None:
raise ValueError(f"String '{target_value}' not found in column '{column}'")
return result
@abstractmethod
def _validate(self, df: pl.DataFrame):
"""
Validate required columns exist - derive from schema.
Args:
df: The DataFrame to validate.
Raises:
NotImplementedError: Subclasses must implement this method.
"""
# Validate required columns exist - derive from schema
raise NotImplementedError("Subclasses must implement this method")