"""
Component defining the DiseaseProcess, which simulates the disease progression in the ABM model with MCV1.
"""
import numpy as np
from pydantic import BaseModel
from pydantic import Field
from laser_measles.abm.model import ABMModel
from laser_measles.base import BaseComponent
# Import numba conditionally for the numba implementation
try:
    import numba as nb
    NUMBA_AVAILABLE = True
    NUM_THREADS = nb.get_num_threads()
except ImportError:
    NUMBA_AVAILABLE = False
# Numpy Implementation
def numpy_gamma_update(count, timers_0, timers_1, state, shape, scale, flow, patch_id):
    """Numpy function to check and update exposed timers for the population."""
    # Find individuals with active exposure timers
    active_mask = timers_0[:count] > 0
    active_indices = np.where(active_mask)[0]
    if len(active_indices) == 0:
        return
    # Decrement timers for active individuals
    timers_0[active_indices] -= 1
    # Find individuals transitioning from E to I (timer reaches 0)
    transition_mask = timers_0[active_indices] <= 0
    transition_indices = active_indices[transition_mask]
    if len(transition_indices) > 0:
        # Set infectious timers using gamma distribution
        new_timers = np.maximum(1, np.round(np.random.gamma(shape, scale, len(transition_indices))))
        timers_1[transition_indices] = new_timers.astype(np.uint16)
        # Update state to infectious (I)
        state[transition_indices] = 2
        # Update flow counts by patch
        patch_counts = np.bincount(patch_id[transition_indices], minlength=len(flow))
        flow += patch_counts.astype(np.uint32)
# Numba Implementation (if available)
if NUMBA_AVAILABLE:
    @nb.njit(
        (nb.uint32, nb.uint16[:], nb.uint16[:], nb.uint8[:], nb.float32, nb.float32, nb.uint32[:], nb.uint16[:]), parallel=True, cache=True
    )
    def nb_gamma_update(count, timers_0, timers_1, state, shape, scale, flow, patch_id):  # pragma: no cover
        """Numba compiled function to check and update exposed timers for the population in parallel."""
        max_node_id = np.max(patch_id) + 1
        thread_flow = np.zeros((NUM_THREADS, max_node_id), dtype=np.uint32)
        for i in nb.prange(count):
            timer_0 = timers_0[i]
            if timer_0 > 0:
                timer_0 -= 1
                # if we have decremented etimer from >0 to <=0, set infectious timer.
                if timer_0 <= 0:
                    timers_1[i] = np.maximum(np.uint16(1), np.uint16(np.round(np.random.gamma(shape, scale))))
                    thread_flow[nb.get_thread_id(), patch_id[i]] += 1
                    state[i] = 2
                timers_0[i] = timer_0
        flow[:] += thread_flow.sum(axis=0)
        return
else:
    nb_gamma_update = None
# Numpy Implementation
def numpy_state_update(count, timers, state, new_state, flow, patch_id):
    """Numpy function to check and update infection timers for the population."""
    # Find individuals with active timers
    active_mask = timers[:count] > 0
    active_indices = np.where(active_mask)[0]
    if len(active_indices) == 0:
        return
    # Decrement timers for active individuals
    timers[active_indices] -= 1
    # Find individuals transitioning (timer reaches 0)
    transition_mask = timers[active_indices] == 0
    transition_indices = active_indices[transition_mask]
    if len(transition_indices) > 0:
        # Update state
        state[transition_indices] = new_state
        # Update flow counts by patch
        patch_counts = np.bincount(patch_id[transition_indices], minlength=len(flow))
        flow += patch_counts.astype(np.uint32)
# Numba Implementation (if available)
if NUMBA_AVAILABLE:
    @nb.njit((nb.uint32, nb.uint16[:], nb.uint8[:], nb.uint8, nb.uint32[:], nb.uint16[:]), parallel=True, cache=True)
    def nb_state_update(count, timers, state, new_state, flow, patch_id):  # pragma: no cover
        """Numba compiled function to check and update infection timers for the population in parallel."""
        max_patch_id = np.max(patch_id) + 1
        thread_flow = np.zeros((NUM_THREADS, max_patch_id), dtype=np.uint32)
        for i in nb.prange(count):
            timer = timers[i]
            if timer > 0:
                timer -= 1
                if timer == 0:
                    thread_flow[nb.get_thread_id(), patch_id[i]] += 1
                    state[i] = new_state
                timers[i] = timer
        flow[:] += thread_flow.sum(axis=0)
        return
else:
    nb_state_update = None
class DiseaseParams(BaseModel):
    inf_mean: float = Field(default=8.0, description="Mean infectious period (days)")
    inf_sigma: float = Field(default=2.0, description="Shape of the infectious period (days)")
    @property
    def inf_shape(self) -> float:
        return (self.inf_mean / self.inf_sigma) ** 2
    @property
    def inf_scale(self) -> float:
        return self.inf_sigma**2 / self.inf_mean
[docs]
class DiseaseProcess(BaseComponent):
    """
    This component provides disease progression (E->I->R)
    It is used to update the infectious timers and the exposed timers.
    """
    def __init__(self, model, verbose: bool = False, params: DiseaseParams | None = None):
        super().__init__(model, verbose)
        self.params = params if params is not None else DiseaseParams()
    def __call__(self, model, tick: int) -> None:
        people = model.people
        patches = model.patches
        flow = np.zeros(len(model.patches), dtype=np.uint32)
        # Update the infectious timers
        # I --> R
        self.state_update_func(people.count, people.itimer, people.state, np.uint8(model.params.states.index("R")), flow, people.patch_id)
        patches.states.I -= flow
        patches.states.R += flow
        # Update the exposure timers for the population in the model,
        # move to infectious which follows a gamma distribution
        flow = np.zeros(len(model.patches), dtype=np.uint32)
        self.gamma_update_func(
            people.count,
            people.etimer,
            people.itimer,
            people.state,
            self.params.inf_shape,
            self.params.inf_scale,
            flow,
            people.patch_id,
        )
        patches.states.E -= flow
        patches.states.I += flow
        return
    def _initialize(self, model: ABMModel) -> None:
        # Select function implementations based on model configuration
        self.state_update_func = self.select_function(numpy_state_update, nb_state_update)
        self.gamma_update_func = self.select_function(numpy_gamma_update, nb_gamma_update)