Source code for poliosim.analysis

'''
Additional analysis functions.
'''

import os
import sys
import numpy as np
import pylab as pl
import pandas as pd
import sciris as sc
from . import utils as psu
from . import interventions as psi
try:
    import optuna as op
except ImportError as E: # pragma: no cover
    errormsg = f'Optuna import failed ({str(E)}), please install first (pip install optuna)'
    op = ImportError(errormsg)


__all__ = ['Analyzer',
           'snapshot',
           'save_states',
           'track_events',
           'track_shedders',
           'infection_report',
           'Fit',
           'TransTree',
           'calculate_contacts_infected',
           ]


[docs]class Analyzer(sc.prettyobj): ''' Base class for analyzers. Based on the Intervention class. Analyzers are used to provide more detailed information about a simulation than is available by default -- for example, pulling states out of sim.people on a particular timestep before it gets updated in the next timestep. To retrieve a particular analyzer from a sim, use sim.get_analyzer(). Args: label (str): a label for the Analyzer (used for ease of identification) ''' def __init__(self, label=None): if label is None: label = self.__class__.__name__ # Use the class name if no label is supplied self.label = label # e.g. "Record ages" self.initialized = False self.finalized = False return def __call__(self, *args, **kwargs): # Makes Analyzer(sim) equivalent to Analyzer.apply(sim) if not self.initialized: errormsg = f'Analyzer (label={self.label}, {type(self)}) has not been initialized' raise RuntimeError(errormsg) return self.apply(*args, **kwargs)
[docs] def initialize(self, sim=None): ''' Initialize the analyzer, e.g. convert date strings to integers. ''' self.initialized = True self.finalized = False return
[docs] def finalize(self, sim=None): ''' Finalize analyzer This method is run once as part of `sim.finalize()` enabling the analyzer to perform any final operations after the simulation is complete (e.g. rescaling) ''' if self.finalized: raise RuntimeError('Analyzer already finalized') # Raise an error because finalizing multiple times has a high probability of producing incorrect results e.g. applying rescale factors twice self.finalized = True return
[docs] def apply(self, sim): ''' Apply analyzer at each time point. The analyzer has full access to the sim object, and typically stores data/results in itself. This is the core method which each analyzer object needs to implement. Args: sim: the Sim instance ''' raise NotImplementedError
[docs] def shrink(self, in_place=False): ''' Remove any excess stored data from the intervention; for use with sim.shrink(). Args: in_place (bool): whether to shrink the intervention (else shrink a copy) ''' if in_place: return self else: return sc.dcp(self)
[docs] def to_json(self): ''' Return JSON-compatible representation Custom classes can't be directly represented in JSON. This method is a one-way export to produce a JSON-compatible representation of the intervention. This method will attempt to JSONify each attribute of the intervention, skipping any that fail. Returns: JSON-serializable representation ''' # Set the name json = {} json['analyzer_name'] = self.label if hasattr(self, 'label') else None json['analyzer_class'] = self.__class__.__name__ # Loop over the attributes and try to process attrs = self.__dict__.keys() for attr in attrs: try: data = getattr(self, attr) try: attjson = sc.jsonify(data) json[attr] = attjson except Exception as E: json[attr] = f'Could not jsonify "{attr}" ({type(data)}): "{str(E)}"' except Exception as E2: json[attr] = f'Could not jsonify "{attr}": "{str(E2)}"' return json
[docs]class snapshot(Analyzer): ''' Analyzer that takes a "snapshot" of the sim.people array at specified points in time, and saves them to itself. To retrieve them, you can either access the dictionary directly, or use the get() method. Args: days (list): list of ints/strings/date objects, the days on which to take the snapshot kwargs (dict): passed to Intervention() **Example**:: sim = ps.Sim(analyzers=ps.snapshot('2020-04-04', '2020-04-14')) sim.run() snapshot = sim['analyzers'][0] people = snapshot.snapshots[0] # Option 1 people = snapshot.snapshots['2020-04-04'] # Option 2 people = snapshot.get('2020-04-14') # Option 3 people = snapshot.get(34) # Option 4 people = snapshot.get() # Option 5 ''' def __init__(self, days, *args, **kwargs): super().__init__(**kwargs) # Initialize the Intervention object days = sc.promotetolist(days) # Combine multiple days days.extend(args) # Include additional arguments, if present self.days = days # Converted to integer representations self.dates = None # String representations self.start_day = None # Store the start date of the simulation self.snapshots = sc.odict() # Store the actual snapshots return
[docs] def initialize(self, sim): super().initialize() self.start_day = sim['start_day'] # Store the simulation start day self.days = psi.process_days(sim, self.days) # Ensure days are in the right format self.dates = [sim.date(day) for day in self.days] # Store as date strings return
[docs] def apply(self, sim): for ind in psi.find_day(self.days, sim.t): date = self.dates[ind] self.snapshots[date] = sc.dcp(sim.people) # Take snapshot! return
[docs] def get(self, key=None): ''' Retrieve a snapshot from the given key (int, str, or date) ''' if key is None: key = self.days[0] day = psu.day(key, start_day=self.start_day) date = psu.date(day, start_date=self.start_day, as_date=False) if date in self.snapshots: snapshot = self.snapshots[date] else: dates = ', '.join(list(self.snapshots.keys())) errormsg = f'Could not find snapshot date {date} (day {day}): choices are {dates}' raise sc.KeyNotFoundError(errormsg) return snapshot
# Define the states/events that are stored for save_states an track_events. full_states = [ 'age', 'naive', 'exposed', 'symptomatic', 'recovered', 'tested', 'quarantined', 'is_shed', 'viral_shed', 'shed_duration', 'susceptible_to_paralysis', 'current_immunity', 'postchallenge_peak_immunity', 'prechallenge_immunity', ] main_states = [ 'age', 'naive', 'susceptible_to_paralysis', 'symptomatic', 'recovered', 'is_shed', 'viral_shed', 'shed_duration', ]
[docs]class save_states(Analyzer): ''' Save the states of the people into a big array for polio. Args: states (list) the list of states to save (default, just shedding) full (bool): whether to save the complete list of states ''' def __init__(self, states=None, full=False, sort_inds=None, **kwargs): super().__init__(**kwargs) # Initialize the Analyzer object self.states = states if states else (full_states if full else main_states) self.n_states = len(self.states) self.arr = None self.sort_inds = sort_inds return
[docs] def initialize(self, sim): super().initialize() self.arr = np.zeros((self.n_states, len(sim.people), sim.npts)) return
[docs] def apply(self, sim): for i,state in enumerate(self.states): self.arr[i, :, sim.t] = sim.people[state] return
[docs] def to_dfdict(self): output = dict() for i, state in enumerate(self.states): output[state] = pd.DataFrame(self.arr[i, :, :]) return output
[docs] def plot(self, sort_inds=None, log_viral_shed=True, nonzero_only=False, filter_states=None, rows=None, **kwargs): assert not (nonzero_only == True and rows is not None), "I don't know what to do with nonzero_only == True and rows is not None" if sort_inds is None: sort_inds = np.arange(self.arr.shape[1]) # Do not change default order arr = sc.dcp(self.arr) figs = [] rows_with_nonzero = [] for i,state in enumerate(self.states): if filter_states is not None: if state not in filter_states: continue fig = pl.figure(figsize=(26,22)) vals = arr[i,sort_inds,:] if log_viral_shed and state == 'viral_shed': eps = 1e-2 # Surely this shouldn't be less than 1 particle, right? vals = np.log10(vals + eps) state = 'log10(viral_shed)' if nonzero_only: rows_with_nonzero = np.any(vals, axis=1).nonzero()[0] vals = vals[rows_with_nonzero,:] if rows is not None: vals=vals[rows, :] pl.imshow(vals, origin='lower', aspect='auto', **kwargs) pl.colorbar() pl.title(state) pl.xlabel('Time (days)') pl.ylabel('Person') if nonzero_only: pl.yticks(ticks=np.arange(len(rows_with_nonzero)), labels=rows_with_nonzero) if rows is not None: pl.yticks(ticks=np.arange(len(rows)), labels=rows) figs.append(fig) return figs, rows_with_nonzero
[docs] def coplot(self, people_inds=None, filter_states=None, transform_state=None, **kwargs): ''' Co-plot selected states of selected people, one plot per person. Args: people_inds: The indices of the people to create plots for. filter_states: If specified, only these states will be shown. The states you want also need to be specified when creating the analyzer object, or else be one of the defaults if you didn't specify any explicitly. transform_state: An optional dictionary, keyed on state names. The dictionary value is a function to pass the state values through. The transformed value will be plotted. **kwargs: Returns: Nothing. Example of using transform_state:: # This function returns log10(x) when x > 0, otherwise np.nan: def snoop_loggy_log10(x): return np.log10(x, where=0 < x, out=np.nan * x) # Apply the special log function just to the state 'viral_shed'. my_analyzer.plot(people_inds=inds, transform_state={'viral_shed': snoop_loggy_log10}) ''' if people_inds is None: people_inds = self.people_inds if self.people_inds is not None else np.arange(self.num_people) if transform_state is None: transform_state = {} arr = sc.dcp(self.arr) for person_ind in people_inds: state_frame = pd.DataFrame() for i,state in enumerate(self.states): if filter_states is not None and state not in filter_states: continue if state in transform_state: state_display_name = f'{transform_state[state].__name__}({state})' state_frame[state_display_name] = transform_state[state](arr[i, person_ind, :]) else: state_frame[state] = arr[i,person_ind,:] state_frame.plot(subplots=True, title=f'person [{person_ind}]', **kwargs) return
[docs]class track_events(Analyzer): ''' Store a list of events for everyone. Args: states (list) the list of states to save (default, just shedding) full (bool): whether to save the complete list of states curr_imm (bool): whether to store current immunity, which is updated on every timestep (default: false) Data are stored in self.events for the events, and self.init_state for the initial state. ''' def __init__(self, states=None, full=True, curr_imm=False, **kwargs): super().__init__(**kwargs) # Initialize the Analyzer object # Handle which states to store if states is not None: self.states = states else: if full: self.states = full_states else: self.states = main_states if not curr_imm: self.states = [s for s in self.states if s != 'current_immunity'] # Remove current immunity if it exists # Initialize everything else self.people = None # The people array from every timestep self.init_data = [] # The initial state of everyone self.event_data = [] # The data on events, as a dict self.init_state = None # Initial state of everyone self.events = None # The results as a dataframe self.n_people = 0 # Number of people return
[docs] def initialize(self, sim): super().initialize() self.n_people = len(sim.people) self.people = dict() for state in self.states: arr = sim.people[state].copy() self.people[state] = arr for p in range(self.n_people): row = dict(t=-1, uid=p, state=state, value=arr[p]) # Store the initial state self.init_data.append(row) return
[docs] def finalize(self, sim=None): ''' Convert from list-of-dicts to dataframe ''' super().finalize() self.init_state = pd.DataFrame(self.init_data) self.events = pd.DataFrame(self.event_data) del self.init_data # Save space del self.event_data return
[docs] def apply(self, sim): t = sim.t for state in self.states: old_arr = self.people[state] new_arr = sim.people[state].copy() changed = old_arr != new_arr not_nan = np.isfinite(new_arr) inds = sc.findinds(changed*not_nan) # Find which people have had state changes, and are non-nan for p in inds: # Loop over people who've changed row = dict(t=t, uid=p, state=state, value=new_arr[p]) # Store event information self.event_data.append(row) self.people[state] = new_arr # Update for the next timestep return
[docs]class track_shedders(Analyzer): ''' Keep a record of people who shed ''' def __init__(self, **kwargs): super().__init__(**kwargs) # Initialize the Analyzer object self.contacts = {} self.contact_shed = {} return
[docs] def initialize(self, sim): super().initialize() self.para_all = np.zeros((len(sim.people), sim.npts)) self.shed_all = np.zeros((len(sim.people), sim.npts)) return
[docs] def apply(self, sim): ''' Record data throughout the sim ''' self.para_all[:,sim.t] = sim.people.symptomatic self.shed_all[:,sim.t] = sim.people.is_shed if sim.t == sim.tvec[-1]: # Find symptomatics self.para_inds = sim.people.true('symptomatic') self.n_para = len(self.para_inds) self.date_para = {} self.age_para = {} self.shed_para = {} self.lkeys = sim.people.contacts.keys() self.tvec = np.arange(sim.npts) # Find their contacts for ind in self.para_inds: self.date_para[ind] = sim.people.date_symptomatic[ind] self.age_para[ind] = sim.people.age[ind] self.shed_para[ind] = sc.findinds(self.shed_all[ind,:]) self.contacts[ind] = {} self.contact_shed[ind] = {} for lkey in self.lkeys: layer = sim.people.contacts[lkey] # Find all the contacts of these people inds_list = [] for k1,k2 in [['p1','p2'],['p2','p1']]: # Loop over the contact network in both directions -- k1,k2 are the keys in_k1 = sc.findinds(layer[k1]==ind) # Get all the indices of the pairs that each person is in inds_list.append(layer[k2][in_k1]) # Find their pairing partner self.contacts[ind][lkey] = np.unique(np.concatenate(inds_list)) # Find all edges contact_shed_arr = self.shed_all[self.contacts[ind][lkey],:] contact_shed_vec = contact_shed_arr.sum(axis=0) self.contact_shed[ind][lkey] = contact_shed_vec return
[docs] def plot(self, font_family=None): if font_family: pl.rcParams['font.family'] = font_family figs = [] layer_names = dict(h='Household contacts', s='School', w='Workplace', c='Community contacts') colors = dict(h=[0.2, 0.4, 0.6], s=[0.3, 0.3, 0.3], w=[0.7,0.7,0.7], c=[0.8, 0.6, 0.1]) # colors = sc.gridcolors(len(layer_names), asarray=True)[[0,2,3,1]] # Change order since only household/community args1 = dict(marker='o', markersize=4) args2 = dict(lw=2) fig = pl.figure(figsize=(14,8), dpi=180) figs.append(fig) pl.subplots_adjust(left=0.05, right=0.95, bottom=0.08, top=0.95, hspace=0.4, wspace=0.2) axs = [] top = 0 order = np.argsort(list(self.date_para.values())) if len(order) % 2 != 0 and len(order)>2: order = order[:-1] n_para = len(order) if n_para <= 4: rows = n_para cols = 1 elif n_para <= 8: rows = n_para/2 cols = 2 else: rows = np.ceil(n_para/3) cols = 3 rows = int(rows) cols = int(cols) medrow = np.median(np.arange(rows)) axs = fig.subplots(nrows=rows, ncols=cols) min_x = -20 max_x = 50 # show paralysis duration, change colors, run sim longer, fewer y labels, fewer x labels for i,ind in enumerate(self.para_inds[order]): col = i//rows row = i % rows try: # >1 row ax = axs[row, col] except: ax = axs[row] x = self.tvec - self.date_para[ind] rel_shed = self.shed_para[ind] - self.date_para[ind] rel_shed = rel_shed[rel_shed>=min_x] rel_shed = rel_shed[rel_shed<=max_x] for ti,t in enumerate(x): bottom = 0 for l,lkey in enumerate(self.lkeys): n = int(self.contact_shed[ind][lkey][ti]) if n: label = layer_names[lkey] if t==0 else None ax.plot(n*[t], np.arange(bottom+1, bottom+n+1), 's', c=colors[lkey], label=label, zorder=-l, **args1) ax.plot((n+1)*[t], np.arange(bottom, bottom+n+1), c=colors[lkey], zorder=-l, **args2) bottom += n top = max(top, bottom) ax.plot(rel_shed, np.zeros(len(rel_shed))-0.15, lw=4, alpha=0.5, c=[0.7,0.1,0.6], label='Index case shedding', clip_on=False) ax.axvline(0, linestyle='-', lw=4, c='k', alpha=0.2, zorder=-10) dx = 0.5 minx = max(min_x-dx, x[0]-dx) maxx = min(max_x+dx, x[-1]+dx) ax.set_xlim(minx, maxx) if row==0 and col==cols-1: ax.legend(frameon=False, bbox_to_anchor=(0.75,1.00)) if row == rows-1: ax.set_xlabel('Days since onset of paralysis', fontweight='bold') if row == medrow and col==0: ax.set_ylabel('Number of shedding contacts', fontweight='bold') ax.set_title(r'$\bf{Case}$ $\bf{'+f'{i+1}' + '}$' + f': age {self.age_para[ind]:0.1f} years, paralysis on day {self.date_para[ind]:0.0f} of epidemic') stride = 2 for ax in axs.flatten(): ax.set_ylim([0, top+dx]) ax.set_yticks(np.arange(0, top+1, stride)) sc.boxoff(ax=ax) pl.show() return figs
[docs]class infection_report(Analyzer): ''' Keep a record of who infects who. Args: naive_immunity (float): immunity threshold for categorization as naive (units = Nab). Default is 8.0. age_breakpoints (array of float): age breakpoints for making categories. Default is [5,15] which forms categories 'naive u5', 'non-naive u5', '5-15', 'o15'. All categories with lower bound <5 will have naive and non-naive variants. The first age breakpoint must be >0 or an exception will be raised. ''' def __init__(self, naive_immunity=8.0, age_breakpoints=None, **kwargs): super().__init__(**kwargs) # Initialize the Analyzer object if age_breakpoints == None: age_breakpoints = [5,15] if age_breakpoints[0] <= 0: raise ValueError(f'First age breakpoint must be >0, was: {age_breakpoints[0]}') self.age_breakpoints = age_breakpoints self.infection_report_categories = self.make_infection_report_categories(self.age_breakpoints) self.n_categories = len(self.infection_report_categories.keys()) self.naive_immunity = naive_immunity return
[docs] def make_infection_report_categories(self, age_breakpoints): last_age_breakpoint = 0 infection_report_categories = dict() category_index = 0 age_breakpoints.append(sys.float_info.max) for i, age_breakpoint in enumerate(age_breakpoints): age_category = '' if i == 0: age_category = f'u{age_breakpoint}' elif i == len(age_breakpoints) - 1: age_category = f'o{last_age_breakpoint}' else: age_category = f'{last_age_breakpoint}-{age_breakpoint}' if last_age_breakpoint < 5: infection_report_categories[f'naive {age_category}'] = \ { 'category_index': category_index, 'age_lbound': last_age_breakpoint, 'age_ubound': age_breakpoint, 'immunity_lbound': 0.0, 'immunity_ubound': 8.0 } category_index = category_index + 1 infection_report_categories[f'non-naive {age_category}'] = \ { 'category_index': category_index, 'age_lbound': last_age_breakpoint, 'age_ubound': age_breakpoint, 'immunity_lbound': 8.0, 'immunity_ubound': sys.float_info.max } category_index = category_index + 1 else: infection_report_categories[f'{age_category}'] = \ { 'category_index': category_index, 'age_lbound': last_age_breakpoint, 'age_ubound': age_breakpoint, 'immunity_lbound': 0.0, 'immunity_ubound': sys.float_info.max } category_index = category_index + 1 last_age_breakpoint = age_breakpoint return infection_report_categories
[docs] def initialize(self, sim): super().initialize() self.initial_immunities = np.zeros(len(sim.people)) self.infection_report = np.zeros((self.n_categories, self.n_categories)) self.first_infection_report = np.zeros((self.n_categories, self.n_categories)) # We can derive the data above from the data below, but we have a bunch of post-processing # written to the stuff above, so storing this separately. self.layer_index_map = dict() for i, layer_key in enumerate(sim.people.contacts.keys()): self.layer_index_map[layer_key] = i self.infection_report_by_layer = np.zeros((len(self.layer_index_map.keys()), self.n_categories, self.n_categories)) self.first_infection_report_by_layer = np.zeros((len(self.layer_index_map.keys()), self.n_categories, self.n_categories)) return
[docs] def get_infection_report_category(self, age, immunity): for key in self.infection_report_categories.keys(): if age >= self.infection_report_categories[key]['age_lbound'] \ and age < self.infection_report_categories[key]['age_ubound'] \ and immunity >= self.infection_report_categories[key]['immunity_lbound'] \ and immunity < self.infection_report_categories[key]['immunity_ubound']: return self.infection_report_categories[key]['category_index'] raise ValueError(f'Cannot categorize age {age}, immunity {immunity}')
[docs] def apply(self, sim): ''' Record data throughout the sim ''' # At t = 0 we take a snapshot of initial immunity. if sim.t == 0: self.initial_immunities[:] = sim.people.current_immunity # At last time step, we pull from the infection log. if sim.t == sim.tvec[-1]: infection_log = sim.people.infection_log for infection in infection_log: if infection['source']: buckets = {} for st in ['source', 'target']: age = sim.people.age[infection[st]] if st == 'source': # Use initial immunity for source bucket. immunity = self.initial_immunities[infection['source']] else: # Use immunity at time of infection for target bucket. immunity = infection['target_immunity'] buckets[st] = self.get_infection_report_category(age, immunity) source_bucket = buckets['source'] target_bucket = buckets['target'] infection_layer_index = self.layer_index_map[infection['layer']] self.infection_report[source_bucket][target_bucket] += 1 self.infection_report_by_layer[infection_layer_index][source_bucket][target_bucket] += 1 if infection['target_exposure_count'] == 1: self.first_infection_report[source_bucket][target_bucket] += 1 self.first_infection_report_by_layer[infection_layer_index][source_bucket][target_bucket] += 1 return
[docs] def plot(self, font_family=None): if font_family: pl.rcParams['font.family'] = font_family fig, ax = pl.subplots() im = ax.imshow(self.infection_report, cmap="YlGn") infection_report_categories = self.infection_report_categories.keys() ax.set_xticks(np.arange(len(infection_report_categories))) ax.set_yticks(np.arange(len(infection_report_categories))) ax.set_xticklabels(infection_report_categories) ax.set_yticklabels(infection_report_categories) for i in range(len(infection_report_categories)): for j in range(len(infection_report_categories)): ax.text(j, i, self.infection_report[i, j], ha="center", va="center", color="k" ) cbar = ax.figure.colorbar(im) cbar.ax.set_ylabel('count', rotation=-90, va="bottom") ax.set_ylabel('Infection source category') ax.set_xlabel('Infection target category') ax.set_title("Infection by source/target category") pl.gca().invert_yaxis() fig.tight_layout() pl.show() return fig
[docs]class Fit(sc.prettyobj): ''' A class for calculating the fit between the model and the data. Note the following terminology is used here: - fit: nonspecific term for how well the model matches the data - difference: the absolute numerical differences between the model and the data (one time series per result) - goodness-of-fit: the result of passing the difference through a statistical function, such as mean squared error - loss: the goodness-of-fit for each result multiplied by user-specified weights (one time series per result) - mismatch: the sum of all the loses (a single scalar value) -- this is the value to be minimized during calibration Args: sim (Sim): the sim object weights (dict): the relative weight to place on each result keys (list): the keys to use in the calculation method (str): the method to be used to calculate the goodness-of-fit custom (dict): a custom dictionary of additional data to fit; format is e.g. {'<label>':{'data':[1,2,3], 'sim':[1,2,4], 'weights':2.0}} compute (bool): whether to compute the mismatch immediately verbose (bool): detail to print kwargs (dict): passed to compute_gof() **Example**:: sim = ps.Sim() # Needs data to work sim.run() fit = sim.compute_fit() fit.plot() ''' def __init__(self, sim, weights=None, keys=None, method=None, custom=None, compute=True, verbose=False, **kwargs): # Handle inputs self.weights = weights self.custom = sc.mergedicts(custom) self.verbose = verbose self.weights = sc.mergedicts({'cum_diagnoses':5}, weights) self.keys = keys self.gof_kwargs = kwargs # Copy data if sim.data is None: errormsg = 'Model fit cannot be calculated until data are loaded' raise RuntimeError(errormsg) self.data = sim.data # Copy sim results if not sim.results_ready: errormsg = 'Model fit cannot be calculated until results are run' raise RuntimeError(errormsg) self.sim_results = sc.objdict() for key in sim.result_keys() + ['t', 'date']: self.sim_results[key] = sim.results[key] self.sim_npts = sim.npts # Number of time points in the sim # Copy other things self.sim_dates = sim.datevec.tolist() # These are populated during initialization self.inds = sc.objdict() # To store matching indices between the data and the simulation self.inds.sim = sc.objdict() # For storing matching indices in the sim self.inds.data = sc.objdict() # For storing matching indices in the data self.date_matches = sc.objdict() # For storing matching dates, largely for plotting self.pair = sc.objdict() # For storing perfectly paired points between the data and the sim self.diffs = sc.objdict() # Differences between pairs self.gofs = sc.objdict() # Goodness-of-fit for differences self.losses = sc.objdict() # Weighted goodness-of-fit self.mismatches = sc.objdict() # Final mismatch values self.mismatch = None # The final value if compute: self.compute() return
[docs] def compute(self): ''' Perform all required computations ''' self.reconcile_inputs() # Find matching values self.compute_diffs() # Perform calculations self.compute_gofs() self.compute_losses() self.compute_mismatch() return self.mismatch
[docs] def reconcile_inputs(self): ''' Find matching keys and indices between the model and the data ''' data_cols = self.data.columns if self.keys is None: sim_keys = self.sim_results.keys() intersection = list(set(sim_keys).intersection(data_cols)) # Find keys in both the sim and data self.keys = [key for key in sim_keys if key in intersection and key.startswith('cum_')] # Only keep cumulative keys if not len(self.keys): errormsg = f'No matches found between simulation result keys ({sim_keys}) and data columns ({data_cols})' raise sc.KeyNotFoundError(errormsg) mismatches = [key for key in self.keys if key not in data_cols] if len(mismatches): mismatchstr = ', '.join(mismatches) errormsg = f'The following requested key(s) were not found in the data: {mismatchstr}' raise sc.KeyNotFoundError(errormsg) for key in self.keys: # For keys present in both the results and in the data self.inds.sim[key] = [] self.inds.data[key] = [] self.date_matches[key] = [] count = -1 for d, datum in self.data[key].iteritems(): count += 1 if np.isfinite(datum): if d in self.sim_dates: self.date_matches[key].append(d) self.inds.sim[key].append(self.sim_dates.index(d)) self.inds.data[key].append(count) self.inds.sim[key] = np.array(self.inds.sim[key]) self.inds.data[key] = np.array(self.inds.data[key]) # Convert into paired points for key in self.keys: self.pair[key] = sc.objdict() sim_inds = self.inds.sim[key] data_inds = self.inds.data[key] n_inds = len(sim_inds) self.pair[key].sim = np.zeros(n_inds) self.pair[key].data = np.zeros(n_inds) for i in range(n_inds): self.pair[key].sim[i] = self.sim_results[key].values[sim_inds[i]] self.pair[key].data[i] = self.data[key].values[data_inds[i]] # Process custom inputs self.custom_keys = list(self.custom.keys()) for key in self.custom.keys(): # Initialize and do error checking custom = self.custom[key] c_keys = list(custom.keys()) if 'sim' not in c_keys or 'data' not in c_keys: errormsg = f'Custom input must have "sim" and "data" keys, not {c_keys}' raise sc.KeyNotFoundError(errormsg) c_data = custom['data'] c_sim = custom['sim'] try: assert len(c_data) == len(c_sim) except: errormsg = f'Custom data and sim must be arrays, and be of the same length: data = {c_data}, sim = {c_sim} could not be processed' raise ValueError(errormsg) if key in self.pair: errormsg = f'You cannot use a custom key "{key}" that matches one of the existing keys: {self.pair.keys()}' raise ValueError(errormsg) # If all tests pass, simply copy the data self.pair[key] = sc.objdict() self.pair[key].sim = c_sim self.pair[key].data = c_data # Process weight, if available wt = custom.get('weight', 1.0) # Attempt to retrieve key 'weight', or use the default if not provided wt = custom.get('weights', wt) # ...but also try "weights" self.weights[key] = wt # Set the weight return
[docs] def compute_diffs(self, absolute=False): ''' Find the differences between the sim and the data ''' for key in self.pair.keys(): self.diffs[key] = self.pair[key].sim - self.pair[key].data if absolute: self.diffs[key] = np.abs(self.diffs[key]) return
[docs] def compute_gofs(self, **kwargs): ''' Compute the goodness-of-fit ''' kwargs = sc.mergedicts(self.gof_kwargs, kwargs) for key in self.pair.keys(): actual = sc.dcp(self.pair[key].data) predicted = sc.dcp(self.pair[key].sim) self.gofs[key] = psu.compute_gof(actual, predicted, **kwargs) return
[docs] def compute_losses(self): ''' Compute the weighted goodness-of-fit ''' for key in self.gofs.keys(): if key in self.weights: weight = self.weights[key] if sc.isiterable(weight): # It's an array len_wt = len(weight) len_sim = self.sim_npts len_match = len(self.gofs[key]) if len_wt == len_match: # If the weight already is the right length, do nothing pass elif len_wt == len_sim: # Most typical case: it's the length of the simulation, must trim weight = weight[self.inds.sim[key]] # Trim to matching indices else: errormsg = f'Could not map weight array of length {len_wt} onto simulation of length {len_sim} or data-model matches of length {len_match}' raise ValueError(errormsg) else: weight = 1.0 self.losses[key] = self.gofs[key]*weight return
[docs] def compute_mismatch(self, use_median=False): ''' Compute the final mismatch ''' for key in self.losses.keys(): if use_median: self.mismatches[key] = np.median(self.losses[key]) else: self.mismatches[key] = np.sum(self.losses[key]) self.mismatch = self.mismatches[:].sum() return self.mismatch
[docs] def plot(self, keys=None, width=0.8, font_size=18, fig_args=None, axis_args=None, plot_args=None): ''' Plot the fit of the model to the data. For each result, plot the data and the model; the difference; and the loss (weighted difference). Also plots the loss as a function of time. Args: keys (list): which keys to plot (default, all) width (float): bar width font_size (float): size of font fig_args (dict): passed to pl.figure() axis_args (dict): passed to pl.subplots_adjust() plot_args (dict): passed to pl.plot() ''' fig_args = sc.mergedicts(dict(figsize=(36,22)), fig_args) axis_args = sc.mergedicts(dict(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.3, hspace=0.3), axis_args) plot_args = sc.mergedicts(dict(lw=4, alpha=0.5, marker='o'), plot_args) pl.rcParams['font.size'] = font_size if keys is None: keys = self.keys + self.custom_keys n_keys = len(keys) loss_ax = None colors = sc.gridcolors(n_keys) n_rows = 4 figs = [pl.figure(**fig_args)] pl.subplots_adjust(**axis_args) main_ax1 = pl.subplot(n_rows, 2, 1) main_ax2 = pl.subplot(n_rows, 2, 2) bottom = sc.objdict() # Keep track of the bottoms for plotting cumulative bottom.daily = np.zeros(self.sim_npts) bottom.cumul = np.zeros(self.sim_npts) for k,key in enumerate(keys): if key in self.keys: # It's a time series, plot with days and dates days = self.inds.sim[key] # The "days" axis (or not, for custom keys) daylabel = 'Day' else: #It's custom, we don't know what it is days = np.arange(len(self.losses[key])) # Just use indices daylabel = 'Index' # Cumulative totals can't mix daily and non-daily inputs, so skip custom keys if key in self.keys: for i,ax in enumerate([main_ax1, main_ax2]): if i == 0: data = self.losses[key] ylabel = 'Daily mismatch' title = 'Daily total mismatch' else: data = np.cumsum(self.losses[key]) ylabel = 'Cumulative mismatch' title = f'Cumulative mismatch: {self.mismatch:0.3f}' dates = self.sim_results['date'][days] # Show these with dates, rather than days, as a reference point ax.bar(dates, data, width=width, bottom=bottom[i][self.inds.sim[key]], color=colors[k], label=f'{key}') if i == 0: bottom.daily[self.inds.sim[key]] += self.losses[key] else: bottom.cumul = np.cumsum(bottom.daily) if k == len(self.keys)-1: ax.set_xlabel('Date') ax.set_ylabel(ylabel) ax.set_title(title) ax.legend() pl.subplot(n_rows, n_keys, k+1*n_keys+1) pl.plot(days, self.pair[key].data, c='k', label='Data', **plot_args) pl.plot(days, self.pair[key].sim, c=colors[k], label='Simulation', **plot_args) pl.title(key) if k == 0: pl.ylabel('Time series (counts)') pl.legend() pl.subplot(n_rows, n_keys, k+2*n_keys+1) pl.bar(days, self.diffs[key], width=width, color=colors[k], label='Difference') pl.axhline(0, c='k') if k == 0: pl.ylabel('Differences (counts)') pl.legend() loss_ax = pl.subplot(n_rows, n_keys, k+3*n_keys+1, sharey=loss_ax) pl.bar(days, self.losses[key], width=width, color=colors[k], label='Losses') pl.xlabel(daylabel) pl.title(f'Total loss: {self.losses[key].sum():0.3f}') if k == 0: pl.ylabel('Losses') pl.legend() return figs
class Calibration(Analyzer): ''' A class to handle calibration of Poliosim simulations. Uses the Optuna hyperparameter optimization library (optuna.org), which must be installed separately (via pip install optuna). Note: running a calibration does not guarantee a good fit! You must ensure that you run for a sufficient number of iterations, have enough free parameters, and that the parameters have wide enough bounds. Please see the tutorial on calibration for more information. Args: sim (Sim) : the simulation to calibrate calib_pars (dict) : a dictionary of the parameters to calibrate of the format dict(key1=[best, low, high]) fit_args (dict) : a dictionary of options that are passed to sim.compute_fit() to calculate the goodness-of-fit par_samplers (dict) : an optional mapping from parameters to the Optuna sampler to use for choosing new points for each; by default, suggest_uniform custom_fn (func) : a custom function for modifying the simulation; receives the sim and calib_pars as inputs, should return the modified sim n_trials (int) : the number of trials per worker n_workers (int) : the number of parallel workers (default: maximum total_trials (int) : if n_trials is not supplied, calculate by dividing this number by n_workers) name (str) : the name of the database (default: 'poliosim_calibration') db_name (str) : the name of the database file (default: 'poliosim_calibration.db') storage (str) : the location of the database (default: sqlite) label (str) : a label for this calibration object verbose (bool) : whether to print details of the calibration kwargs (dict) : passed to cv.Calibration() Returns: A Calibration object **Example**:: sim = cv.Sim(datafile='data.csv') calib_pars = dict(beta=[0.015, 0.010, 0.020]) calib = cv.Calibration(sim, calib_pars, total_trials=100) calib.calibrate() calib.plot() New in version 3.0.3. ''' def __init__(self, sim, calib_pars=None, fit_args=None, custom_fn=None, par_samplers=None, n_trials=None, n_workers=None, total_trials=None, name=None, db_name=None, storage=None, label=None, verbose=True): super().__init__(label=label) # Initialize the Analyzer object if isinstance(op, Exception): raise op # If Optuna failed to import, raise that exception now import multiprocessing as mp # Handle run arguments if n_trials is None: n_trials = 20 if n_workers is None: n_workers = mp.cpu_count() if name is None: name = 'poliosim_calibration' if db_name is None: db_name = f'{name}.db' if storage is None: storage = f'sqlite:///{db_name}' if total_trials is not None: n_trials = total_trials/n_workers self.run_args = sc.objdict(n_trials=int(n_trials), n_workers=int(n_workers), name=name, db_name=db_name, storage=storage) # Handle other inputs self.sim = sim self.calib_pars = calib_pars self.fit_args = sc.mergedicts(fit_args) self.par_samplers = sc.mergedicts(par_samplers) self.custom_fn = custom_fn self.verbose = verbose self.calibrated = False # Handle if the sim has already been run if self.sim.complete: print('Warning: sim has already been run; re-initializing, but in future, use a sim that has not been run') self.sim = self.sim.copy() self.sim.initialize() return def run_sim(self, calib_pars, label=None, return_sim=False): ''' Create and run a simulation ''' sim = self.sim.copy() if label: sim.label = label valid_pars = {k:v for k,v in calib_pars.items() if k in sim.pars} sim.update_pars(valid_pars) if self.custom_fn: sim = self.custom_fn(sim, calib_pars) else: if len(valid_pars) != len(calib_pars): extra = set(calib_pars.keys()) - set(valid_pars.keys()) errormsg = f'The following parameters are not part of the sim, nor is a custom function specified to use them: {sc.strjoin(extra)}' raise ValueError(errormsg) sim.run() sim.compute_fit(**self.fit_args) if return_sim: return sim else: return sim.fit.mismatch def run_trial(self, trial): ''' Define the objective for Optuna ''' pars = {} for key, (best,low,high) in self.calib_pars.items(): if key in self.par_samplers: # If a custom sampler is used, get it now try: sampler_fn = getattr(trial, self.par_samplers[key]) except Exception as E: errormsg = 'The requested sampler function is not found: ensure it is a valid attribute of an Optuna Trial object' raise AttributeError(errormsg) from E else: sampler_fn = trial.suggest_uniform pars[key] = sampler_fn(key, low, high) # Sample from values within this range mismatch = self.run_sim(pars) return mismatch def worker(self): ''' Run a single worker ''' if self.verbose: op.logging.set_verbosity(op.logging.DEBUG) else: op.logging.set_verbosity(op.logging.ERROR) study = op.load_study(storage=self.run_args.storage, study_name=self.run_args.name) output = study.optimize(self.run_trial, n_trials=self.run_args.n_trials) return output def run_workers(self): ''' Run multiple workers in parallel ''' output = sc.parallelize(self.worker, iterarg=self.run_args.n_workers) return output def make_study(self): ''' Make a study, deleting one if it already exists ''' if os.path.exists(self.run_args.db_name): os.remove(self.run_args.db_name) print(f'Removed existing calibration {self.run_args.db_name}') output = op.create_study(storage=self.run_args.storage, study_name=self.run_args.name) return output def calibrate(self, calib_pars=None, verbose=True, **kwargs): ''' Actually perform calibration. Args: calib_pars (dict): if supplied, overwrite stored calib_pars verbose (bool): whether to print output from each trial kwargs (dict): if supplied, overwrite stored run_args (n_trials, n_workers, etc.) ''' # Load and validate calibration parameters if calib_pars is not None: self.calib_pars = calib_pars if self.calib_pars is None: errormsg = 'You must supply calibration parameters either when creating the calibration object or when calling calibrate().' raise ValueError(errormsg) self.run_args.update(kwargs) # Update optuna settings # Run the optimization t0 = sc.tic() self.make_study() self.run_workers() self.study = op.load_study(storage=self.run_args.storage, study_name=self.run_args.name) self.best_pars = sc.objdict(self.study.best_params) self.elapsed = sc.toc(t0, output=True) # Compare the results self.initial_pars = sc.objdict({k:v[0] for k,v in self.calib_pars.items()}) self.before = self.run_sim(calib_pars=self.initial_pars, label='Before calibration', return_sim=True) self.after = self.run_sim(calib_pars=self.best_pars, label='After calibration', return_sim=True) # Tidy up self.calibrated = True if verbose: self.summarize() return def summarize(self): if self.calibrated: print(f'Calibration for {self.run_args.n_workers*self.run_args.n_trials} total trials completed in {self.elapsed:0.1f} s.') before = self.before.fit.mismatch after = self.after.fit.mismatch print('\nInitial parameter values:') print(self.initial_pars) print('\nBest parameter values:') print(self.best_pars) print(f'\nMismatch before calibration: {before:n}') print(f'Mismatch after calibration: {after:n}') print(f'Percent improvement: {((before-after)/before)*100:0.1f}%') return before, after else: print('Calibration not yet run; please run calib.calibrate()') return def plot(self, **kwargs): from . import run as psr # To avoid circular import msim = psr.MultiSim([self.before, self.after]) fig = msim.plot(**kwargs) return fig
[docs]class TransTree(sc.prettyobj): ''' A class for holding a transmission tree. There are several different representations of the transmission tree: "infection_log" is copied from the people object and is the simplest representation. "detailed h" includes additional attributes about the source and target. If NetworkX is installed (required for most methods), "graph" includes an NX representation of the transmission tree. Args: sim (Sim): the sim object to_networkx (bool): whether to convert the graph to a NetworkX object ''' def __init__(self, sim, to_networkx=False): # Pull out each of the attributes relevant to transmission attrs = {'age', 'date_exposed', 'date_symptomatic', 'date_tested', 'date_diagnosed', 'date_quarantined', 'date_known_contact', 'date_recovered'} # Pull out the people and some of the sim results people = sim.people self.sim_start = sim['start_day'] # Used for filtering later self.sim_results = {} self.sim_results['t'] = sim.results['t'] self.sim_results['cum_infections'] = sim.results['cum_infections'].values self.n_days = people.t # people.t should be set to the last simulation timestep in the output (since the Transtree is constructed after the people have been stepped forward in time) self.pop_size = len(people) # Include the basic line list self.infection_log = sc.dcp(people.infection_log) self.exposure_log = sc.dcp(people.exposure_log) # Parse into sources and targets self.sources = [None for i in range(self.pop_size)] self.targets = [[] for i in range(self.pop_size)] self.source_dates = [None for i in range(self.pop_size)] self.target_dates = [[] for i in range(self.pop_size)] for entry in self.infection_log: source = entry['source'] target = entry['target'] date = entry['date'] if source: self.sources[target] = source # Each target has at most one source self.targets[source].append(target) # Each source can have multiple targets self.source_dates[target] = date # Each target has at most one source self.target_dates[source].append(date) # Each source can have multiple targets # Count the number of targets each person has self.n_targets = self.count_targets() # Include the detailed transmission tree as well self.detailed = self.make_detailed(people) # Optionally convert to NetworkX -- must be done on import since the people object is not kept if to_networkx: # Initialization import networkx as nx self.graph = nx.DiGraph() # Add the nodes for i in range(len(people)): d = {} for attr in attrs: d[attr] = people[attr][i] self.graph.add_node(i, **d) # Next, add edges from linelist for edge in people.infection_log: if edge['source'] is not None: # Skip seed infections self.graph.add_edge(edge['source'],edge['target'],date=edge['date'],layer=edge['layer']) return def __len__(self): ''' The length of the transmission tree is the length of the line list, which should equal the number of infections. ''' try: return len(self.infection_log) except: return 0 @property def transmissions(self): """ Iterable over edges corresponding to transmission events This excludes edges corresponding to seeded infections without a source """ output = [] for d in self.infection_log: if d['source'] is not None: output.append([d['source'], d['target']]) return output
[docs] def day(self, day=None, which=None): ''' Convenience function for converting an input to an integer day ''' if day is not None: day = psu.day(day, start_day=self.sim_start) elif which == 'start': day = 0 elif which == 'end': day = self.n_days return day
[docs] def count_targets(self, start_day=None, end_day=None): ''' Count the number of targets each infected person has. If start and/or end days are given, it will only count the targets of people who got infected between those dates (it does not, however, filter on the date the target got infected). Args: start_day (int/str): the day on which to start counting people who got infected end_day (int/str): the day on which to stop counting people who got infected ''' # Handle start and end days start_day = self.day(start_day, which='start') end_day = self.day(end_day, which='end') n_targets = np.nan+np.zeros(self.pop_size) for i in range(self.pop_size): if self.sources[i] is not None: if self.source_dates[i] >= start_day and self.source_dates[i] <= end_day: n_targets[i] = len(self.targets[i]) n_target_inds = sc.findinds(~np.isnan(n_targets)) n_targets = n_targets[n_target_inds] return n_targets
[docs] def make_detailed(self, people, reset=False): ''' Construct a detailed transmission tree, with additional information for each person ''' detailed = [None]*self.pop_size for transdict in self.infection_log: # Pull out key quantities ddict = sc.dcp(transdict) # For "detailed dictionary" source = ddict['source'] target = ddict['target'] ddict['s'] = {} # Source properties ddict['t'] = {} # Target properties # If the source is available (e.g. not a seed infection), loop over both it and the target if source is not None: stdict = {'s':source, 't':target} else: stdict = {'t':target} # Pull out each of the attributes relevant to transmission attrs = ['age', 'date_exposed', 'date_symptomatic', 'date_tested', 'date_diagnosed', 'date_quarantined', 'date_known_contact'] for st,stind in stdict.items(): for attr in attrs: ddict[st][attr] = people[attr][stind] if source is not None: for attr in attrs: if attr.startswith('date_'): is_attr = attr.replace('date_', 'is_') # Convert date to a boolean, e.g. date_diagnosed -> is_diagnosed ddict['s'][is_attr] = ddict['s'][attr] <= ddict['date'] # These don't make sense for people just infected (targets), only sources ddict['s']['is_asymp'] = np.isnan(people.date_symptomatic[source]) ddict['s']['is_presymp'] = ~ddict['s']['is_asymp'] and ~ddict['s']['is_symptomatic'] # Not asymptomatic and not currently symptomatic ddict['t']['is_quarantined'] = ddict['t']['date_quarantined'] <= ddict['date'] # This is the only target date that it makes sense to define since it can happen before infection detailed[target] = ddict return detailed
[docs] def plot(self, *args, **kwargs): ''' Plot the transmission tree ''' fig_args = kwargs.get('fig_args', dict(figsize=(16, 10))) ttlist = [] for source_ind, target_ind in self.transmissions: ddict = self.detailed[target_ind] source = ddict['s'] target = ddict['t'] tdict = {} tdict['date'] = ddict['date'] tdict['layer'] = ddict['layer'] tdict['s_asymp'] = np.isnan(source['date_symptomatic']) # True if they *never* became symptomatic tdict['s_presymp'] = ~tdict['s_asymp'] and tdict['date']<source['date_symptomatic'] # True if they became symptomatic after the transmission date tdict['s_diag'] = source['date_diagnosed'] < tdict['date'] tdict['s_quar'] = source['date_quarantined'] < tdict['date'] tdict['t_quar'] = target['date_quarantined'] < tdict['date'] # What if the target was released from quarantine? ttlist.append(tdict) df = pd.DataFrame(ttlist).rename(columns={'date': 'Day'}) df = df.loc[df['layer'] != 'seed_infection'] df['Stage'] = 'Symptomatic' df.loc[df['s_asymp'], 'Stage'] = 'Asymptomatic' df.loc[df['s_presymp'], 'Stage'] = 'Presymptomatic' fig = pl.figure(**fig_args) i = 1; r = 2; c = 3 def plot_quantity(key, title, i): dat = df.groupby(['Day', key]).size().unstack(key) ax = pl.subplot(r, c, i); dat.plot(ax=ax, legend=None) pl.legend(title=None) ax.set_title(title) to_plot = { 'layer': 'Layer', 'Stage': 'Source stage', 's_diag': 'Source diagnosed', 's_quar': 'Source quarantined', 't_quar': 'Target quarantined', } for i, (key, title) in enumerate(to_plot.items()): plot_quantity(key, title, i + 1) return fig
[docs] def plot_histograms(self, start_day=None, end_day=None, bins=None, width=0.8, fig_args=None, font_size=18): ''' Plots a histogram of the number of transmissions. Args: start_day (int/str): the day on which to start counting people who got infected end_day (int/str): the day on which to stop counting people who got infected bins (list): bin edges to use for the histogram width (float): width of bars fig_args (dict): passed to pl.figure() font_size (float): size of font ''' # Process targets n_targets = self.count_targets(start_day, end_day) # Handle bins if bins is None: max_infections = n_targets.max() bins = np.arange(0, max_infections+2) # Analysis counts = np.histogram(n_targets, bins)[0] bins = bins[:-1] # Remove last bin since it's an edge total_counts = counts*bins n_bins = len(bins) index = np.linspace(0, 100, len(n_targets)) sorted_arr = np.sort(n_targets) sorted_sum = np.cumsum(sorted_arr) sorted_sum = sorted_sum/sorted_sum.max()*100 change_inds = sc.findinds(np.diff(sorted_arr) != 0) max_labels = 15 # Maximum number of ticks and legend entries to plot # Plotting fig_args = sc.mergedicts(dict(figsize=(24,15)), fig_args) pl.rcParams['font.size'] = font_size fig = pl.figure(**fig_args) pl.set_cmap('Spectral') pl.subplots_adjust(left=0.08, right=0.92, bottom=0.08, top=0.92) colors = sc.vectocolor(n_bins) pl.subplot(1,2,1) w05 = width*0.5 w025 = w05*0.5 pl.bar(bins-w025, counts, width=w05, facecolor='k', label='Number of events') for i in range(n_bins): label = 'Number of transmissions (events × transmissions per event)' if i==0 else None pl.bar(bins[i]+w025, total_counts[i], width=w05, facecolor=colors[i], label=label) pl.xlabel('Number of transmissions per person') pl.ylabel('Count') if n_bins<max_labels: pl.xticks(ticks=bins) pl.legend() pl.title('Numbers of events and transmissions') pl.subplot(2,2,2) total = 0 for i in range(n_bins): pl.bar(bins[i:], total_counts[i], width=width, bottom=total, facecolor=colors[i]) total += total_counts[i] if n_bins<max_labels: pl.xticks(ticks=bins) pl.xlabel('Number of transmissions per person') pl.ylabel('Number of infections caused') pl.title('Number of transmissions, by transmissions per person') pl.subplot(2,2,4) pl.plot(index, sorted_sum, lw=3, c='k', alpha=0.5) n_change_inds = len(change_inds) label_inds = np.linspace(0, n_change_inds, max_labels).round() # Don't allow more than this many labels for i in range(n_change_inds): if i in label_inds: # Don't plot more than this many labels label = f'Transmitted to {bins[i+1]:n} people' else: label = None pl.scatter([index[change_inds[i]]], [sorted_sum[change_inds[i]]], s=150, zorder=10, c=[colors[i]], label=label) pl.xlabel('Proportion of population, ordered by the number of people they infected (%)') pl.ylabel('Proportion of infections caused (%)') pl.legend() pl.ylim([0, 100]) pl.grid(True) pl.title('Proportion of transmissions, by proportion of population') pl.axes([0.30, 0.65, 0.15, 0.2]) berry = [0.8, 0.1, 0.2] dirty_snow = [0.9, 0.9, 0.9] start_day = self.day(start_day, which='start') end_day = self.day(end_day, which='end') pl.axvspan(start_day, end_day, facecolor=dirty_snow) pl.plot(self.sim_results['t'], self.sim_results['cum_infections'], lw=2, c=berry) pl.xlabel('Day') pl.ylabel('Cumulative infections') return fig
[docs] def plot_intervals(self): ''' Plot serial interval distribution ''' intervals = [] for entry in self.detailed: if entry: date = entry['date'] # Date this infection occurred source = entry['source'] # The person who caused this infection if source: source_exposures = np.array(self.exposure_log[source]) source_date = np.max(source_exposures[source_exposures<date]) diff = date - source_date intervals.append(diff) intervals = np.array(intervals) #%% Plotting fig = pl.figure(figsize=(16,10)) ax1 = pl.subplot(111) bw = 5 edges = np.arange(0, intervals.max()+bw+1, bw) counts, _ = pl.histogram(intervals, bins=edges) bins = edges[:-1] props = counts/len(intervals)*100 ax1.bar(bins, props, width=bw*0.8, facecolor='k') ax1.set_xlabel('Serial interval (days)') ax1.set_ylabel('Proportion of transmissions (%)') ax1.set_xticks(bins) # pl.grid(True) color =[0.2,0.4,0.9] ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis cum_props = np.cumsum(props) ax2.plot(bins, cum_props, 'o-', lw=3, c=color, markersize=15) ax2.set_ylabel('Cumulative proportion of transmissions (%)', color=color) ax2.tick_params(axis='y', labelcolor=color) ax2.set_ylim([-1,102]) pl.show() return fig
[docs]class calculate_contacts_infected(Analyzer): ''' Calculate the contacts infected ''' def __init__(self, ignore_last_n_days=None, **kwargs): super().__init__(**kwargs) self.ignore_last_n_days = ignore_last_n_days if self.ignore_last_n_days is None: self.ignore_last_n_days = 30 return
[docs] def apply(self, sim): # At last time step, we calculate the numbers. if sim.t == sim.tvec[-1]: self.calculate_contacts_infected(sim) return
[docs] def calculate_contact_counts(self, sim): contact_counts = np.full(sim.people.pop_size, None) for ind in np.arange(0, sim.people.pop_size): contact_counts[ind] = dict() for layer in sim.people.contacts.keys(): contact_counts[ind][layer] = 0 for layer in sim.people.contacts.keys(): # TODO: deal with p1 -> p2 and p2 -> p1 duplicate the same contact. for p1_ind in sim.people.contacts[layer]['p1']: contact_counts[p1_ind][layer] += 1 for p2_ind in sim.people.contacts[layer]['p2']: contact_counts[p2_ind][layer] += 1 # Sum up and store the total across all layers for contact_count in contact_counts: total_contact_count = 0 for layer in contact_count.keys(): total_contact_count += contact_count[layer] contact_count['__total_contact_count'] = total_contact_count return contact_counts
[docs] def calculate_contacts_infected(self, sim): infection_log = sim.people.infection_log # In the numbers below, we want to leave out infections that started within # the last N days, because they haven't had time to do their whole thing, # and if we include them, it could skew the numbers low. t_end = sim.tvec[-1] t_cutoff = t_end - self.ignore_last_n_days ############################# # distribution of number of transmissions from an infected individual ############################# transmission_counts = np.full(sim.people.pop_size, 0) infected_inds = set() for infection in infection_log: if infection['date'] <= t_cutoff: infected_inds.add(infection['target']) # The 2nd half of this check makes sure we only count transmissions from individuals # who were infected before the cutoff time. It may count transmissions from them in their first infection # or in a subsequent infection. if infection['source'] is not None and infection['source'] in infected_inds: transmission_counts[infection['source']] += 1 ############################# # distribution of number of unique transmissions from an infected individual # unique meaning count each (source,target) only once. ############################# unique_transmission_counts = np.full(sim.people.pop_size, 0) infected_inds = set() infection_pairs = set() for infection in infection_log: if infection['date'] <= t_cutoff: infected_inds.add(infection['target']) # Again here we are only counting transmissions from individuals infected before the cutoff time. if infection['source'] is not None and infection['source'] in infected_inds: infection_pair = (infection['source'], infection['target']) if not infection_pair in infection_pairs: unique_transmission_counts[infection['source']] += 1 infection_pairs.add(infection_pair) ############################# # distribution of percentage of contacts infected from an infected individual ############################# contact_counts = self.calculate_contact_counts(sim) total_contact_counts = [contact_count['__total_contact_count'] for contact_count in contact_counts] # TODO: are there any situations of divide by zero that need to be dealt with here? percentage_contacts_infected = np.divide(unique_transmission_counts, total_contact_counts) self.infected_inds = infected_inds self.transmission_counts = transmission_counts self.unique_transmission_counts = unique_transmission_counts self.percentage_contacts_infected = percentage_contacts_infected infected_unique_transmission_counts = self.unique_transmission_counts[list(self.infected_inds)] self.r0 = np.mean(infected_unique_transmission_counts) self.V = np.var(infected_unique_transmission_counts)
# Clear out the memory hogs
[docs] def clear(self): del self.infected_inds del self.transmission_counts del self.unique_transmission_counts del self.percentage_contacts_infected
[docs] def plot(self): self.calculator = self plots = dict() print("Calculating plots for contacts infected...") ############################# # distribution of number of transmissions from an infected individual ############################# plots['distribution_of_transmission_count'] = pl.figure() # We just want transmission counts for people who were infected. infected_transmission_counts = self.calculator.transmission_counts[list(self.calculator.infected_inds)] max_transmission_count = np.max(infected_transmission_counts) counts1, bins1 = np.histogram(infected_transmission_counts, np.arange(0,max_transmission_count+1)) pl.hist(bins1[:-1], bins1, weights=counts1, log=True, align='left', edgecolor='black', zorder=10) the_mean = np.mean(infected_transmission_counts) pl.axvline(the_mean, color='k', linestyle='dashed', linewidth=1) min_ylim, max_ylim = pl.ylim() pl.text(the_mean*1.2, max_ylim*0.7, f'Mean: {the_mean:.1f}') pl.grid(axis='y', which='both', linestyle='--', linewidth=0.5, zorder=0) pl.xlabel('Number of transmissions from an infected individual') pl.ylabel('Count') pl.title(f'Total transmissions: {np.sum(np.multiply(counts1, bins1[:-1]))}') pl.tight_layout() ############################# # distribution of number of unique transmissions from an infected individual # unique meaning count each (source,target) only once. ############################# plots['distribution_of_unique_transmission_count'] = pl.figure() # We just want transmission counts for people who were infected. infected_unique_transmission_counts = self.calculator.unique_transmission_counts[list(self.calculator.infected_inds)] # Using max_transmission_count from the not-unique version above, to make it easier to compare # the two plots. counts2, bins2 = np.histogram(infected_unique_transmission_counts, np.arange(0,max_transmission_count+1)) pl.hist(bins2[:-1], bins2, weights=counts2, log=True, align='left', edgecolor='black', zorder=10) the_mean = np.mean(infected_unique_transmission_counts) pl.axvline(the_mean, color='k', linestyle='dashed', linewidth=1) min_ylim, max_ylim = pl.ylim() pl.text(the_mean*1.2, max_ylim*0.7, f'Mean: {the_mean:.1f}') pl.grid(axis='y', which='both', linestyle='--', linewidth=0.5, zorder=0) pl.grid(axis='y', which='both', linestyle='--', linewidth=0.5, zorder=0) pl.xlabel('Number of unique transmissions from an infected individual\n(repeated transmissions to same target counted only once)') pl.ylabel('Count') total_unique_transmissions = np.sum(np.multiply(counts2, bins2[:-1])) pl.title(f'Total unique transmissions: {total_unique_transmissions}') pl.tight_layout() ############################# # distribution of percentage of contacts infected from an infected individual ############################# infected_percentage_contacts_infected = self.calculator.percentage_contacts_infected[list(self.calculator.infected_inds)] plots['distribution_of_percentage_of_contacts_infected'] = pl.figure() counts, bins = np.histogram(infected_percentage_contacts_infected, np.linspace(0,1,11)) pl.hist(bins[:-1], bins, weights=counts, log=True, align='mid', edgecolor='black', zorder=10) pl.xticks(bins) pl.grid(axis='y', which='both', linestyle='--', linewidth=0.5, zorder=0) pl.xlabel('Fraction of contacts infected') pl.ylabel('Count') pl.title(f'Total unique transmissions: {total_unique_transmissions}') pl.tight_layout() ############################# # x axis is age bucket # # Whisker plot of percentage of contacts infected ############################# ############################# # one subplot for each layer # # x axis is age bucket # # Whisker plot of percentage of contacts infected # in that layer. ############################# print("Done.") return plots
[docs] def save(self, folder=None): if folder is None: folder = "plot_contacts_infected" folder_path = sc.thisdir(aspath=True) / folder folder_path.mkdir(parents=True, exist_ok=True) plots = self.plot() for plot_name in plots.keys(): filepath = f'{folder}/{plot_name}.png' plots[plot_name].savefig(filepath) print(f'Wrote plot [{plot_name}] to [{filepath}]') return