Source code for laser_measles.biweekly.model
"""
A class to represent the biweekly model.
"""
import numpy as np
import polars as pl
from laser_measles.base import BaseLaserModel
from laser_measles.biweekly.base import BaseBiweeklyScenario
from laser_measles.biweekly.base import PatchLaserFrame
from laser_measles.biweekly.params import BiweeklyParams
from laser_measles.utils import StateArray
from laser_measles.utils import cast_type
[docs]
class BiweeklyModel(BaseLaserModel):
    """
    A class to represent the biweekly model.
    Args:
        scenario (BaseScenario): A scenario containing the scenario data, including population, latitude, and longitude.
        params (BiweeklyParams): A set of parameters for the model.
        name (str, optional): The name of the model. Defaults to "biweekly".
    Notes:
        This class initializes the model with the given scenario and parameters. The scenario must include the following columns:
            - `id` (string): The name of the patch or location.
            - `pop` (integer): The population count for the patch.
            - `lat` (float degrees): The latitude of the patches (e.g., from geographic or population centroid).
            - `lon` (float degrees): The longitude of the patches (e.g., from geographic or population centroid).
            - `mcv1` (float): The MCV1 coverage for the patches.
    """
    patches: PatchLaserFrame
    # Specify the scenario wrapper class for auto-wrapping DataFrames
    scenario_wrapper_class = BaseBiweeklyScenario
    def __init__(self, scenario: BaseBiweeklyScenario | pl.DataFrame, params: BiweeklyParams, name: str = "biweekly") -> None:
        """
        Initialize the disease model with the given scenario and parameters.
        Args:
            scenario (BaseScenario): A scenario containing the scenario data, including population, latitude, and longitude.
            params (BiweeklyParams): A set of parameters for the model, including seed, nticks, k, a, b, c, max_frac, cbr, verbose, and pyramid_file.
            name (str, optional): The name of the model. Defaults to "biweekly".
        Returns:
            None
        """
        super().__init__(scenario, params, name)
        # Add patches to the model
        self.patches = PatchLaserFrame(capacity=len(scenario))
        # Create the state vector for each of the patches (3, num_patches)
        self.patches.add_vector_property("states", len(self.params.states))  # S, I, R
        # Wrap the states array with StateArray for attribute access
        self.patches.states = StateArray(self.patches.states, state_names=self.params.states)
        # Start with totally susceptible population
        self.patches.states.S[:] = scenario["pop"]
        return
    def __call__(self, model, tick: int) -> None:
        """
        Updates the model for the next tick.
        Args:
            model: The model containing the patches and their populations.
            tick (int): The current time step or tick.
        Returns:
            None
        """
        return
[docs]
    def infect(self, indices: int | np.ndarray, num_infected: int | np.ndarray) -> None:
        """
        Infects the given nodes with the given number of infected individuals.
        Args:
            indices (int | np.ndarray): The indices of the nodes to infect.
            num_infected (int | np.ndarray): The number of infected individuals to infect.
        """
        self.patches.states.I[indices] += cast_type(num_infected, self.patches.states.dtype)
        self.patches.states.S[indices] -= cast_type(num_infected, self.patches.states.dtype)
        return 
[docs]
    def recover(self, indices: int | np.ndarray, num_recovered: int | np.ndarray) -> None:
        """
        Recovers the given nodes with the given number of recovered individuals.
        Moves individuals from Infected to Recovered compartment.
        Args:
            indices (int | np.ndarray): The indices of the nodes to recover.
            num_recovered (int | np.ndarray): The number of recovered individuals.
        """
        self.patches.states.R[indices] += cast_type(num_recovered, self.patches.states.dtype)  # Add to R
        self.patches.states.I[indices] -= cast_type(num_recovered, self.patches.states.dtype)  # Remove from I
        return 
    def _setup_components(self) -> None:
        pass 
# Create an alias for BiweeklyModel as Model
Model = BiweeklyModel