Source code for laser_measles.abm.components.process_disease

"""
Component defining the DiseaseProcess, which simulates the disease progression in the ABM model with MCV1.
"""

import numpy as np
from pydantic import BaseModel
from pydantic import Field

from laser_measles.abm.model import ABMModel
from laser_measles.base import BaseComponent

# Import numba conditionally for the numba implementation
try:
    import numba as nb

    NUMBA_AVAILABLE = True
    NUM_THREADS = nb.get_num_threads()
except ImportError:
    NUMBA_AVAILABLE = False


# Numpy Implementation
def numpy_gamma_update(count, timers_0, timers_1, state, shape, scale, flow, patch_id):
    """Numpy function to check and update exposed timers for the population."""
    # Find individuals with active exposure timers
    active_mask = timers_0[:count] > 0
    active_indices = np.where(active_mask)[0]

    if len(active_indices) == 0:
        return

    # Decrement timers for active individuals
    timers_0[active_indices] -= 1

    # Find individuals transitioning from E to I (timer reaches 0)
    transition_mask = timers_0[active_indices] <= 0
    transition_indices = active_indices[transition_mask]

    if len(transition_indices) > 0:
        # Set infectious timers using gamma distribution
        new_timers = np.maximum(1, np.round(np.random.gamma(shape, scale, len(transition_indices))))
        timers_1[transition_indices] = new_timers.astype(np.uint16)

        # Update state to infectious (I)
        state[transition_indices] = 2

        # Update flow counts by patch
        patch_counts = np.bincount(patch_id[transition_indices], minlength=len(flow))
        flow += patch_counts.astype(np.uint32)


# Numba Implementation (if available)
if NUMBA_AVAILABLE:

    @nb.njit(
        (nb.uint32, nb.uint16[:], nb.uint16[:], nb.uint8[:], nb.float32, nb.float32, nb.uint32[:], nb.uint16[:]), parallel=True, cache=True
    )
    def nb_gamma_update(count, timers_0, timers_1, state, shape, scale, flow, patch_id):  # pragma: no cover
        """Numba compiled function to check and update exposed timers for the population in parallel."""
        max_node_id = np.max(patch_id) + 1
        thread_flow = np.zeros((NUM_THREADS, max_node_id), dtype=np.uint32)

        for i in nb.prange(count):
            timer_0 = timers_0[i]
            if timer_0 > 0:
                timer_0 -= 1
                # if we have decremented etimer from >0 to <=0, set infectious timer.
                if timer_0 <= 0:
                    timers_1[i] = np.maximum(np.uint16(1), np.uint16(np.round(np.random.gamma(shape, scale))))
                    thread_flow[nb.get_thread_id(), patch_id[i]] += 1
                    state[i] = 2
                timers_0[i] = timer_0
        flow[:] += thread_flow.sum(axis=0)
        return
else:
    nb_gamma_update = None


# Numpy Implementation
def numpy_state_update(count, timers, state, new_state, flow, patch_id):
    """Numpy function to check and update infection timers for the population."""
    # Find individuals with active timers
    active_mask = timers[:count] > 0
    active_indices = np.where(active_mask)[0]

    if len(active_indices) == 0:
        return

    # Decrement timers for active individuals
    timers[active_indices] -= 1

    # Find individuals transitioning (timer reaches 0)
    transition_mask = timers[active_indices] == 0
    transition_indices = active_indices[transition_mask]

    if len(transition_indices) > 0:
        # Update state
        state[transition_indices] = new_state

        # Update flow counts by patch
        patch_counts = np.bincount(patch_id[transition_indices], minlength=len(flow))
        flow += patch_counts.astype(np.uint32)


# Numba Implementation (if available)
if NUMBA_AVAILABLE:

    @nb.njit((nb.uint32, nb.uint16[:], nb.uint8[:], nb.uint8, nb.uint32[:], nb.uint16[:]), parallel=True, cache=True)
    def nb_state_update(count, timers, state, new_state, flow, patch_id):  # pragma: no cover
        """Numba compiled function to check and update infection timers for the population in parallel."""
        max_patch_id = np.max(patch_id) + 1
        thread_flow = np.zeros((NUM_THREADS, max_patch_id), dtype=np.uint32)
        for i in nb.prange(count):
            timer = timers[i]
            if timer > 0:
                timer -= 1
                if timer == 0:
                    thread_flow[nb.get_thread_id(), patch_id[i]] += 1
                    state[i] = new_state
                timers[i] = timer
        flow[:] += thread_flow.sum(axis=0)
        return
else:
    nb_state_update = None


class DiseaseParams(BaseModel):
    inf_mean: float = Field(default=8.0, description="Mean infectious period (days)")
    inf_sigma: float = Field(default=2.0, description="Shape of the infectious period (days)")

    @property
    def inf_shape(self) -> float:
        return (self.inf_mean / self.inf_sigma) ** 2

    @property
    def inf_scale(self) -> float:
        return self.inf_sigma**2 / self.inf_mean


[docs] class DiseaseProcess(BaseComponent): """ This component provides disease progression (E->I->R) It is used to update the infectious timers and the exposed timers. """ def __init__(self, model, verbose: bool = False, params: DiseaseParams | None = None): super().__init__(model, verbose) self.params = params if params is not None else DiseaseParams() def __call__(self, model, tick: int) -> None: people = model.people patches = model.patches flow = np.zeros(len(model.patches), dtype=np.uint32) # Update the infectious timers # I --> R self.state_update_func(people.count, people.itimer, people.state, np.uint8(model.params.states.index("R")), flow, people.patch_id) patches.states.I -= flow patches.states.R += flow # Update the exposure timers for the population in the model, # move to infectious which follows a gamma distribution flow = np.zeros(len(model.patches), dtype=np.uint32) self.gamma_update_func( people.count, people.etimer, people.itimer, people.state, self.params.inf_shape, self.params.inf_scale, flow, people.patch_id, ) patches.states.E -= flow patches.states.I += flow return def _initialize(self, model: ABMModel) -> None: # Select function implementations based on model configuration self.state_update_func = self.select_function(numpy_state_update, nb_state_update) self.gamma_update_func = self.select_function(numpy_gamma_update, nb_gamma_update)