Source code for poliosim.interventions

'''
Specify the core interventions available in Poliosim. Other interventions can be
defined by the user by inheriting from these classes.
'''

import numpy as np
import pandas as pd
import pylab as pl
import sciris as sc
import inspect
import datetime as dt
from . import utils as psu



#%% Generic intervention classes

__all__ = ['InterventionDict', 'Intervention']


def find_day(arr, t=None, which='first'):
    '''
    Helper function to find if the current simulation time matches any day in the
    intervention. Although usually never more than one index is returned, it is
    returned as a list for the sake of easy iteration.

    Args:
        arr (list): list of days in the intervention, or else a boolean array
        t (int): current simulation time (can be None if a boolean array is used)
        which (str): what to return: 'first', 'last', or 'all' indices

    Returns:
        inds (list): list of matching days; length zero or one unless which is 'all'
    '''
    all_inds = sc.findinds(arr=arr, val=t)
    if len(all_inds) == 0 or which == 'all':
        inds = all_inds
    elif which == 'first':
        inds = [all_inds[0]]
    elif which == 'last':
        inds = [all_inds[-1]]
    else:
        errormsg = f'Argument "which" must be "first", "last", or "all", not "{which}"'
        raise ValueError(errormsg)
    return inds


[docs]def InterventionDict(which, pars): ''' Generate an intervention from a dictionary. Although a function, it acts like a class, since it returns a class instance. **Example**:: interv = ps.InterventionDict(which='change_beta', pars={'days': 30, 'changes': 0.5, 'layers': None}) ''' mapping = dict( change_beta = change_beta, test_prob = test_prob, contact_tracing = contact_tracing, ) try: IntervClass = mapping[which] except: available = ', '.join(mapping.keys()) errormsg = f'Only interventions "{available}" are available in dictionary representation, not "{which}"' raise sc.KeyNotFoundError(errormsg) intervention = IntervClass(**pars) return intervention
[docs]class Intervention: ''' Base class for interventions. By default, interventions are printed using a dict format, which they can be recreated from. To display all the attributes of the intervention, use disp() instead. To retrieve a particular intervention from a sim, use sim.get_intervention(). Args: label (str): a label for the intervention (used for plotting, and for ease of identification) show_label (bool): whether or not to include the label in the legend do_plot (bool): whether or not to plot the intervention line_args (dict): arguments passed to pl.axvline() when plotting ''' def __init__(self, label=None, show_label=False, do_plot=None, line_args=None): self._store_args() # Store the input arguments so the intervention can be recreated if label is None: label = self.__class__.__name__ # Use the class name if no label is supplied self.label = label # e.g. "Close schools" self.show_label = show_label # Do not show the label by default self.do_plot = do_plot if do_plot is not None else True # Plot the intervention, including if None self.line_args = sc.mergedicts(dict(linestyle='--', c='#aaa', lw=1.0), line_args) # Do not set alpha by default due to the issue of overlapping interventions self.days = [] # The start and end days of the intervention self.initialized = False # Whether or not it has been initialized self.finalized = False # Whether or not it has been initialized return def __repr__(self, jsonify=False): ''' Return a JSON-friendly output if possible, else revert to short repr ''' if self.__class__.__name__ in __all__ or jsonify: try: json = self.to_json() which = json['which'] pars = json['pars'] parstr = ', '.join([f'{k}={v}' for k,v in pars.items()]) output = f"ps.{which}({parstr})" except Exception as E: output = str(type(self)) + f' (error: {str(E)})' # If that fails, print why return output else: return f'{self.__module__}.{self.__class__.__name__}()' def __call__(self, *args, **kwargs): # Makes Intervention(sim) equivalent to Intervention.apply(sim) if not self.initialized: # pragma: no cover errormsg = f'Intervention (label={self.label}, {type(self)}) has not been initialized' raise RuntimeError(errormsg) return self.apply(*args, **kwargs)
[docs] def disp(self): ''' Print a detailed representation of the intervention ''' return sc.pr(self)
def _store_args(self): ''' Store the user-supplied arguments for later use in to_json ''' f0 = inspect.currentframe() # This "frame", i.e. Intervention.__init__() f1 = inspect.getouterframes(f0) # The list of outer frames parent = f1[2].frame # The parent frame, e.g. change_beta.__init__() _,_,_,values = inspect.getargvalues(parent) # Get the values of the arguments if values: self.input_args = {} for key,value in values.items(): if key == 'kwargs': # Store additional kwargs directly for k2,v2 in value.items(): self.input_args[k2] = v2 # These are already a dict elif key not in ['self', '__class__']: # Everything else, but skip these self.input_args[key] = value return
[docs] def initialize(self, sim=None): ''' Initialize intervention -- this is used to make modifications to the intervention that can't be done until after the sim is created. ''' self.initialized = True self.finalized = False return
[docs] def finalize(self, sim=None): ''' Finalize intervention This method is run once as part of `sim.finalize()` enabling the intervention to perform any final operations after the simulation is complete (e.g. rescaling) ''' if self.finalized: raise RuntimeError('Intervention already finalized') # Raise an error because finalizing multiple times has a high probability of producing incorrect results e.g. applying rescale factors twice self.finalized = True return
[docs] def apply(self, sim): ''' Apply the intervention. This is the core method which each derived intervention class must implement. This method gets called at each timestep and can make arbitrary changes to the Sim object, as well as storing or modifying the state of the intervention. Args: sim: the Sim instance Returns: None ''' raise NotImplementedError
[docs] def shrink(self, in_place=False): ''' Remove any excess stored data from the intervention; for use with sim.shrink(). Args: in_place (bool): whether to shrink the intervention (else shrink a copy) ''' if in_place: return self else: return sc.dcp(self)
[docs] def plot_intervention(self, sim, ax=None, **kwargs): ''' Plot the intervention This can be used to do things like add vertical lines on days when interventions take place. Can be disabled by setting self.do_plot=False. Note 1: you can modify the plotting style via the ``line_args`` argument when creating the intervention. Note 2: By default, the intervention is plotted at the days stored in self.days. However, if there is a self.plot_days attribute, this will be used instead. Args: sim: the Sim instance ax: the axis instance kwargs: passed to ax.axvline() Returns: None ''' line_args = sc.mergedicts(self.line_args, kwargs) if self.do_plot or self.do_plot is None: if ax is None: ax = pl.gca() if hasattr(self, 'plot_days'): days = self.plot_days else: days = self.days if sc.isiterable(days): label_shown = False # Don't show the label more than once for day in days: if sc.isnumber(day): if self.show_label and not label_shown: # Choose whether to include the label in the legend label = self.label label_shown = True else: label = None ax.axvline(day, label=label, **line_args) return
[docs] def to_json(self): ''' Return JSON-compatible representation Custom classes can't be directly represented in JSON. This method is a one-way export to produce a JSON-compatible representation of the intervention. In the first instance, the object dict will be returned. However, if an intervention itself contains non-standard variables as attributes, then its `to_json` method will need to handle those. Note that simply printing an intervention will usually return a representation that can be used to recreate it. Returns: JSON-serializable representation (typically a dict, but could be anything else) ''' which = self.__class__.__name__ pars = sc.jsonify(self.input_args) output = dict(which=which, pars=pars) return output
#%% Beta interventions __all__+= ['change_beta'] def process_days(sim, days): ''' Ensure lists of days are in consistent format. Used by change_beta, clip_edges, and some analyzers. If day is 'end' or -1, use the final day of the simulation. ''' if sc.isstring(days) or not sc.isiterable(days): days = sc.promotetolist(days) for d,day in enumerate(days): if day in ['end', -1]: day = sim['end_day'] days[d] = sim.day(day) # Ensure it's an integer and not a string or something days = sc.promotetoarray(days) return days def process_changes(sim, changes, days): ''' Ensure lists of changes are in consistent format. Used by change_beta and clip_edges. ''' changes = sc.promotetoarray(changes) if len(days) != len(changes): errormsg = f'Number of days supplied ({len(days)}) does not match number of changes ({len(changes)})' raise ValueError(errormsg) return changes
[docs]class change_beta(Intervention): ''' The most basic intervention -- change beta by a certain amount. Args: days (int or array): the day or array of days to apply the interventions changes (float or array): the changes in beta (1 = no change, 0 = no transmission) layers (str or list): the layers in which to change beta kwargs (dict): passed to Intervention() **Examples**:: interv = ps.change_beta(25, 0.3) # On day 25, reduce overall beta by 70% to 0.3 interv = ps.change_beta([14, 28], [0.7, 1], layers='s') # On day 14, reduce beta by 30%, and on day 28, return to 1 for schools ''' def __init__(self, days, changes, layers=None, **kwargs): super().__init__(**kwargs) # Initialize the Intervention object self._store_args() # Store the input arguments so the intervention can be recreated self.days = sc.dcp(days) self.changes = sc.dcp(changes) self.layers = sc.dcp(layers) self.orig_betas = None return
[docs] def initialize(self, sim): ''' Fix days and store beta ''' super().initialize() self.days = process_days(sim, self.days) self.changes = process_changes(sim, self.changes, self.days) self.layers = sc.promotetolist(self.layers, keepnone=True) self.orig_betas = {} for lkey in self.layers: if lkey is None: self.orig_betas['overall'] = sim['beta'] else: self.orig_betas[lkey] = sim['beta_layer'][lkey] return
[docs] def apply(self, sim): # If this day is found in the list, apply the intervention for ind in find_day(self.days, sim.t): for lkey,new_beta in self.orig_betas.items(): new_beta = new_beta * self.changes[ind] if lkey == 'overall': sim['beta'] = new_beta else: sim['beta_layer'][lkey] = new_beta return
#%% Testing interventions __all__+= ['test_prob', 'contact_tracing', 'symptomatic_triggered_surveillance'] # Process daily data def process_daily_data(daily_data, sim, start_day, as_int=False): ''' This function performs one of two things: if the daily data are supplied as a number, then it converts it to an array of the right length. If the daily data are supplied as a Pandas series or dataframe with a date index, then it reindexes it to match the start date of the simulation. Otherwise, it does nothing. Args: daily_data (number, dataframe, or series): the data to convert to standardized format sim (Sim): the simulation object start_day (date): the start day of the simulation, in already-converted datetime.date format as_int (bool): whether to convert to an integer ''' if sc.isnumber(daily_data): # If a number, convert to an array if as_int: daily_data = int(daily_data) # Make it an integer daily_data = np.array([daily_data] * sim.npts) elif isinstance(daily_data, (pd.Series, pd.DataFrame)): start_date = sim['start_day'] + dt.timedelta(days=start_day) end_date = daily_data.index[-1] dateindex = pd.date_range(start_date, end_date) daily_data = daily_data.reindex(dateindex, fill_value=0).to_numpy() return daily_data def get_subtargets(subtarget, sim): ''' A small helper function to see if subtargeting is a list of indices to use, or a function that needs to be called. If a function, it must take a single argument, a sim object, and return a list of indices. Also validates the values. Currently designed for use with testing interventions, but could be generalized to other interventions. Args: subtarget (dict): dict with keys 'inds' and 'vals'; see test_num() for examples of a valid subtarget dictionary sim (Sim): the simulation object ''' # Validation if callable(subtarget): subtarget = subtarget(sim) if 'inds' not in subtarget: errormsg = f'The subtarget dict must have keys "inds" and "vals", but you supplied {subtarget}' raise ValueError(errormsg) # Handle the two options of type if callable(subtarget['inds']): # A function has been provided subtarget_inds = subtarget['inds'](sim) # Call the function to get the indices else: subtarget_inds = subtarget['inds'] # The indices are supplied directly # Validate the values if callable(subtarget['vals']): # A function has been provided subtarget_vals = subtarget['vals'](sim) # Call the function to get the indices else: subtarget_vals = subtarget['vals'] # The indices are supplied directly if sc.isiterable(subtarget_vals): if len(subtarget_vals) != len(subtarget_inds): errormsg = f'Length of subtargeting indices ({len(subtarget_inds)}) does not match length of values ({len(subtarget_vals)})' raise ValueError(errormsg) return subtarget_inds, subtarget_vals def get_quar_inds(quar_policy, sim): ''' Helper function to return the appropriate indices for people in quarantine based on the current quarantine testing "policy". Used by test_num and test_prob. Args: quar_policy (str): 'start', people entering quarantine; 'end', people leaving; 'both', entering and leaving; 'daily', every day in quarantine sim (Sim): the simulation object ''' t = sim.t if quar_policy == 'start': quar_inds = psu.true(sim.people.date_quarantined==t-1) # Actually do the day before elif quar_policy == 'end': quar_inds = psu.true(sim.people.date_end_quarantine==t) elif quar_policy == 'both': quar_inds = np.concatenate([psu.true(sim.people.date_quarantined==t-1), psu.true(sim.people.date_end_quarantine==t)]) elif quar_policy == 'daily': quar_inds = psu.true(sim.people.quarantined) else: errormsg = f'Quarantine policy "{quar_policy}" not recognized: must be start, end, both, or daily' raise ValueError(errormsg) return quar_inds
[docs]class test_prob(Intervention): ''' Test as many people as required based on test probability. Probabilities are OR together, so choose wisely. Args: symp_prob (float): Probability of testing a symptomatic (unquarantined) person asymp_prob (float): Probability of testing an asymptomatic (unquarantined) person symp_quar_prob (float): Probability of testing a symptomatic quarantined person asymp_quar_prob (float): Probability of testing an asymptomatic quarantined person quar_policy (str): Policy for testing in quarantine: options are 'start', 'end', 'both' (start and end), 'daily' subtarget (dict): subtarget intervention to people with particular indices (see test_num() for details) test_sensitivity (float): Probability of a true positive loss_prob (float): Probability of loss to follow-up test_delay (int): How long testing takes start_day (int): When to start the intervention kwargs (dict): passed to Intervention() **Examples**:: interv = ps.test_prob(symp_prob=0.1, asymp_prob=0.01) # Test 10% of symptomatics and 1% of asymptomatics interv = ps.test_prob(symp_quar_prob=0.4) # Test 40% of those in quarantine with symptoms ''' def __init__(self, symp_prob, asymp_prob=0.0, symp_quar_prob=None, asymp_quar_prob=None, quar_policy=None, subtarget=None, test_sensitivity=1.0, loss_prob=0.0, test_delay=0, start_day=0, end_day=None, **kwargs): super().__init__(**kwargs) # Initialize the Intervention object self.symp_prob = symp_prob self.asymp_prob = asymp_prob self.symp_quar_prob = symp_quar_prob if symp_quar_prob is not None else symp_prob self.asymp_quar_prob = asymp_quar_prob if asymp_quar_prob is not None else asymp_prob self.quar_policy = quar_policy if quar_policy else 'start' self.subtarget = subtarget self.test_sensitivity = test_sensitivity self.loss_prob = loss_prob self.test_delay = test_delay self.start_day = start_day self.end_day = end_day return
[docs] def initialize(self, sim): ''' Fix the dates ''' super().initialize() self.start_day = sim.day(self.start_day) self.end_day = sim.day(self.end_day) self.days = [self.start_day, self.end_day] return
[docs] def apply(self, sim): ''' Perform testing ''' t = sim.t if t < self.start_day: return elif self.end_day is not None and t > self.end_day: return # Find probablity for symptomatics to be tested symp_inds = psu.true(sim.people.symptomatic) symp_prob = self.symp_prob # Define symptomatics, accounting for ILI prevalence pop_size = sim['pop_size'] # Define asymptomatics: those who neither have COVID symptoms nor ILI symptoms asymp_inds = np.setdiff1d(np.arange(pop_size), symp_inds) # Handle quarantine and other testing criteria quar_inds = get_quar_inds(self.quar_policy, sim) symp_quar_inds = np.intersect1d(quar_inds, symp_inds) asymp_quar_inds = np.intersect1d(quar_inds, asymp_inds) diag_inds = psu.true(sim.people.diagnosed) if self.subtarget is not None: subtarget_inds = self.subtarget['inds'] # Construct the testing probabilities piece by piece -- complicated, since need to do it in the right order test_probs = np.zeros(sim.n) # Begin by assigning equal testing probability to everyone test_probs[symp_inds] = symp_prob # People with symptoms test_probs[asymp_inds] = self.asymp_prob # People without symptoms test_probs[symp_quar_inds] = self.symp_quar_prob # People with symptoms in quarantine test_probs[asymp_quar_inds] = self.asymp_quar_prob # People without symptoms in quarantine if self.subtarget is not None: subtarget_inds, subtarget_vals = get_subtargets(self.subtarget, sim) test_probs[subtarget_inds] = subtarget_vals # People being explicitly subtargeted test_probs[diag_inds] = 0.0 # People who are diagnosed don't test test_inds = psu.true(psu.binomial_arr(test_probs)) # Finally, calculate who actually tests sim.people.test(test_inds, test_sensitivity=self.test_sensitivity, loss_prob=self.loss_prob, test_delay=self.test_delay) # Actually test people sim.results['new_tests'][t] += int(len(test_inds)*sim['pop_scale']/sim.rescale_vec[t]) # If we're using dynamic scaling, we have to scale by pop_scale, not rescale_vec return
[docs]class symptomatic_triggered_surveillance(test_prob): ''' Trigger surveillance whenever one or more newly symptomatic cases occur in a timestep. Surveillance is implemented as randomly choosing a set of individuals in the population, as configured by parameter. The other parameters associated with test_priob are also available. ''' def __init__(self, num_sample_mu=10.0, num_sample_sigma=0.0, test_sensitivity=1.0, test_delay=0, **kwargs): # Initialize the test_prob object. # First argument is effectively ignored. super().__init__(symp_prob=0.0, **kwargs) # Store the input arguments so the intervention can be recreated self.num_sample_mu = num_sample_mu self.num_sample_sigma = num_sample_sigma self.test_sensitivity = test_sensitivity self.test_delay = test_delay self.last_num_symptomatic = 0 return
[docs] def apply(self, sim): t = sim.t if t < self.start_day: return elif self.end_day is not None and t > self.end_day: return # If no one new is detected, we're done current_num_symptomatic = len(np.nonzero(sim.people.symptomatic)[0]) if self.last_num_symptomatic >= current_num_symptomatic: return # Figure out who to test num_to_surveil = np.rint(np.random.normal(loc=self.num_sample_mu, scale=self.num_sample_sigma)) num_pop = sim.n test_inds = psu.choose(num_pop, num_to_surveil) sim.people.test(test_inds, test_sensitivity=self.test_sensitivity, test_delay=self.test_delay) # Update results sim.results['new_symp_triggered_surveilled'][t] += int(len(test_inds) * sim['pop_scale'] / sim.rescale_vec[t]) # Update the last number symptomatic current_num_symptomatic = len(np.nonzero(sim.people.symptomatic)[0]) self.last_num_symptomatic = current_num_symptomatic return
[docs]class contact_tracing(Intervention): ''' Contact tracing of positive people. Args: trace_probs (dict): probability of tracing, per layer trace_time (dict): days required to trace, per layer start_day (int): intervention start day end_day (int): intervention end day test_delay (int): number of days a test result takes presumptive (bool): whether or not to begin isolation and contact tracing on the presumption of a positive diagnosis kwargs (dict): passed to Intervention() ''' def __init__(self, trace_probs=None, trace_time=None, start_day=0, end_day=None, presumptive=False, **kwargs): super().__init__(**kwargs) # Initialize the Intervention object self.trace_probs = trace_probs self.trace_time = trace_time self.start_day = start_day self.end_day = end_day self.presumptive = presumptive return
[docs] def initialize(self, sim): ''' Fix the dates and dictionaries ''' super().initialize() self.start_day = sim.day(self.start_day) # Handle days self.end_day = sim.day(self.end_day) self.days = [self.start_day, self.end_day] if self.trace_probs is None: self.trace_probs = 1.0 if self.trace_time is None: self.trace_time = 0.0 if sc.isnumber(self.trace_probs): # Convert trace probabilities to a dictionary val = self.trace_probs self.trace_probs = {k:val for k in sim.people.layer_keys()} if sc.isnumber(self.trace_time): val = self.trace_time self.trace_time = {k:val for k in sim.people.layer_keys()} return
[docs] def apply(self, sim): t = sim.t if t < self.start_day: return elif self.end_day is not None and t > self.end_day: return # Figure out whom to test and trace if not self.presumptive: # If we're not doing presumptive quarantine, it's really easy-- just people diagnosed on this timestep trace_from_inds = psu.true(sim.people.date_diagnosed == t) # Diagnosed this time step, time to trace else: just_tested = psu.true(sim.people.date_tested == t) # Tested this time step, time to trace trace_from_inds = psu.itruei(sim.people.exposed, just_tested) # This is necessary to avoid infinite chains of asymptomatic testing if len(trace_from_inds): # If there are any just-diagnosed people, go trace their contacts sim.people.trace(trace_from_inds, self.trace_probs, self.trace_time) # This is the step that actually traces people! return