Source code for fpsim.base

'''
Base classes for loading parameters and for running simulations with FP model
'''
import numpy as np
import pandas as pd
import sciris as sc
import pylab as pl
from . import utils as fpu
from . import defaults as fpd

obj_get = object.__getattribute__ # Alias the default getattribute method
obj_set = object.__setattr__


__all__ = ['ParsObj', 'BasePeople', 'BaseSim']
__all__ += ['ndict']



class FlexPretty(sc.prettyobj):
    '''
    A class that supports multiple different display options: namely obj.brief()
    for a one-line description and obj.disp() for a full description.
    '''

    def __repr__(self):
        ''' Use brief repr by default '''
        try:
            string = self._brief()
        except Exception as E:
            string = sc.objectid(self)
            string += f'Warning, something went wrong printing object:\n{str(E)}'
        return string

    def _disp(self):
        ''' Verbose output -- use Sciris' pretty repr by default '''
        return sc.prepr(self)

    def disp(self, output=False):
        ''' Print or output verbose representation of the object '''
        string = self._disp()
        if not output:
            print(string)
        else:
            return string

    def _brief(self):
        ''' Brief output -- use a one-line output, a la Python's default '''
        return sc.objectid(self)

    def brief(self, output=False):
        ''' Print or output a brief representation of the object '''
        string = self._brief()
        if not output:
            print(string)
        else:
            return string



