Source code for laser_measles.biweekly.model

"""
A class to represent the biweekly model.
"""

import numpy as np
import polars as pl

from laser_measles.base import BaseLaserModel
from laser_measles.biweekly.base import BaseBiweeklyScenario
from laser_measles.biweekly.base import PatchLaserFrame
from laser_measles.biweekly.params import BiweeklyParams
from laser_measles.utils import StateArray
from laser_measles.utils import cast_type


[docs] class BiweeklyModel(BaseLaserModel): """ A class to represent the biweekly model. Args: scenario (BaseScenario): A scenario containing the scenario data, including population, latitude, and longitude. params (BiweeklyParams): A set of parameters for the model. name (str, optional): The name of the model. Defaults to "biweekly". Notes: This class initializes the model with the given scenario and parameters. The scenario must include the following columns: - `id` (string): The name of the patch or location. - `pop` (integer): The population count for the patch. - `lat` (float degrees): The latitude of the patches (e.g., from geographic or population centroid). - `lon` (float degrees): The longitude of the patches (e.g., from geographic or population centroid). - `mcv1` (float): The MCV1 coverage for the patches. """ patches: PatchLaserFrame # Specify the scenario wrapper class for auto-wrapping DataFrames scenario_wrapper_class = BaseBiweeklyScenario def __init__(self, scenario: BaseBiweeklyScenario | pl.DataFrame, params: BiweeklyParams, name: str = "biweekly") -> None: """ Initialize the disease model with the given scenario and parameters. Args: scenario (BaseScenario): A scenario containing the scenario data, including population, latitude, and longitude. params (BiweeklyParams): A set of parameters for the model, including seed, nticks, k, a, b, c, max_frac, cbr, verbose, and pyramid_file. name (str, optional): The name of the model. Defaults to "biweekly". Returns: None """ super().__init__(scenario, params, name) # Add patches to the model self.patches = PatchLaserFrame(capacity=len(scenario)) # Create the state vector for each of the patches (3, num_patches) self.patches.add_vector_property("states", len(self.params.states)) # S, I, R # Wrap the states array with StateArray for attribute access self.patches.states = StateArray(self.patches.states, state_names=self.params.states) # Start with totally susceptible population self.patches.states.S[:] = scenario["pop"] return def __call__(self, model, tick: int) -> None: """ Updates the model for the next tick. Args: model: The model containing the patches and their populations. tick (int): The current time step or tick. Returns: None """ return
[docs] def infect(self, indices: int | np.ndarray, num_infected: int | np.ndarray) -> None: """ Infects the given nodes with the given number of infected individuals. Args: indices (int | np.ndarray): The indices of the nodes to infect. num_infected (int | np.ndarray): The number of infected individuals to infect. """ self.patches.states.I[indices] += cast_type(num_infected, self.patches.states.dtype) self.patches.states.S[indices] -= cast_type(num_infected, self.patches.states.dtype) return
[docs] def recover(self, indices: int | np.ndarray, num_recovered: int | np.ndarray) -> None: """ Recovers the given nodes with the given number of recovered individuals. Moves individuals from Infected to Recovered compartment. Args: indices (int | np.ndarray): The indices of the nodes to recover. num_recovered (int | np.ndarray): The number of recovered individuals. """ self.patches.states.R[indices] += cast_type(num_recovered, self.patches.states.dtype) # Add to R self.patches.states.I[indices] -= cast_type(num_recovered, self.patches.states.dtype) # Remove from I return
def _setup_components(self) -> None: pass
# Create an alias for BiweeklyModel as Model Model = BiweeklyModel