Source code for hpvsim.analysis

'''
Additional analysis functions that are not part of the core workflow,
but which are useful for particular investigations.
'''

import numpy as np
import pylab as pl
import pandas as pd
import sciris as sc
from . import utils as hpu
from . import misc as hpm
from . import plotting as hppl
from . import defaults as hpd
from . import parameters as hppar
from . import interventions as hpi
from .settings import options as hpo # For setting global options


__all__ = ['Analyzer', 'snapshot', 'age_pyramid', 'age_results', 'age_causal_infection',
           'cancer_detection', 'daly_computation', 'analyzer_map']


[docs] class Analyzer(sc.prettyobj): ''' Base class for analyzers. Based on the Intervention class. Analyzers are used to provide more detailed information about a simulation than is available by default -- for example, pulling states out of sim.people on a particular timestep before it gets updated in the next timestep. To retrieve a particular analyzer from a sim, use ``sim.get_analyzer()``. Args: label (str): a label for the Analyzer (used for ease of identification) ''' def __init__(self, label=None): if label is None: label = self.__class__.__name__ # Use the class name if no label is supplied self.label = label # e.g. "Record ages" self.initialized = False self.finalized = False return def __call__(self, *args, **kwargs): # Makes Analyzer(sim) equivalent to Analyzer.apply(sim) if not self.initialized: errormsg = f'Analyzer (label={self.label}, {type(self)}) has not been initialized' raise RuntimeError(errormsg) return self.apply(*args, **kwargs)
[docs] def initialize(self, sim=None): ''' Initialize the analyzer, e.g. convert date strings to integers. ''' self.initialized = True self.finalized = False return
[docs] def finalize(self, sim=None): ''' Finalize analyzer This method is run once as part of ``sim.finalize()`` enabling the analyzer to perform any final operations after the simulation is complete (e.g. rescaling) ''' if self.finalized: raise RuntimeError('Analyzer already finalized') # Raise an error because finalizing multiple times has a high probability of producing incorrect results e.g. applying rescale factors twice self.finalized = True return
[docs] def apply(self, sim): ''' Apply analyzer at each time point. The analyzer has full access to the sim object, and typically stores data/results in itself. This is the core method which each analyzer object needs to implement. Args: sim: the Sim instance ''' raise NotImplementedError
[docs] def shrink(self, in_place=False): ''' Remove any excess stored data from the intervention; for use with ``sim.shrink()``. Args: in_place (bool): whether to shrink the intervention (else shrink a copy) ''' if in_place: return self else: return sc.dcp(self)
[docs] @staticmethod def reduce(analyzers, use_mean=False): ''' Create a reduced analyzer from a list of analyzers, using Args: analyzers: list of analyzers use_mean (bool): whether to use medians (the default) or means to create the reduced analyzer ''' pass
[docs] def to_json(self): ''' Return JSON-compatible representation Custom classes can't be directly represented in JSON. This method is a one-way export to produce a JSON-compatible representation of the intervention. This method will attempt to JSONify each attribute of the intervention, skipping any that fail. Returns: JSON-serializable representation ''' # Set the name json = {} json['analyzer_name'] = self.label if hasattr(self, 'label') else None json['analyzer_class'] = self.__class__.__name__ # Loop over the attributes and try to process attrs = self.__dict__.keys() for attr in attrs: try: data = getattr(self, attr) try: attjson = sc.jsonify(data) json[attr] = attjson except Exception as E: json[attr] = f'Could not jsonify "{attr}" ({type(data)}): "{str(E)}"' except Exception as E2: json[attr] = f'Could not jsonify "{attr}": "{str(E2)}"' return json
def validate_recorded_dates(sim, requested_dates, recorded_dates, die=True): ''' Helper method to ensure that dates recorded by an analyzer match the ones requested. ''' requested_dates = sorted(list(requested_dates)) recorded_dates = sorted(list(recorded_dates)) if recorded_dates != requested_dates: # pragma: no cover errormsg = f'The dates {requested_dates} were requested but only {recorded_dates} were recorded: please check the dates fall between {sim["start"]} and {sim["end"]} and the sim was actually run' if die: raise RuntimeError(errormsg) else: print(errormsg) return
[docs] class snapshot(Analyzer): ''' Analyzer that takes a "snapshot" of the sim.people array at specified points in time, and saves them to itself. To retrieve them, you can either access the dictionary directly, or use the ``get()`` method. Args: timepoints (list): list of ints/strings/date objects, the days on which to take the snapshot die (bool): whether or not to raise an exception if a date is not found (default true) kwargs (dict): passed to :py:class:`Analyzer` **Example**:: sim = hpv.Sim(analyzers=hpv.snapshot('2015.4', '2020')) sim.run() snapshot = sim['analyzers'][0] people = snapshot.snapshots[0] # Option 1 people = snapshot.snapshots['2020'] # Option 2 people = snapshot.get('2020') # Option 3 people = snapshot.get(34) # Option 4 people = snapshot.get() # Option 5 ''' def __init__(self, timepoints=None, *args, die=True, **kwargs): super().__init__(**kwargs) # Initialize the Analyzer object self.timepoints = timepoints self.die = die # Whether or not to raise an exception self.dates = None # Representations in terms of years, e.g. 2020.4, set during initialization self.start = None # Store the start year of the simulation self.snapshots = sc.odict() # Store the actual snapshots return
[docs] def initialize(self, sim): self.start = sim['start'] # Store the simulation start if self.timepoints is None: self.timepoints = [sim['end']] self.timepoints, self.dates = sim.get_t(self.timepoints, return_date_format='str') # Ensure timepoints and dates are in the right format self.initialized = True return
[docs] def apply(self, sim): for ind in sc.findinds(self.timepoints, sim.t): date = self.dates[ind] self.snapshots[date] = sc.dcp(sim.people) # Take snapshot!
[docs] def finalize(self, sim): super().finalize() validate_recorded_dates(sim, requested_dates=self.dates, recorded_dates=self.snapshots.keys(), die=self.die) return
[docs] def get(self, key=None): ''' Retrieve a snapshot from the given key (int, str, or date) ''' if key is None: key = self.dates[0] date = key # TODO: consider ways to make this more robust if date in self.snapshots: snapshot = self.snapshots[date] else: dates = ', '.join(list(self.snapshots.keys())) errormsg = f'Could not find snapshot date {date}: choices are {self.dates}' raise sc.KeyNotFoundError(errormsg) return snapshot
[docs] class age_pyramid(Analyzer): ''' Constructs an age/sex pyramid at specified points within the sim. Can be used with data Args: timepoints (list): list of ints/strings/date objects, the days on which to take the snapshot die (bool): whether or not to raise an exception if a date is not found (default true) kwargs (dict): passed to Analyzer() **Example**:: sim = hpv.Sim(analyzers=hpv.age_pyramid('2015', '2020')) sim.run() age_pyramid = sim['analyzers'][0] ''' def __init__(self, timepoints=None, *args, edges=None, age_labels=None, datafile=None, die=False, **kwargs): super().__init__(**kwargs) # Initialize the Analyzer object self.timepoints = timepoints self.edges = edges # Edges of bins self.datafile = datafile # Data file to load self.bins = None # Age bins, calculated from edges self.data = None # Store the loaded data self.die = die # Whether or not to raise an exception self.dates = None # Representations in terms of years, e.g. 2020.4, set during initialization self.start = None # Store the start year of the simulation self.age_labels = age_labels # Labels for the age bins - will be automatically generated if not provided self.age_pyramids = sc.odict() # Store the age pyramids return
[docs] def initialize(self, sim): super().initialize() # Handle timepoints and dates self.start = sim['start'] # Store the simulation start self.end = sim['end'] # Store simulation end if self.timepoints is None: self.timepoints = [self.end] # If no day is supplied, use the last day self.timepoints, self.dates = sim.get_t(self.timepoints, return_date_format='str') # Ensure timepoints and dates are in the right format max_hist_time = self.timepoints[-1] max_sim_time = sim['end'] if max_hist_time > max_sim_time: errormsg = f'Cannot create histogram for {self.dates[-1]} ({max_hist_time}) because the simulation ends on {self.end} ({max_sim_time})' raise ValueError(errormsg) # Handle edges, age bins, and labels if self.edges is None: # Default age bins self.edges = np.linspace(0,100,11) self.bins = self.edges[:-1] # Don't include the last edge in the bins if self.age_labels is None: self.age_labels = [f'{int(self.edges[i])}-{int(self.edges[i+1])}' for i in range(len(self.edges)-1)] self.age_labels.append(f'{int(self.edges[-1])}+') # Handle the data file if self.datafile is not None: if sc.isstring(self.datafile): self.data = hpm.load_data(self.datafile, check_date=False) else: self.data = self.datafile # Use it directly self.datafile = None # Validate the data. Currently we only allow the same timepoints and age brackets data_dates = {str(float(i)) for i in self.data.year} if len(set(self.dates)-data_dates) or len(data_dates-set(self.dates)): string = f'Dates provided in the age pyramid datafile ({data_dates}) are not the same as the age pyramid dates that were requested ({self.dates}).' if self.die: raise ValueError(string) else: string += '\nPlots will only show requested dates, not all dates in the datafile.' print(string) self.data_dates = data_dates # Validate the edges - must be the same as requested edges from the model output data_edges = np.array(self.data.age.unique(), dtype=float) if not np.array_equal(np.sort(self.edges),np.sort(data_edges)): errormsg = f'Age bins provided in the age pyramid datafile ({data_edges}) are not the same as the age pyramid age bins that were requested ({self.edges}).' raise ValueError(errormsg) self.initialized = True return
[docs] def apply(self, sim): for ind in sc.findinds(self.timepoints, sim.t): date = self.dates[ind] self.age_pyramids[date] = sc.objdict() # Initialize the dictionary ppl = sim.people self.age_pyramids[date]['bins'] = self.bins # Copy here for convenience for sb,sex in enumerate(['m','f']): # Loop over each sex; sb stands for sex boolean, translating the labels to 0/1 inds = (sim.people.alive*(ppl.sex==sb)).nonzero()[0] self.age_pyramids[date][sex] = np.histogram(ppl.age[inds], bins=self.edges, weights=ppl.scale[inds])[0] # Bin people
[docs] def finalize(self, sim): super().finalize() validate_recorded_dates(sim, requested_dates=self.dates, recorded_dates=self.age_pyramids.keys(), die=self.die) return
[docs] @staticmethod def reduce(analyzers, use_mean=False, bounds=None, quantiles=None): ''' Create an averaged age pyramid from a list of age pyramid analyzers ''' # Process inputs for statistical calculations if use_mean: if bounds is None: bounds = 2 else: if quantiles is None: quantiles = {'low':0.1, 'high':0.9} if not isinstance(quantiles, dict): try: quantiles = {'low':float(quantiles[0]), 'high':float(quantiles[1])} except Exception as E: errormsg = f'Could not figure out how to convert {quantiles} into a quantiles object: must be a dict with keys low, high or a 2-element array ({str(E)})' raise ValueError(errormsg) # Check that a list of analyzers has been provided if not isinstance(analyzers, list): errormsg = 'age_pyramid.reduce() expects a list of age pyramid analyzers' raise TypeError(errormsg) # Check that everything in the list is an analyzer of the right type for analyzer in analyzers: if not isinstance(analyzer, age_pyramid): errormsg = 'All items in the list of analyzers provided to age_pyramid.reduce must be age pyramids' raise TypeError(errormsg) # Check that all the analyzers have the same timepoints and age bins base_analyzer = analyzers[0] for analyzer in analyzers: if not np.array_equal(analyzer.timepoints, base_analyzer.timepoints): errormsg = 'The list of analyzers provided to age_pyramid.reduce have different timepoints.' raise TypeError(errormsg) if not np.array_equal(analyzer.edges, base_analyzer.edges): errormsg = 'The list of analyzers provided to age_pyramid.reduce have different age bin edges.' raise TypeError(errormsg) # Initialize the reduced analyzer reduced_analyzer = sc.dcp(base_analyzer) reduced_analyzer.age_pyramids = sc.objdict() # Remove the age pyramids so we can rebuild them # Aggregate the list of analyzers raw = {} for date,tp in zip(base_analyzer.dates, base_analyzer.timepoints): raw[date] = {} # raw[date]['bins'] = analyzer.age_pyramids[date]['bins'] reduced_analyzer.age_pyramids[date] = sc.objdict() reduced_analyzer.age_pyramids[date]['bins'] = analyzer.age_pyramids[date]['bins'] for sk in ['f','m']: raw[date][sk] = np.zeros((len(base_analyzer.age_pyramids[date]['bins']), len(analyzers))) for a, analyzer in enumerate(analyzers): vals = analyzer.age_pyramids[date][sk] raw[date][sk][:, a] = vals # Summarizing the aggregated list reduced_analyzer.age_pyramids[date][sk] = sc.objdict() if use_mean: r_mean = np.mean(raw[date][sk], axis=1) r_std = np.std(raw[date][sk], axis=1) reduced_analyzer.age_pyramids[date][sk].best = r_mean reduced_analyzer.age_pyramids[date][sk].low = r_mean - bounds * r_std reduced_analyzer.age_pyramids[date][sk].high = r_mean + bounds * r_std else: reduced_analyzer.age_pyramids[date][sk].best = np.quantile(raw[date][sk], q=0.5, axis=1) reduced_analyzer.age_pyramids[date][sk].low = np.quantile(raw[date][sk], q=quantiles['low'], axis=1) reduced_analyzer.age_pyramids[date][sk].high = np.quantile(raw[date][sk], q=quantiles['high'], axis=1) return reduced_analyzer
[docs] def plot(self, m_color='#4682b4', f_color='#ee7989', fig_args=None, axis_args=None, data_args=None, percentages=True, do_save=None, fig_path=None, do_show=True, **kwargs): ''' Plot the age pyramids Args: m_color (hex or rgb): the color of the bars for males f_color (hex or rgb): the color of the bars for females fig_args (dict): passed to pl.figure() axis_args (dict): passed to pl.subplots_adjust() data_args (dict): 'width', 'color', and 'offset' arguments for the data percentages (bool): whether to plot the pyramid as percentages or numbers do_save (bool): whether to save fig_path (str or filepath): filepath to save to do_show (bool): whether to show the figure kwargs (dict): passed to ``hpv.options.with_style()``; see that function for choices ''' import seaborn as sns # Import here since slow # Handle inputs fig_args = sc.mergedicts(dict(figsize=(12,8)), fig_args) axis_args = sc.mergedicts(dict(left=0.08, right=0.92, bottom=0.08, top=0.92), axis_args) d_args = sc.objdict(sc.mergedicts(dict(width=0.3, color='#000000', offset=0), data_args)) all_args = sc.mergedicts(fig_args, axis_args, d_args) # Initialize fig = pl.figure(**fig_args) labels = list(reversed(self.age_labels)) # Set properties depending on data if self.data is None: # Simple case: just plot model output n_plots = len(self.timepoints) n_rows, n_cols = sc.get_rows_cols(n_plots) else: # Complex case: add data n_cols = 2 # Data plots go in the right column n_rows = len(self.timepoints) # We only show plots for requested timepoints # Handle windows and what to plot pyramidsdict = self.age_pyramids if not len(pyramidsdict): errormsg = f'Cannot plot since no age pyramids were recorded (scheduled timepoints: {self.timepoints})' raise ValueError(errormsg) # Make the figure(s) xlabel = 'Share of population by sex' if percentages else 'Population by sex' with hpo.with_style(**kwargs): count=1 for date,pyramid in pyramidsdict.items(): pl.subplots_adjust(**axis_args) bins = pyramid['bins'] # Prepare data pydf = pd.DataFrame(pyramid) if percentages: pydf['m'] = pydf['m'] / sum(pydf['m']) pydf['f'] = pydf['f'] / sum(pydf['f']) pydf['f']=-pydf['f'] # Reverse values for females to get on same axis # Start making plot ax = pl.subplot(n_rows, n_cols, count) sns.barplot(x='m', y='bins', data=pydf, order=np.flip(bins), orient='h', ax=ax, color=m_color) sns.barplot(x='f', y='bins', data=pydf, order=np.flip(bins), orient='h', ax=ax, color=f_color) ax.set_xlabel(xlabel) ax.set_ylabel('Age group') ax.set_yticklabels(labels[1:]) xticks = ax.get_xticks() if percentages: xlabels = [f'{abs(i):.2f}' for i in xticks] else: xlabels = [f'{sc.sigfig(abs(i), sigfigs=2, SI=True)}' for i in xticks] pl.xticks(xticks, xlabels) ax.set_title(f'{date}') count +=1 if self.data is not None: if date in self.data_dates: datadf = self.data[self.data.year==float(date)] # Consistent naming of males and females datadf.columns = datadf.columns.str[0] datadf.columns = datadf.columns.str.lower() if percentages: datadf = datadf.assign(m=datadf['m'] / sum(datadf['m']), f=datadf['f'] / sum(datadf['f'])) datadf = datadf.assign(f=-datadf['f']) # Start making plot ax = pl.subplot(n_rows, n_cols, count) sns.barplot(x='m', y='a', data=datadf, order=np.flip(bins), orient='h', ax=ax, color=m_color) sns.barplot(x='f', y='a', data=datadf, order=np.flip(bins), orient='h', ax=ax, color=f_color) ax.set_xlabel(xlabel) ax.set_ylabel('Age group') ax.set_yticklabels(labels[1:]) xticks = ax.get_xticks() if percentages: xlabels = [f'{abs(i):.2f}' for i in xticks] else: xlabels = [f'{sc.sigfig(abs(i), sigfigs=2, SI=True)}' for i in xticks] pl.xticks(xticks, xlabels) ax.set_title(f'{date} - data') count += 1 return hppl.tidy_up(fig, do_save=do_save, fig_path=fig_path, do_show=do_show, args=all_args)
[docs] class age_results(Analyzer): ''' Constructs results by age at specified points within the sim. Can be used with data Args: result_args (dict): dict of results to generate and associated years/age-bins to generate each result as well as whether to compute_fit die (bool): whether or not to raise an exception if errors are found kwargs (dict): passed to :py:class:`Analyzer` **Example**:: result_args=sc.objdict( hpv_prevalence=sc.objdict( timepoints=[1990], edges=np.array([0.,20.,25.,30.,40.,45.,50.,55.,65.,100.]), ), hpv_incidence=sc.objdict( timepoints=[1990, 2000], edges=np.array([0.,20.,30.,40.,50.,60.,70.,80.,100.]) ) sim = hpv.Sim(analyzers=hpv.age_results(result_args=result_args)) sim.run() age_results = sim['analyzers'][0] ''' def __init__(self, result_args=None, die=False, **kwargs): super().__init__(**kwargs) # Initialize the Analyzer object self.mismatch = 0 # TODO, should this be set to np.nan initially? self.die = die # Whether or not to raise an exception self.results = sc.objdict() # Store the age results self.result_args = result_args return
[docs] def initialize(self, sim): super().initialize() # Handle which results to make. Specification of the results to make is stored in result_args if sc.checktype(self.result_args, dict): # Ensure it's an object dict self.result_args = sc.objdict(self.result_args) else: # Raise an error errormsg = f'result_args must be a dict with keys for the years and edges you want to compute, not {type(self.result_args)}.' raise TypeError(errormsg) self.result_keys = self.result_args.keys() # Handle dt - if we're storing annual results we'll need to aggregate them over several consecutive timesteps self.dt = sim['dt'] self.resfreq = sim.resfreq # Store genotypes self.ng = sim['n_genotypes'] self.glabels = [g.upper() for g in sim['genotype_map'].values()] # Initialize result structure and validate the result variable arguments self.validate_variables(sim) # Store colors for rkey in self.result_args.keys(): if 'hiv' in rkey: self.result_args[rkey].color = sim.hivsim.results[rkey].color self.result_args[rkey].name = sim.hivsim.results[rkey].name else: self.result_args[rkey].color = sim.results[rkey].color self.result_args[rkey].name = sim.results[rkey].name self.initialized = True return
[docs] def validate_variables(self, sim): ''' Check that the variables in result_args are valid, and initialize the result structure ''' choices = sim.result_keys('total')+[k for k in sim.result_keys('genotype')] if sim['model_hiv']: choices += list(sim.hivsim.results.keys()) for rk, rdict in self.result_args.items(): if rk not in choices: strm = '\n'.join(choices) errormsg = f'Cannot compute age results for {rk}. Please enter one of the standard sim result_keys to the age_results analyzer; choices are {strm}.' raise ValueError(errormsg) else: self.results[rk] = dict() # Store the results. Not an odict because keyed by year # If a datafile has been provided, read it in and get the age bins and years if 'datafile' in rdict.keys(): if sc.isstring(rdict.datafile): rdict.data = hpm.load_data(rdict.datafile, check_date=False) else: rdict.data = rdict.datafile # Use it directly rdict.datafile = None # Get edges, age bins, and labels from datafile. This assumes # that the datafile has bins, and we make the edges by appending # the last point of the sim age bin edges. rdict.years = rdict.data.year.unique() rdict.bins = np.array(rdict.data.age.unique(), dtype=float) rdict.edges = np.append(rdict.bins, sim['age_bin_edges'][-1]) self.results[rk]['bins'] = rdict.bins else: # Use years and age bin edges provided, or use defaults from sim if (not hasattr(rdict,'edges')) or rdict.edges is None: # Default age bins warnmsg = f'Did not provide edges for age analyzer {rk}' if self.die: raise ValueError(warnmsg) else: warnmsg += ', using age bin edges from sim' hpm.warn(warnmsg) rdict.edges = sim['age_bin_edges'] rdict.bins = rdict.edges[:-1] # Don't include the last edge in the bins self.results[rk]['bins'] = rdict.bins if 'years' not in rdict.keys(): warnmsg = f'Did not provide years for age analyzer {rk}' if self.die: raise ValueError(warnmsg) else: warnmsg += ', using final year of sim' hpm.warn(warnmsg) rdict.years = sim['end'] rdict.years = sc.promotetoarray(rdict.years) # Construct age labels used for plotting rdict.age_labels = [f'{int(rdict.bins[i])}-{int(rdict.bins[i + 1])}' for i in range(len(rdict.bins) - 1)] rdict.age_labels.append(f'{int(rdict.bins[-1])}+') # Construct timepoints if not rdict.get('timepoints') or rdict.timepoints is None: rdict.timepoints = [] for y in rdict.years: rdict.timepoints.append(sc.findinds(sim.yearvec, y)[0] + int(1 / sim['dt']) - 1) # Check that the requested timepoints are in the sim max_hist_time = rdict.timepoints[-1] max_sim_time = sim.tvec[-1] if max_hist_time > max_sim_time: errormsg = f'Cannot create age results for {rdict.years[-1]} ({max_hist_time}) because the simulation ends on {self.end} ({max_sim_time})' raise ValueError(errormsg) # Translate the name of the result to the people attribute result_name = sc.dcp(rk) na = len(rdict.bins) ng = sim['n_genotypes'] # Clean up the name if 'genotype' in result_name: # Results by genotype result_name = result_name.replace('_by_genotype','') # remove "by_genotype" from result name rdict.size = (na, ng) rdict.by_genotype = True else: # Total results rdict.size = na rdict.by_genotype = False rdict.by_hiv = False if '_with_hiv' in result_name: result_name = result_name.replace('_with_hiv', '') # remove "_with_hiv" from result name rdict.by_hiv = True rdict.hiv_attr = 'hiv' elif '_no_hiv' in result_name: result_name = result_name.replace('_no_hiv', '') # remove "_no_hiv" from result name rdict.by_hiv = True # attr3 = ~ppl['hiv'] # Figure out if it's a flow or incidence if result_name in hpd.flow_keys or 'incidence' in result_name or 'mortality' in result_name: date_attr, attr = self.convert_rname_flows(result_name) rdict.result_type = 'flow' elif result_name[:2] == 'n_' or 'prevalence' in result_name: attr = self.convert_rname_stocks(result_name) # Convert to a people attribute date_attr = None rdict.result_type = 'stock' rdict.attr = attr rdict.date_attr = date_attr # Initialize results for year in rdict.years: self.results[rk][year] = np.zeros(rdict.size) # For flows, we calculate results on all the timepoints throughout the year, not just the last one if rdict.result_type == 'flow': rdict.calcpoints = [] rdict.calcpointyears = [] for tpi, tp in enumerate(rdict.timepoints): rdict.calcpoints += [tp+i+1 for i in range(-int(1/self.dt),0)] rdict.calcpointyears += [sim.yearvec[tp-(int(1/sim['dt'])-1)]]*int(1/self.dt) else: rdict.calcpoints = sc.dcp(rdict.timepoints) rdict.calcpointyears = [sim.yearvec[tp-(int(1/sim['dt'])-1)] for tp in rdict.calcpoints] if 'compute_fit' in rdict.keys() and rdict.compute_fit: if rdict.data is None: errormsg = 'Cannot compute fit without data' raise ValueError(errormsg) else: if 'weights' in rdict.data.columns: rdict.weights = rdict.data['weights'].values else: rdict.weights = np.ones(len(rdict.data)) rdict.mismatch = 0 # The final value
[docs] def convert_rname_stocks(self, rname): ''' Helper function for converting stock result names to people attributes ''' attr = rname.replace('_prevalence', '') # Strip out terms that aren't stored in the people if attr[0] == 'n': attr = attr[2:] # Remove n, used to identify stocks if attr == 'hpv': attr = 'infectious' # People with HPV are referred to as infectious in the sim if attr == 'cancer': attr = 'cancerous' return attr
[docs] def convert_rname_flows(self, rname): ''' Helper function for converting flow result names to people attributes ''' attr = rname.replace('_incidence', '') # Name of the actual state if attr == 'hpv': attr = 'infections' # HPV is referred to as infections in the sim if attr == 'cancer': attr = 'cancers' # cancer is referred to as cancers in the sim if attr == 'cancer_mortality': attr = 'cancer_deaths' # Handle variable names mapping = { 'infections': ['date_exposed', 'infectious'], 'cin': ['date_cin', 'cin'], 'dysplasias': ['date_cin', 'cin'], 'cins': ['date_cin', 'cin'], 'cancers': ['date_cancerous', 'cancerous'], 'cancer': ['date_cancerous', 'cancerous'], 'detected_cancer': ['date_detected_cancer', 'detected_cancer'], 'detected_cancers': ['date_detected_cancer', 'detected_cancer'], 'cancer_deaths': ['date_dead_cancer', 'dead_cancer'], 'detected_cancer_deaths': ['date_dead_cancer', 'dead_cancer'] } attr1 = mapping[attr][0] # Messy way of turning 'total cancers' into 'date_cancerous' and 'cancerous' etc attr2 = mapping[attr][1] # As above return attr1, attr2
[docs] def apply(self, sim): ''' Calculate age results ''' # Shorten variables that are used a lot ng = self.ng ppl = sim.people def bin_ages(inds=None, bins=None): return np.histogram(ppl.age[inds], bins=bins, weights=ppl.scale[inds])[0] # Bin the people # Go through each result key and determine if this is a timepoint where age results are requested for rkey, rdict in self.result_args.items(): # Establish initial quantities bins = rdict.edges na = len(rdict.bins) # Calculate flows and stocks over all calcpoints if sim.t in rdict.calcpoints: date_ind = sc.findinds(rdict.calcpoints, sim.t)[0] # Get the index date = rdict.calcpointyears[date_ind] # Create the date which will be used to key the results if 'compute_fit' in rdict.keys(): thisdatadf = rdict.data[(rdict.data.year == float(date)) & (rdict.data.name == rkey)] unique_genotypes = thisdatadf.genotype.unique() ng = len(unique_genotypes) # CAREFUL, THIS IS OVERWRITING # Figure out if it's a flow if rdict.result_type == 'flow': if not rdict.by_genotype: # Results across all genotypes if rkey == 'detected_cancer_deaths': inds = ((ppl[rdict.date_attr] == sim.t) * (ppl[rdict.attr]) * (ppl['detected_cancer'])).nonzero()[-1] else: if rdict.by_hiv: inds = ((ppl[rdict.date_attr] == sim.t) * (ppl[rdict.attr]) * (ppl[rdict.hiv_attr])).nonzero()[-1] else: inds = ((ppl[rdict.date_attr] == sim.t) * (ppl[rdict.attr])).nonzero()[-1] self.results[rkey][date] += bin_ages(inds, bins) # Bin the people else: # Results by genotype for g in range(ng): # Loop over genotypes inds = ((ppl[rdict.date_attr][g, :] == sim.t) * (ppl[rdict.attr][g, :])).nonzero()[-1] self.results[rkey][date][:, g] += bin_ages(inds, bins) # Bin the people # This section is completed for stocks elif rdict.result_type == 'stock': if not rdict.by_genotype: if rdict.by_hiv: inds = (ppl[rdict.attr].any(axis=0) * ppl[rdict.hiv_attr]).nonzero()[-1] elif isinstance(rdict.attr, list): inds = (ppl[rdict.attr[0]].any(axis=0) + ppl[rdict.attr[1]].any(axis=0)).nonzero()[-1] inds = np.unique(inds) else: inds = ppl[rdict.attr].any(axis=0).nonzero()[-1] self.results[rkey][date] = bin_ages(inds, bins) else: for g in range(ng): inds = ppl[rdict.attr][g, :].nonzero()[-1] self.results[rkey][date][g, :] = bin_ages(inds, bins) # Bin the people # On the final timepoint in the year, normalize if sim.t in rdict.timepoints: if 'prevalence' in rkey: if 'hpv' in rkey: # Denominator is whole population if rdict.by_hiv: inds = sc.findinds(ppl[rdict.hiv_attr]) denom = bin_ages(inds=inds, bins=bins) else: denom = bin_ages(inds=ppl.alive, bins=bins) else: # Denominator is females denom = bin_ages(inds=ppl.is_female_alive, bins=bins) if rdict.by_genotype: denom = denom[None, :] self.results[rkey][date] = self.results[rkey][date] / (denom) if 'incidence' in rkey: if 'hpv' in rkey: # Denominator is susceptible population inds = sc.findinds(ppl.is_female_alive & ~ppl.cancerous.any(axis=0)) denom = bin_ages(inds=hpu.true(ppl.sus_pool), bins=bins) else: # Denominator is females at risk for cancer if rdict.by_hiv: inds = sc.findinds(ppl.is_female_alive & ppl[rdict.hiv_attr] * ~ppl.cancerous.any(axis=0)) else: inds = sc.findinds(ppl.is_female_alive & ~ppl.cancerous.any(axis=0)) denom = bin_ages(inds, bins) / 1e5 # CIN and cancer are per 100,000 women # if 'total' not in result and 'cancer' not in result: denom = denom[None, :] # THIS IS IT!!!! self.results[rkey][date] = self.results[rkey][date] / denom if 'mortality' in rkey: # first need to find people who died of other causes today and add them back into denom denom = bin_ages(inds=ppl.is_female_alive, bins=bins) scale_factor = 1e5 # per 100,000 women denom /= scale_factor self.results[rkey][date] = self.results[rkey][date] / denom return
[docs] def finalize(self, sim): super().finalize() for rkey, rdict in self.result_args.items(): recorded_dates = [k for k in self.results[rkey].keys()][1:] validate_recorded_dates(sim, requested_dates=rdict.years, recorded_dates=recorded_dates, die=self.die) if 'compute_fit' in rdict.keys(): self.mismatch += self.compute_mismatch(rkey) # Add to sim.fit if hasattr(sim,'fit'): sim.fit += self.mismatch else: sim.fit = self.mismatch return
[docs] @staticmethod def reduce(analyzers, use_mean=False, bounds=None, quantiles=None): ''' Create an averaged age result from a list of age result analyzers ''' # Process inputs for statistical calculations if use_mean: if bounds is None: bounds = 2 else: if quantiles is None: quantiles = {'low':0.1, 'high':0.9} if not isinstance(quantiles, dict): try: quantiles = {'low':float(quantiles[0]), 'high':float(quantiles[1])} except Exception as E: errormsg = f'Could not figure out how to convert {quantiles} into a quantiles object: must be a dict with keys low, high or a 2-element array ({str(E)})' raise ValueError(errormsg) # Check that a list of analyzers has been provided if not isinstance(analyzers, list): errormsg = 'age_results.reduce() expects a list of age_results analyzers' raise TypeError(errormsg) # Check that everything in the list is an analyzer of the right type for analyzer in analyzers: if not isinstance(analyzer, age_results): errormsg = 'All items in the list of analyzers provided to age_results.reduce must be age_results instances' raise TypeError(errormsg) # Check that all the analyzers have the same timepoints and age bins base_analyzer = analyzers[0] for analyzer in analyzers: if set(analyzer.results.keys()) != set(base_analyzer.results.keys()): errormsg = 'The list of analyzers provided to age_results.reduce have different result keys.' raise ValueError(errormsg) for reskey in base_analyzer.results.keys(): if not np.array_equal(base_analyzer.result_args[reskey]['timepoints'],analyzer.result_args[reskey]['timepoints']): errormsg = 'The list of analyzers provided to age_results.reduce have different timepoints.' raise ValueError(errormsg) if not np.array_equal(base_analyzer.result_args[reskey]['edges'],analyzer.result_args[reskey]['edges']): errormsg = 'The list of analyzers provided to age_pyramid.reduce have different age bin edges.' raise ValueError(errormsg) # Initialize the reduced analyzer reduced_analyzer = sc.dcp(base_analyzer) reduced_analyzer.results = sc.objdict() # Remove the age results so we can rebuild them # Aggregate the list of analyzers raw = {} for reskey in base_analyzer.results.keys(): raw[reskey] = {} reduced_analyzer.results[reskey] = dict() reduced_analyzer.results[reskey]['bins'] = base_analyzer.results[reskey]['bins'] for year,tp in zip(base_analyzer.result_args[reskey].years, base_analyzer.result_args[reskey].timepoints): ashape = analyzer.results[reskey][year].shape # Figure out dimensions new_ashape = ashape + (len(analyzers),) raw[reskey][year] = np.zeros(new_ashape) for a, analyzer in enumerate(analyzers): vals = analyzer.results[reskey][year] if len(ashape) == 1: raw[reskey][year][:, a] = vals elif len(ashape) == 2: raw[reskey][year][:, :, a] = vals # Summarizing the aggregated list reduced_analyzer.results[reskey][year] = sc.objdict() if use_mean: r_mean = np.mean(raw[reskey][year], axis=-1) r_std = np.std(raw[reskey][year], axis=-1) reduced_analyzer.results[reskey][year].best = r_mean reduced_analyzer.results[reskey][year].low = r_mean - bounds * r_std reduced_analyzer.results[reskey][year].high = r_mean + bounds * r_std else: reduced_analyzer.results[reskey][year].best = np.quantile(raw[reskey][year], q=0.5, axis=-1) reduced_analyzer.results[reskey][year].low = np.quantile(raw[reskey][year], q=quantiles['low'], axis=-1) reduced_analyzer.results[reskey][year].high = np.quantile(raw[reskey][year], q=quantiles['high'], axis=-1) return reduced_analyzer
[docs] def compute_mismatch(self, key): ''' Compute mismatch between analyzer results and datafile''' res = [] resargs = self.result_args[key] results = self.results[key] for name, group in resargs.data.groupby(['genotype', 'year']): genotype = name[0] year = name[1] if 'genotype' in key: sim_res = list(results[year][self.glabels.index(genotype)]) res.extend(sim_res) else: sim_res = list(results[year]) res.extend(sim_res) self.result_args[key].data['model_output'] = res self.result_args[key].data['diffs'] = resargs.data['model_output'] - resargs.data['value'] self.result_args[key].data['gofs'] = hpm.compute_gof(resargs.data['value'].values, resargs.data['model_output'].values) self.result_args[key].data['losses'] = resargs.data['gofs'].values * resargs.weights self.result_args[key].mismatch = resargs.data['losses'].sum() return self.result_args[key].mismatch
[docs] def get_to_plot(self): ''' Get number of plots to make ''' if len(self.results) == 0: errormsg = 'Cannot plot since no age results were recorded)' raise ValueError(errormsg) else: years_per_result = [len(rk['years']) for rk in self.result_args.values()] n_plots = sum(years_per_result) to_plot_args = [] for rkey in self.result_keys: for year in self.result_args[rkey]['years']: to_plot_args.append([rkey,year]) return n_plots, to_plot_args
[docs] def plot_single(self, ax, rkey, date, by_genotype, plot_args=None, scatter_args=None): ''' Function to plot a single age result for a single date. Requires an axis as input and will generally be called by a helper function rather than directly. ''' args = sc.objdict() args.plot = sc.objdict(sc.mergedicts(dict(linestyle='--'), plot_args)) args.scatter = sc.objdict(sc.mergedicts(dict(marker='s'), scatter_args)) resdict = self.results[rkey] # Extract the result dictionary... resargs = self.result_args[rkey] # ... and the result arguments x = np.arange(len(resargs.age_labels)) # Create the label locations # Pull out data dataframe, if available if 'data' in resargs.keys(): thisdatadf = resargs.data[(resargs.data.year == float(date)) & (resargs.data.name == rkey)] unique_genotypes = thisdatadf.genotype.unique() # Plot by genotype if by_genotype: colors = sc.gridcolors(self.ng) # Overwrite default colors with genotype colors for g in range(self.ng): color = colors[g] glabel = self.glabels[g].upper() ax.plot(x, resdict[date][g,:], color=color, **args.plot, label=glabel) if ('data' in resargs.keys()) and (len(thisdatadf) > 0): # check if this genotype is in dataframe if self.glabels[g].upper() in unique_genotypes: ydata = np.array(thisdatadf[thisdatadf.genotype==self.glabels[g].upper()].value) ax.scatter(x, ydata, color=color, **args.scatter, label=f'Data - {glabel}') # Plot totals else: ax.plot(x, resdict[date].T, color=resargs.color, **args.plot, label='Model') if ('data' in resargs.keys()) and (len(thisdatadf) > 0): ydata = np.array(thisdatadf.value) ax.scatter(x, ydata, color=resargs.color, **args.scatter, label='Data') # Labels and legends ax.set_xlabel('Age') ax.set_title(f'{resargs.name} - {date}') ax.legend() pl.xticks(x, resargs.age_labels) return ax
[docs] def plot(self, fig_args=None, axis_args=None, plot_args=None, scatter_args=None, do_save=None, fig_path=None, do_show=True, fig=None, ax=None, **kwargs): ''' Plot the age results Args: fig_args (dict): passed to pl.figure() axis_args (dict): passed to pl.subplots_adjust() plot_args (dict): passed to plot_single scatter_args (dict): passed to plot_single do_save (bool): whether to save fig_path (str or filepath): filepath to save to do_show (bool): whether to show the figure kwargs (dict): passed to ``hpv.options.with_style()``; see that function for choices ''' # Handle inputs fig_args = sc.mergedicts(dict(figsize=(12,8)), fig_args) axis_args = sc.mergedicts(dict(left=0.08, right=0.92, bottom=0.08, top=0.92), axis_args) all_args = sc.mergedicts(fig_args, axis_args) # Initialize fig = pl.figure(**fig_args) n_plots, _ = self.get_to_plot() n_rows, n_cols = sc.get_rows_cols(n_plots) # Make the figure(s) with hpo.with_style(**kwargs): plot_count=1 for rkey,resdict in self.results.items(): pl.subplots_adjust(**axis_args) by_genotype=True if 'genotype' in rkey else False for year in self.result_args[rkey]['years']: ax = pl.subplot(n_rows, n_cols, plot_count) ax = self.plot_single(ax, rkey, year, by_genotype, plot_args=plot_args, scatter_args=scatter_args) plot_count+=1 return hppl.tidy_up(fig, do_save=do_save, fig_path=fig_path, do_show=do_show, args=all_args)
[docs] class age_causal_infection(Analyzer): ''' Determine the age at which people with cervical cancer were causally infected and time spent between infection and cancer. ''' def __init__(self, start_year=None, **kwargs): super().__init__(**kwargs) self.start_year = start_year self.years = None
[docs] def initialize(self, sim): super().initialize(sim) self.years = sim.yearvec if self.start_year is None: self.start_year = sim['start'] self.age_causal = [] self.age_cancer = [] self.dwelltime = dict() for state in ['precin', 'cin', 'total']: self.dwelltime[state] = []
[docs] def apply(self, sim): if sim.yearvec[sim.t] >= self.start_year: cancer_genotypes, cancer_inds = (sim.people.date_cancerous == sim.t).nonzero() if len(cancer_inds): current_age = sim.people.age[cancer_inds] date_exposed = sim.people.date_exposed[cancer_genotypes, cancer_inds] date_cin = sim.people.date_cin[cancer_genotypes, cancer_inds] hpv_time = (date_cin - date_exposed) * sim['dt'] cin_time = (sim.t - date_cin) * sim['dt'] total_time = (sim.t - date_exposed) * sim['dt'] self.age_causal += (current_age - total_time).tolist() self.age_cancer += current_age.tolist() self.dwelltime['precin'] += hpv_time.tolist() self.dwelltime['cin'] += cin_time.tolist() self.dwelltime['total'] += total_time.tolist() return
[docs] def finalize(self, sim=None): ''' Convert things to arrays '''
[docs] class cancer_detection(Analyzer): ''' Cancer detection via symptoms Args: symp_prob: Probability of having cancer detected via symptoms, rather than screening treat_prob: Probability of receiving treatment for those with symptom-detected cancer ''' def __init__(self, symp_prob=0.01, treat_prob=0.01, product=None, **kwargs): super().__init__(**kwargs) self.symp_prob = symp_prob self.treat_prob = treat_prob self.product = product or hpi.radiation()
[docs] def initialize(self, sim): super().initialize(sim) self.dt = sim['dt']
[docs] def apply(self, sim): ''' Check for new cancer detection, treat subset of detected cancers ''' cancer_genotypes, cancer_inds = sim.people.cancerous.nonzero() # Get everyone with cancer new_detections, new_treatments = 0, 0 if len(cancer_inds) > 0: detection_probs = np.full(len(cancer_inds), self.symp_prob / self.dt, dtype=hpd.default_float) # Initialize probabilities of cancer detection detection_probs[sim.people.detected_cancer[cancer_inds]] = 0 is_detected = hpu.binomial_arr(detection_probs) is_detected_inds = cancer_inds[is_detected] new_detections = len(is_detected_inds) if new_detections>0: sim.people.detected_cancer[is_detected_inds] = True sim.people.date_detected_cancer[is_detected_inds] = sim.t treat_probs = np.full(len(is_detected_inds), self.treat_prob) treat_inds = is_detected_inds[hpu.binomial_arr(treat_probs)] if len(treat_inds)>0: self.product.administer(sim, treat_inds) # Update flows sim.people.flows['detected_cancers'] = new_detections return new_detections, new_treatments
[docs] class daly_computation(Analyzer): ''' Analyzer for computing DALYs. Produces a dataframe by year storing: - Cases/deaths: number of new cancer cases and cancer deaths - Average age of new cases, average age of deaths, average age of noncancer death ''' def __init__(self, start, **kwargs): super().__init__(**kwargs) self.start = start return
[docs] def initialize(self, sim): super().initialize(sim) columns = ['new_cancers', 'new_cancer_deaths', 'new_other_deaths', 'av_age_cancers', 'av_age_cancer_deaths', 'av_age_other_deaths'] self.si = sc.findinds(sim.res_yearvec, self.start)[0] self.df = pd.DataFrame(0.0, index=pd.Index(sim.res_yearvec[self.si:], name='year'), columns=columns) return
[docs] def apply(self, sim): if sim.yearvec[sim.t] >= self.start: ppl = sim.people def av_age(arr): if len(hpu.true(arr)): return np.mean(sim.people.age[hpu.true(arr)]) else: return np.nan li = np.floor(sim.yearvec[sim.t]) lt = (sim.t - 1) self.df.loc[li].av_age_other_deaths = av_age(ppl.date_dead_other == lt) self.df.loc[li].av_age_cancer_deaths = av_age(ppl.date_dead_cancer == lt) self.df.loc[li].av_age_cancers = av_age(ppl.date_cancerous == lt) return
[docs] def finalize(self, sim): # Add in results that are already generated (NB, these have all been scaled already) self.df['new_cancers'] = sim.results['cancers'][self.si:] self.df['new_cancer_deaths'] = sim.results['cancer_deaths'][self.si:] self.df['new_other_deaths'] = sim.results['other_deaths'][self.si:] return
#%% Additional utilities analyzer_map = { 'snapshot': snapshot, 'age_pyramid': age_pyramid, 'age_results': age_results, 'age_causal_infection': age_causal_infection, 'cancer_detection': cancer_detection, 'daly_computation': daly_computation, }