Source code for laser.measles.abm.base
"""
Basic classes
"""
import numpy as np
import patito as pt
import polars as pl
from laser.measles.base import BasePatchLaserFrame
from laser.measles.base import BasePeopleLaserFrame
from laser.measles.base import BaseScenario as LaserMeaslesBaseScenario
[docs]
class PeopleLaserFrame(BasePeopleLaserFrame):
"""
Laserframe for people (e.g., agent) properties
"""
patch_id: np.ndarray
state: np.ndarray
susceptibility: np.ndarray
[docs]
class PatchLaserFrame(BasePatchLaserFrame):
"""
LaserFrame for patch-level properties in ABM models.
This class extends BasePatchLaserFrame to provide patch-level data
storage and access patterns specific to agent-based models.
"""
[docs]
class BaseABMScenarioSchema(pt.Model):
"""
Schema for the scenario data.
"""
pop: int # population
lat: float # latitude
lon: float # longitude
id: str # ids of the nodes
mcv1: float # Routine MCV1 coverage for newborns (0.0-1.0). Only affects births via VitalDynamicsProcess; does NOT vaccinate existing population.
[docs]
class BaseABMScenario(LaserMeaslesBaseScenario):
def __init__(self, df: pl.DataFrame):
super().__init__(df)
BaseABMScenarioSchema.validate(df, allow_superfluous_columns=True)
def _validate(self, df: pl.DataFrame):
# Validate required columns exist - derive from schema
required_columns = list(BaseABMScenarioSchema.model_fields.keys())
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
# Validate data types using Polars' native operations
try:
# Validate pop is integer
if not df["pop"].dtype == pl.Int64:
raise ValueError("Column 'pop' must be integer type")
# Validate lat and lon are float
if not df["lat"].dtype == pl.Float64:
raise ValueError("Column 'lat' must be float type")
if not df["lon"].dtype == pl.Float64:
raise ValueError("Column 'lon' must be float type")
# Validate mcv1 is float
if not df["mcv1"].dtype == pl.Float64:
raise ValueError("Column 'mcv1' must be float type")
# Validate mcv1 is between 0 and 1 (as percentages)
if not df["mcv1"].is_between(0, 1).all():
raise ValueError("Column 'mcv1' must be between 0 and 1")
# Validate ids are either string or integer
if not (df["id"].dtype == pl.String or df["id"].dtype == pl.Int64):
raise ValueError("Column 'id' must be either string or integer type")
# Validate no null values
null_counts = df.null_count()
if np.any(null_counts):
raise ValueError(f"DataFrame contains null values:\n{null_counts}")
except Exception as e:
raise ValueError(f"DataFrame validation error:\n{e}") from e
BaseScenario = BaseABMScenario