[docs] class ParsObj(FlexPretty): ''' A class based around performing operations on a self.pars dict. ''' def __init__(self, pars, **kwargs): self.update_pars(pars, create=True, **kwargs) return def __getitem__(self, key): ''' Allow sim['par_name'] instead of sim.pars['par_name'] ''' try: return self.pars[key] except: all_keys = '\n'.join(list(self.pars.keys())) errormsg = f'Key "{key}" not found; available keys:\n{all_keys}' raise sc.KeyNotFoundError(errormsg) def __setitem__(self, key, value): ''' Ditto ''' if key in self.pars: self.pars[key] = value else: all_keys = '\n'.join(list(self.pars.keys())) errormsg = f'Key "{key}" not found; available keys:\n{all_keys}' raise sc.KeyNotFoundError(errormsg) return
[docs] def update_pars(self, pars=None, create=False, **kwargs): ''' Update internal dict with new pars. Args: pars (dict): the parameters to update (if None, do nothing) create (bool): if create is False, then raise a KeyNotFoundError if the key does not already exist kwargs (dict): additional parameters ''' pars = sc.mergedicts(pars, kwargs) if pars: if not isinstance(pars, dict): raise TypeError(f'The pars object must be a dict; you supplied a {type(pars)}') if not hasattr(self, 'pars'): self.pars = pars if not create: available_keys = list(self.pars.keys()) mismatches = [key for key in pars.keys() if key not in available_keys] if len(mismatches): errormsg = f'Key(s) {mismatches} not found; available keys are {available_keys}' raise sc.KeyNotFoundError(errormsg) self.pars.update(pars) return
[docs] class BasePeople(sc.prettyobj): ''' Class for all the people in the simulation. ''' def __init__(self): ''' Initialize essential attributes used for filtering ''' obj_set(self, '_keys', []) # Since getattribute is overwritten obj_set(self, '_inds', None) return def __len__(self): try: return len(self.uid) except Exception as E: print(f'Warning: could not get length of People (could not get self.uid: {E})') return 0 def __getitem__(self, key): ''' Allow people['attr'] instead of getattr(people, 'attr') If the key is an integer, alias `people.person()` to return a `Person` instance ''' try: return self.__getattribute__(key) except: # pragma: no cover if isinstance(key, int): return self.person(key) else: errormsg = f'Key "{key}" is not a valid attribute of people' raise AttributeError(errormsg) def __setitem__(self, key, value): ''' Ditto ''' self.__dict__[key] = value return def _is_filtered(self, attr): ''' Determine if a given attribute is filtered (e.g. people.age is, people.inds isn't) ''' is_filtered = (self._inds is not None and attr in self._keys) return is_filtered def __getattribute__(self, attr): ''' For array quantities, handle filtering ''' output = obj_get(self, attr) if attr[0] == '_': # Short-circuit for built-in methods to save time return output else: try: # Unclear wy this fails, but sometimes it does during initialization/pickling keys = obj_get(self, '_keys') except: keys = [] if attr not in keys: return output else: if self._is_filtered(attr): output = output[self.inds] return output def __setattr__(self, attr, value): ''' Ditto ''' if self._is_filtered(attr): array = obj_get(self, attr) array[self.inds] = value else: # If not initialized, rely on the default behavior obj_set(self, attr, value) return
[docs] def __add__(self, people2): ''' Combine two people arrays ''' # Preliminaries newpeople = self keys = self.keys() n_orig = len(newpeople) max_uid = newpeople.uid.max() + 1 n_new = len(people2) # Merge arrays for key in keys: npval = newpeople[key] p2val = people2[key] if isinstance(npval, np.ndarray): newpeople[key] = np.concatenate([npval, p2val], axis=0) elif isinstance(npval, list): newpeople[key] += p2val else: errormsg = f'Not sure what to do with object of type {type(npval)}' raise TypeError(errormsg) newpeople.uid[n_orig:] = max_uid + np.arange(n_new) # Reassign UIDs so they're unique return newpeople
def __radd__(self, people2): ''' Allows sum() to work correctly ''' if not people2: return self else: return self.__add__(people2)
[docs] def keys(self): ''' Returns keys for all properties of the people object ''' try: # Unclear wy this fails, but sometimes it does during initialization/pickling keys = obj_get(self, '_keys')[:] except: keys = [] return keys
@property def is_female(self): ''' Boolean array of everyone female ''' return self.sex == 0 @property def is_male(self): ''' Boolean array of everyone male ''' return self.sex == 1 @property def int_age(self): ''' Return ages as an integer ''' return np.array(self.age, dtype=np.int64) @property def ceil_age(self): ''' Rounds age up to the next highest integer''' return np.array(np.ceil(self.age)) @property def int_age_clip(self): ''' Return ages as integers, clipped to maximum allowable age for pregnancy ''' return np.minimum(self.int_age, fpd.max_age_preg) @property def n(self): ''' Number of people alive ''' return self.alive.sum() @property def inds(self): ''' Alias to self._inds to prevent accidental overwrite & increase speed ''' return self._inds @property def len_inds(self): ''' Alias to len(self) ''' if self._inds is not None: return len(self._inds) else: return len(self) @property def len_people(self): ''' Full length of People array, ignoring filtering ''' return len(self.unfilter())
[docs] def plot(self, fig_args=None, hist_args=None): ''' Plot histograms of each quantity ''' fig_args = sc.mergedicts(fig_args) hist_args = sc.mergedicts(dict(bins=50), hist_args) keys = self.keys() nkeys = len(keys) rows,cols = sc.get_rows_cols(nkeys) fig = pl.figure(**fig_args) for k,key in enumerate(keys): pl.subplot(rows,cols,k+1) try: data = np.array(self[key], dtype=float) mean = data.mean() label = f'mean: {mean}' pl.hist(data, label=label, **hist_args) pl.title(key) pl.legend() except: pl.title(f'Could not plot {key}') return fig
[docs] def filter(self, criteria=None, inds=None): ''' Store indices to allow for easy filtering of the People object. Args: criteria (bool array): a boolean array for the filtering critria inds (array): alternatively, explicitly filter by these indices Returns: A filtered People object, which works just like a normal People object except only operates on a subset of indices. ''' # Create a new People object with the same properties as the original filtered = object.__new__(self.__class__) # Create a new People instance BasePeople.__init__(filtered) # Perform essential initialization filtered.__dict__ = sc.cp(self.__dict__) # Copy pointers to the arrays in People # Perform the filtering if criteria is None: # No filtering: reset filtered._inds = None if inds is not None: # Unless indices are supplied directly, in which case use them filtered._inds = inds else: # Main use case: perform filtering len_criteria = len(criteria) if len_criteria == self.len_inds: # Main use case: a new filter applied on an already filtered object, e.g. filtered.filter(filtered.age > 5) new_inds = criteria.nonzero()[0] # Criteria is already filtered, just get the indices elif len_criteria == self.len_people: # Alternative: a filter on the underlying People object is applied to the filtered object, e.g. filtered.filter(people.age > 5) new_inds = criteria[filtered.inds].nonzero()[0] # Apply filtering before getting the new indices else: errormsg = f'"criteria" must be boolean array matching either current filter length ({self.len_inds}) or else the total number of people ({self.len_people}), not {len(criteria)}' raise ValueError(errormsg) if filtered.inds is None: # Not yet filtered: use the indices directly filtered._inds = new_inds else: # Already filtered: map them back onto the original People indices filtered._inds = filtered.inds[new_inds] return filtered
[docs] def unfilter(self): ''' An easy way of unfiltering the People object, returning the original. ''' unfiltered = self.filter(criteria=None) return unfiltered
[docs] def binomial(self, prob, as_inds=False, as_filter=False): ''' Return indices either by a single probability or by an array of probabilities. By default just return the boolean array, but can also return the indices, or the filtered People object. Args: prob (float/array): either a scalar probability, or an array of probabilities of the same length as People as_inds (bool): return as list of indices instead of a boolean array as_filter (bool): return as filter instead than boolean array ''' if sc.isnumber(prob): arr = fpu.n_binomial(prob, len(self)) elif sc.isarray(prob): arr = fpu.binomial_arr(prob) else: errormsg = f'Could not recognize {type(prob)} as a scalar or array' raise TypeError(errormsg) if as_inds: output = sc.findinds(arr) elif as_filter: output = self.filter(arr) else: output = arr return output
[docs] class BaseSim(ParsObj): ''' The BaseSim class handles the dynamics of the simulation. ''' def year2ind(self, year): index = int((year - self.pars['start_year']) * fpd.mpy / self.pars['timestep']) return index def ind2year(self, ind): year = ind * self.pars['timestep'] / fpd.mpy # Months return year def ind2calendar(self, ind): year = self.ind2year(ind) + self.pars['start_year'] return year @property def npts(self): ''' Count the number of points in timesteps between the starting year and the ending year.''' try: return int(fpd.mpy * (self.pars['end_year'] - self.pars['start_year']) / self.pars['timestep'] + 1) except: return 0 @property def tvec(self): ''' Create a time vector array at intervals of the timestep in years ''' try: return self.pars['start_year'] + np.arange(self.npts) * self.pars['timestep'] / fpd.mpy except: return np.array([]) @property def n(self): return self.people.alive.sum() def _brief(self): ''' Return a one-line description of a sim -- used internally and by repr(); see sim.brief() for the user version. ''' # Try to get a detailed description of the sim... try: if self.already_run: s = self.summary results = f'b={s.births:n} ☠={s.deaths:n} pop={s.final:n}' else: results = 'not run' # Set label string labelstr = f'"{self.label}"' if self.label else '<no label>' start = self['start_year'] end = self['end_year'] n_agents = self['n_agents'] string = f'Sim({labelstr}; n={n_agents:n}; {start}-{end}; results: {results})' # ...but if anything goes wrong, return the default with a warning except Exception as E: # pragma: no cover string = sc.objectid(self) string += f'Warning, sim appears to be malformed; use sim.disp() for details:\n{str(E)}' return string def _get_ia(self, which, label=None, partial=False, as_list=False, as_inds=False, die=True, first=False): ''' Helper method for get_interventions() and get_analyzers(); see get_interventions() docstring ''' # Handle inputs if which not in ['interventions', 'analyzers']: # pragma: no cover errormsg = f'This method is only defined for interventions and analyzers, not "{which}"' raise ValueError(errormsg) ia_list = sc.tolist(self.pars[which]) # List of interventions or analyzers n_ia = len(ia_list) # Number of interventions/analyzers if label == 'summary': # Print a summary of the interventions df = pd.DataFrame(columns=['ind', 'label', 'type']) for ind,ia_obj in enumerate(ia_list): df = df.append(dict(ind=ind, label=str(ia_obj.label), type=type(ia_obj)), ignore_index=True) print(f'Summary of {which}:') print(df) return else: # Standard usage case position = 0 if first else -1 # Choose either the first or last element if label is None: # Get all interventions if no label is supplied, e.g. sim.get_interventions() label = np.arange(n_ia) if isinstance(label, np.ndarray): # Allow arrays to be provided label = label.tolist() labels = sc.promotetolist(label) # Calculate the matches matches = [] match_inds = [] for label in labels: if sc.isnumber(label): matches.append(ia_list[label]) # This will raise an exception if an invalid index is given label = n_ia + label if label<0 else label # Convert to a positive number match_inds.append(label) elif sc.isstring(label) or isinstance(label, type): for ind,ia_obj in enumerate(ia_list): if sc.isstring(label) and ia_obj.label == label or (partial and (label in str(ia_obj.label))): matches.append(ia_obj) match_inds.append(ind) elif isinstance(label, type) and isinstance(ia_obj, label): matches.append(ia_obj) match_inds.append(ind) else: # pragma: no cover errormsg = f'Could not interpret label type "{type(label)}": should be str, int, list, or {which} class' raise TypeError(errormsg) # Parse the output options if as_inds: output = match_inds elif as_list: # Used by get_interventions() output = matches else: if len(matches) == 0: # pragma: no cover if die: errormsg = f'No {which} matching "{label}" were found' raise ValueError(errormsg) else: output = None else: output = matches[position] # Return either the first or last match (usually), used by get_intervention() return output
[docs] def get_interventions(self, label=None, partial=False, as_inds=False): ''' Find the matching intervention(s) by label, index, or type. If None, return all interventions. If the label provided is "summary", then print a summary of the interventions (index, label, type). Args: label (str, int, Intervention, list): the label, index, or type of intervention to get; if a list, iterate over one of those types partial (bool): if true, return partial matches (e.g. 'beta' will match all beta interventions) as_inds (bool): if true, return matching indices instead of the actual interventions ''' return self._get_ia('interventions', label=label, partial=partial, as_inds=as_inds, as_list=True)
[docs] def get_intervention(self, label=None, partial=False, first=False, die=True): ''' Like get_interventions(), find the matching intervention(s) by label, index, or type. If more than one intervention matches, return the last by default. If no label is provided, return the last intervention in the list. Args: label (str, int, Intervention, list): the label, index, or type of intervention to get; if a list, iterate over one of those types partial (bool): if true, return partial matches (e.g. 'beta' will match all beta interventions) first (bool): if true, return first matching intervention (otherwise, return last) die (bool): whether to raise an exception if no intervention is found ''' return self._get_ia('interventions', label=label, partial=partial, first=first, die=die, as_inds=False, as_list=False)
[docs] def get_analyzers(self, label=None, partial=False, as_inds=False): ''' Same as get_interventions(), but for analyzers. ''' return self._get_ia('analyzers', label=label, partial=partial, as_list=True, as_inds=as_inds)
[docs] def get_analyzer(self, label=None, partial=False, first=False, die=True): ''' Same as get_intervention(), but for analyzers. ''' return self._get_ia('analyzers', label=label, partial=partial, first=first, die=die, as_inds=False, as_list=False)
[docs] class ndict(sc.objdict): """ A dictionary-like class that provides additional functionalities for handling named items. Args: name (str): The items' attribute to use as keys. type (type): The expected type of items. strict (bool): If True, only items with the specified attribute will be accepted. duplicates (bool): If True, it accepts duplicated keys an renames the using enumeration. Default: False. **Examples**:: networks = ss.ndict(ss.mf(), ss.maternal()) networks = ss.ndict([ss.mf(), ss.maternal()]) networks = ss.ndict({'mf':ss.mf(), 'maternal':ss.maternal()}) """ def __init__(self, *args, name='name', type=None, strict=True, **kwargs): self.setattribute('_name', name) # Since otherwise treated as keys self.setattribute('_type', type) self.setattribute('_strict', strict) self._initialize(*args, **kwargs) return def append(self, arg, key=None): valid = False if arg is None: return # Nothing to do elif hasattr(arg, self._name): key = key or getattr(arg, self._name) valid = True elif isinstance(arg, dict): if self._name in arg: key = key or arg[self._name] valid = True else: for k, v in arg.items(): self.append(v, key=k) valid = None # Skip final processing elif not self._strict: key = key or f'item{len(self) + 1}' valid = True else: valid = False if valid is True: self._check_type(arg) # Check if this key already exists if key in self: raise fpu.DuplicateNameException(self[key]) self[key] = arg elif valid is None: pass # Nothing to do else: errormsg = f'Could not interpret argument {arg}: does not have expected attribute "{self._name}"' raise ValueError(errormsg) return def _check_type(self, arg): """ Check types """ if self._type is not None: if not isinstance(arg, self._type): errormsg = f'The following item does not have the expected type {self._type}:\n{arg}' raise TypeError(errormsg) return def _initialize(self, *args, **kwargs): args = sc.mergelists(*args) for arg in args: self.append(arg) for key, arg in kwargs.items(): self.append(arg, key=key) return def copy(self): new = self.__class__.__new__(name=self._name, type=self._type, strict=self._strict) new.update(self) return new
[docs] def __add__(self, dict2): """ Allow c = a + b """ new = self.copy() new.append(dict2) return new
def __iadd__(self, dict2): """ Allow a += b """ self.append(dict2) return self