'''
Defines the Sim class, the core class of the FP model (FPsim).
'''
# %% Imports
import numpy as np # Needed for a few things not provided by pl
import pylab as pl
import seaborn as sns
import sciris as sc
import pandas as pd
from .settings import options as fpo
from . import utils as fpu
from . import defaults as fpd
from . import base as fpb
from . import parameters as fpp
from . import people as fpppl
# Specify all externally visible things this file defines
__all__ = ['Sim', 'MultiSim', 'parallel']
#%% Plotting helper functions
def fixaxis(useSI=True, set_lim=True, legend=True):
""" Format the axis using SI units and limits """
if legend:
pl.legend() # Add legend
if set_lim:
sc.setylim()
if useSI:
sc.SIticks()
return
def tidy_up(fig, do_show=None, do_save=None, filename=None):
""" Helper function to handle the slightly complex logic of showing, saving, returing -- not for users """
# Handle inputs
if do_show is None: do_show = fpo.show
if do_save is None: do_save = fpo.save
backend = pl.get_backend()
# Handle show
if backend == 'agg': # Cannot show plots for a non-interactive backend
do_show = False
if do_show: # Now check whether to show, and atually do it
pl.show()
# Handle saving
if do_save:
if isinstance(do_save, str): # No figpath provided - see whether do_save is a figpath
filename = sc.makefilepath(filename) # Ensure it's valid, including creating the folder
sc.savefig(fig=fig, filename=filename) # Save the figure
# Handle close
if fpo.close and not do_show:
pl.close(fig)
# Return the figure or figures unless we're in Jupyter
if not fpo.returnfig:
return
else:
return fig
# %% Sim class
[docs]
class Sim(fpb.BaseSim):
"""
The Sim class handles the running of the simulation. This class handles the mechanics
of the actual simulation, while BaseSim takes care of housekeeping (saving,
loading, exporting, etc.). Please see the BaseSim class for additional methods.
When a Sim is initialized, it triggers the creation of the population. Methods related
to creating, initializing, and updating people can be found in the People class.
Args:
pars (dict): parameters to modify from their default values
location (str): name of the location (country) to look for data file to load
label (str): the name of the simulation (useful to distinguish in batch runs)
track_children (bool): whether to track links between mothers and their children (slow, so disabled by default)
kwargs (dict): additional parameters; passed to ``fp.make_pars()``
**Examples**::
sim = fp.Sim()
sim = fp.Sim(n_agents=10e3, location='senegal', label='My small Senegal sim')
"""
def __init__(self, pars=None, location=None, label=None, track_children=False, **kwargs):
# Handle location
if location is None:
if pars is not None and pars.get('location'):
location = pars.pop('location')
# Make parameters
pars = fpp.pars(location=location, **sc.mergedicts(pars, kwargs)) # Update with location-specific parameters
# Validate and initialize
mismatches = [key for key in kwargs.keys() if key not in fpp.par_keys]
if len(mismatches):
errormsg = f'Key(s) {mismatches} not found; available keys are {fpp.par_keys}'
raise sc.KeyNotFoundError(errormsg)
super().__init__(pars, location=location, **kwargs) # Initialize and set the parameters as attributes
self.initialized = False
self.already_run = False
self.test_mode = False
self.label = label
self.track_children = track_children
self.results = {}
self.people = None # Sims are generally constructed without people, since People construction is time-consuming
fpu.set_metadata(self) # Set version, date, and git info
self.summary = None
return
[docs]
def initialize(self, force=False):
""" Fully initialize the Sim with people and result storage"""
if force or not self.initialized:
fpu.set_seed(self['seed'])
self.init_results()
self.init_people()
return self
[docs]
def init_results(self):
"""
Initialize result storage. Most default results are either arrays or lists; these are
all stored in defaults.py. Any other results with different formats can also be added here.
"""
for key in fpd.array_results:
self.results[key] = np.zeros(int(self.npts))
for key in fpd.list_results:
self.results[key] = []
# Store age-specific fertility rates
self.results['asfr'] = {}
self.results['method_usage'] = []
for key in fpd.age_bin_map.keys():
self.results['asfr'][key] = []
self.results[f"tfr_{key}"] = []
if self['track_switching']:
m = len(self['methods']['map'])
keys = [
'switching_events_annual',
'switching_events_postpartum',
'switching_events_<18',
'switching_events_18-20',
'switching_events_21-25',
'switching_events_26-35',
'switching_events_>35',
'switching_events_pp_<18',
'switching_events_pp_18-20',
'switching_events_pp_21-25',
'switching_events_pp_26-35',
'switching_events_pp_>35',
]
for key in keys:
self.results[key] = {} # CK: TODO: refactor
for p in range(self.npts):
self.results[key][p] = np.zeros((m, m), dtype=int)
if self.pars['track_as']:
self.results['imr_age_by_group'] = []
self.results['mmr_age_by_group'] = []
self.results['stillbirth_ages'] = []
for age_specific_channel in fpd.by_age_results:
for age_group in fpd.age_specific_channel_bins:
if 'numerator' in age_specific_channel or 'denominator' in age_specific_channel or 'as_' in age_specific_channel:
self.results[age_specific_channel] = []
else:
self.results[f"{age_specific_channel}_{age_group}"] = []
return
[docs]
def init_people(self):
"""
Initialize people by calling the People constructor and initialization methods.
See people.py for details of people construction.
"""
self.people = fpppl.People(pars=self.pars)
[docs]
def update_methods(self):
"""
Update all contraceptive method matrices to have probabilities that follow a trend closest to the
year the sim is on based on mCPR in that year
"""
methods = self['methods'] # Shorten methods
methods['adjusted'] = sc.dcp(methods['raw']) # Avoids needing to copy this within loops later
# Compute the trend in MCPR
trend_years = methods['mcpr_years']
trend_vals = methods['mcpr_rates']
ind = sc.findnearest(trend_years, self.y) # The year of data closest to the sim year
norm_ind = sc.findnearest(trend_years, self['mcpr_norm_year']) # The year we're using to normalize
nearest_val = trend_vals[ind] # Nearest MCPR value from the data
norm_val = trend_vals[norm_ind] # Normalization value
if self.y > max(trend_years): # We're after the last year of data: extrapolate
eps = 1e-3 # Epsilon for lowest allowed MCPR value (to avoid divide by zero errors)
nearest_year = trend_years[ind]
year_diff = self.y - nearest_year
correction = self['mcpr_growth_rate'] * year_diff # Project the change in MCPR
extrapolated_val = nearest_val * (1 + correction) # Multiply the current value by the projection
trend_val = np.clip(extrapolated_val, eps, self['mcpr_max']) # Ensure it stays within bounds
else: # Otherwise, just use the nearest data point
trend_val = nearest_val
norm_trend_val = trend_val / norm_val # Normalize so the correction factor is 1 at the normalization year
# Update annual (non-postpartum) population and postpartum switching matrices for current year mCPR - stratified by age
for switchkey in ['annual', 'pp1to6']:
for matrix in methods['adjusted'][switchkey].values():
matrix[0, 0] /= norm_trend_val # Takes into account mCPR during year of sim
for i in range(len(matrix)):
denom = matrix[i, :].sum()
if denom > 0:
matrix[i, :] = matrix[i, :] / denom # Normalize so probabilities add to 1
# Update postpartum initiation matrices for current year mCPR - stratified by age
for matrix in methods['adjusted']['pp0to1'].values():
matrix[0] /= norm_trend_val # Takes into account mCPR during year of sim
matrix /= matrix.sum()
return
[docs]
def update_mortality(self):
"""
Update infant and maternal mortality for the sim's current year.
Update general mortality trend as this uses a spline interpolation instead of an array.
"""
mapping = {
'age_mortality': 'gen_trend',
'infant_mortality': 'infant',
'maternal_mortality': 'maternal',
'stillbirth_rate': 'stillbirth',
}
self['mortality_probs'] = {}
for key1, key2 in mapping.items():
ind = sc.findnearest(self[key1]['year'], self.y)
val = self[key1]['probs'][ind]
self['mortality_probs'][key2] = val
return
[docs]
def update_mothers(self):
"""
Add link between newly added individuals and their mothers
"""
all_ppl = self.people.unfilter()
for mother_index, postpartum in enumerate(all_ppl.postpartum):
if postpartum and all_ppl.postpartum_dur[mother_index] < 2:
for child in all_ppl.children[mother_index]:
all_ppl.mothers[child] = mother_index
return
[docs]
def apply_interventions(self):
""" Apply each intervention in the model """
from . import interventions as fpi # To avoid circular import
for i, intervention in enumerate(sc.tolist(self['interventions'])):
if isinstance(intervention, fpi.Intervention):
if not intervention.initialized: # pragma: no cover
intervention.initialize(self)
intervention.apply(self) # If it's an intervention, call the apply() method
elif callable(intervention):
intervention(self) # If it's a function, call it directly
else: # pragma: no cover
errormsg = f'Intervention {i} ({intervention}) is neither callable nor an Intervention object: it is {type(intervention)}'
raise TypeError(errormsg)
return
[docs]
def apply_analyzers(self):
""" Apply each analyzer in the model """
from . import analyzers as fpa # To avoid circular import
for i, analyzer in enumerate(sc.tolist(self['analyzers'])):
if isinstance(analyzer, fpa.Analyzer):
if not analyzer.initialized: # pragma: no cover
analyzer.initialize(self)
analyzer.apply(self) # If it's an intervention, call the apply() method
elif callable(analyzer):
analyzer(self) # If it's a function, call it directly
else: # pragma: no cover
errormsg = f'Analyzer {i} ({analyzer}) is neither callable nor an Analyzer object: it is {type(analyzer)}'
raise TypeError(errormsg)
return
[docs]
def finalize_interventions(self):
""" Make any final updates to interventions (e.g. to shrink) """
from . import interventions as fpi # To avoid circular import
for intervention in sc.tolist(self['interventions']):
if isinstance(intervention, fpi.Intervention):
intervention.finalize(self)
[docs]
def finalize_analyzers(self):
""" Make any final updates to analyzers (e.g. to shrink) """
from . import analyzers as fpa # To avoid circular import
for analyzer in sc.tolist(self['analyzers']):
if isinstance(analyzer, fpa.Analyzer):
analyzer.finalize(self)
[docs]
def finalize_people(self):
"""Clean up and reset people's attributes at the end of a time step"""
if not self.track_children:
delattr(self.people, "mothers")
[docs]
def grow_population(self, new_ppl):
"""Expand people's size"""
# Births
people = fpppl.People(pars=self.pars, n=new_ppl)
self.people += people
[docs]
def step(self):
"""Update logic of a single time step"""
# Update method matrices for year of sim to trend over years
self.update_methods()
# Update mortality probabilities for year of sim
self.update_mortality()
# Apply interventions and analyzers
self.apply_interventions()
self.apply_analyzers()
# Update the people
self.people.i = self.i
self.people.t = self.t
step_results = self.people.update()
r = sc.dictobj(**step_results)
new_people = r.births - r.infant_deaths # Do not add agents who died before age 1 to population
self.grow_population(new_people)
# Update mothers
if self.track_children:
self.update_mothers()
return r
[docs]
def run(self, verbose=None):
""" Run the simulation """
# Initialize -- reset settings and results
T = sc.timer()
if verbose is None:
verbose = self['verbose']
self.initialize()
if self.already_run:
errormsg = 'Cannot re-run an already run sim; please recreate or copy prior to a run'
raise RuntimeError(errormsg)
# Main simulation loop
for i in range(self.npts): # Range over number of timesteps in simulation (ie, 0 to 261 steps)
self.i = i # Timestep - RS TODO, do these need to be set as attributes?
self.t = self.ind2year(i) # time elapsed in years given how many timesteps have passed (ie, 25.75 years)
self.y = self.ind2calendar(i) # y is calendar year of timestep (ie, 1975.75)
# Print progress
elapsed = T.toc(output=True)
if verbose:
simlabel = f'"{self.label}": ' if self.label else ''
string = f' Running {simlabel}{self.y:0.0f} of {self["end_year"]} ({i:2.0f}/{self.npts}) ({elapsed:0.2f} s) '
if verbose >= 2:
sc.heading(string)
elif verbose > 0:
if not (self.t % int(1.0 / verbose)):
sc.progressbar(self.i + 1, self.npts, label=string, length=20, newline=True)
# Actually update the model
res = self.step()
# Results
self.update_results(res, i)
# Finalize people
self.finalize_people()
# Finalize results, interventions and analyzers
self.finalize_results()
self.finalize_interventions()
self.finalize_analyzers()
if verbose:
print(f'Final population size: {self.n}.')
elapsed = T.toc(output=True)
print(f'Run finished for "{self.label}" after {elapsed:0.1f} s')
self.summary = sc.objdict()
self.summary.births = np.sum(self.results['births'])
self.summary.deaths = np.sum(self.results['deaths'])
self.summary.final = self.results['pop_size'][-1]
self.already_run = True
return self
def update_results(self, r, i):
percent0to5 = (r.pp0to5 / r.total_women_fecund) * 100
percent6to11 = (r.pp6to11 / r.total_women_fecund) * 100
percent12to23 = (r.pp12to23 / r.total_women_fecund) * 100
nonpostpartum = ((r.total_women_fecund - r.pp0to5 - r.pp6to11 - r.pp12to23) / r.total_women_fecund) * 100
# Store results
if self['scaled_pop']:
scale = self['scaled_pop'] / self['n_agents']
else:
scale = 1
self.results['t'][i] = self.tvec[i]
self.results['pop_size_months'][i] = self.n * scale
self.results['births'][i] = r.births * scale
self.results['deaths'][i] = r.deaths * scale
self.results['stillbirths'][i] = r.stillbirths * scale
self.results['miscarriages'][i] = r.miscarriages * scale
self.results['abortions'][i] = r.abortions * scale
self.results['short_intervals'][i] = r.short_intervals * scale
self.results['secondary_births'][i] = r.secondary_births * scale
self.results['pregnancies'][i] = r.pregnancies * scale
self.results['total_births'][i] = r.total_births * scale
self.results['maternal_deaths'][i] = r.maternal_deaths * scale
self.results['infant_deaths'][i] = r.infant_deaths * scale
self.results['on_methods_mcpr'][i] = r.on_methods_mcpr
self.results['no_methods_mcpr'][i] = r.no_methods_mcpr
self.results['on_methods_cpr'][i] = r.on_methods_cpr
self.results['no_methods_cpr'][i] = r.no_methods_cpr
self.results['on_methods_acpr'][i] = r.on_methods_acpr
self.results['no_methods_acpr'][i] = r.no_methods_acpr
self.results['mcpr'][i] = r.on_methods_mcpr / (r.no_methods_mcpr + r.on_methods_mcpr)
self.results['cpr'][i] = r.on_methods_cpr / (r.no_methods_cpr + r.on_methods_cpr)
self.results['acpr'][i] = r.on_methods_acpr / (r.no_methods_acpr + r.on_methods_acpr)
self.results['pp0to5'][i] = percent0to5
self.results['pp6to11'][i] = percent6to11
self.results['pp12to23'][i] = percent12to23
self.results['nonpostpartum'][i] = nonpostpartum
self.results['total_women_fecund'][i] = r.total_women_fecund * scale
self.results['unintended_pregs'][i] = r.unintended_pregs * scale
if self.pars['track_as']:
for age_specific_channel in ['imr_numerator', 'imr_denominator', 'mmr_numerator', 'mmr_denominator',
'as_stillbirths', 'imr_age_by_group', 'mmr_age_by_group',
'stillbirth_ages']:
self.results[f"{age_specific_channel}"].append(getattr(r, f"{age_specific_channel}"))
if len(self.results[f"{age_specific_channel}"]) > 12:
self.results[f"{age_specific_channel}"] = self.results[f"{age_specific_channel}"][1:]
for age_specific_channel in ['acpr', 'cpr', 'mcpr', 'pregnancies', 'births']:
for method_agekey in fpd.age_specific_channel_bins:
self.results[f"{age_specific_channel}_{method_agekey}"].append(
getattr(r, f"{age_specific_channel}_{method_agekey}"))
for agekey in fpd.age_bin_map.keys():
births_key = f'total_births_{agekey}'
women_key = f'total_women_{agekey}'
self.results[births_key][i] = r.birth_bins[
agekey] * scale # Store results of total births per age bin for ASFR
self.results[women_key][i] = r.age_bin_totals[
agekey] * scale # Store results of total fecund women per age bin for ASFR
# Store results of number of switching events in each age group
if self['track_switching']:
switch_events = r.pop('switching')
self.results['switching_events_<18'][i] = scale * r.switching_annual['<18']
self.results['switching_events_18-20'][i] = scale * r.switching_annual['18-20']
self.results['switching_events_21-25'][i] = scale * r.switching_annual['21-25']
self.results['switching_events_26-35'][i] = scale * r.switching_annual['26-35']
self.results['switching_events_>35'][i] = scale * r.switching_annual['>35']
self.results['switching_events_pp_<18'][i] = scale * r.switching_postpartum['<18']
self.results['switching_events_pp_18-20'][i] = scale * r.switching_postpartum['18-20']
self.results['switching_events_pp_21-25'][i] = scale * r.switching_postpartum['21-25']
self.results['switching_events_pp_26-35'][i] = scale * r.switching_postpartum['26-35']
self.results['switching_events_pp_>35'][i] = scale * r.switching_postpartum['>35']
self.results['switching_events_annual'][i] = scale * switch_events['annual']
self.results['switching_events_postpartum'][i] = scale * switch_events['postpartum']
# Calculate metrics over the last year in the model and save whole years and stats to an array
if i % fpd.mpy == 0:
self.results['tfr_years'].append(self.y)
start_index = (int(self.t) - 1) * fpd.mpy
stop_index = int(self.t) * fpd.mpy
unintended_pregs_over_year = scale * np.sum(self.results['unintended_pregs'][
start_index:stop_index]) # Grabs sum of unintended pregnancies due to method failures over the last 12 months of calendar year
infant_deaths_over_year = scale * np.sum(self.results['infant_deaths'][start_index:stop_index])
total_births_over_year = scale * np.sum(self.results['total_births'][start_index:stop_index])
live_births_over_year = scale * np.sum(self.results['births'][start_index:stop_index])
stillbirths_over_year = scale * np.sum(self.results['stillbirths'][start_index:stop_index])
miscarriages_over_year = scale * np.sum(self.results['miscarriages'][start_index:stop_index])
abortions_over_year = scale * np.sum(self.results['abortions'][start_index:stop_index])
short_intervals_over_year = scale * np.sum(self.results['short_intervals'][start_index:stop_index])
secondary_births_over_year = scale * np.sum(self.results['secondary_births'][start_index:stop_index])
maternal_deaths_over_year = scale * np.sum(self.results['maternal_deaths'][start_index:stop_index])
pregnancies_over_year = scale * np.sum(self.results['pregnancies'][start_index:stop_index])
self.results['method_usage'].append(self.compute_method_usage()) # only want this per year
self.results['pop_size'].append(scale * self.n) # CK: TODO: replace with arrays
self.results['mcpr_by_year'].append(self.results['mcpr'][i])
self.results['cpr_by_year'].append(self.results['cpr'][i])
self.results['method_failures_over_year'].append(unintended_pregs_over_year)
self.results['infant_deaths_over_year'].append(infant_deaths_over_year)
self.results['total_births_over_year'].append(total_births_over_year)
self.results['live_births_over_year'].append(live_births_over_year)
self.results['stillbirths_over_year'].append(stillbirths_over_year)
self.results['miscarriages_over_year'].append(miscarriages_over_year)
self.results['abortions_over_year'].append(abortions_over_year)
self.results['short_intervals_over_year'].append(short_intervals_over_year)
self.results['secondary_births_over_year'].append(secondary_births_over_year)
self.results['maternal_deaths_over_year'].append(maternal_deaths_over_year)
self.results['pregnancies_over_year'].append(pregnancies_over_year)
if self.pars['track_as']:
imr_results_dict = self.people.log_age_split(binned_ages_t=self.results['imr_age_by_group'],
channel='imr',
numerators=self.results['imr_numerator'],
denominators=self.results['imr_denominator'])
mmr_results_dict = self.people.log_age_split(binned_ages_t=self.results['mmr_age_by_group'],
channel='mmr',
numerators=self.results['mmr_numerator'],
denominators=self.results['mmr_denominator'])
stillbirths_results_dict = self.people.log_age_split(binned_ages_t=self.results['stillbirth_ages'],
channel='stillbirths',
numerators=self.results['as_stillbirths'],
denominators=None)
for age_key in fpd.age_specific_channel_bins:
self.results[f"imr_{age_key}"].append(imr_results_dict[f"imr_{age_key}"])
self.results[f"mmr_{age_key}"].append(mmr_results_dict[f"mmr_{age_key}"])
self.results[f"stillbirths_{age_key}"].append(
stillbirths_results_dict[f"stillbirths_{age_key}"])
if maternal_deaths_over_year == 0:
self.results['mmr'].append(0)
else:
maternal_mortality_ratio = maternal_deaths_over_year / live_births_over_year * 100000
self.results['mmr'].append(maternal_mortality_ratio)
if infant_deaths_over_year == 0:
self.results['imr'].append(infant_deaths_over_year)
else:
infant_mortality_rate = infant_deaths_over_year / live_births_over_year * 1000
self.results['imr'].append(infant_mortality_rate)
if secondary_births_over_year == 0:
self.results['proportion_short_interval_by_year'].append(secondary_births_over_year)
else:
short_interval_proportion = (short_intervals_over_year / secondary_births_over_year)
self.results['proportion_short_interval_by_year'].append(short_interval_proportion)
tfr = 0
for key in fpd.age_bin_map.keys():
age_bin_births_year = np.sum(self.results['total_births_' + key][start_index:stop_index])
age_bin_total_women_year = self.results['total_women_' + key][stop_index]
age_bin_births_per_woman = sc.safedivide(age_bin_births_year, age_bin_total_women_year)
self.results['asfr'][key].append(age_bin_births_per_woman * 1000)
self.results[f'tfr_{key}'].append(age_bin_births_per_woman * 1000)
tfr += age_bin_births_per_woman # CK: TODO: check if this is right
self.results['tfr_rates'].append(
tfr * 5) # CK: TODO: why *5? # SB: I think this corresponds to size of age bins?
def finalize_results(self):
# Convert all results to Numpy arrays
for key, arr in self.results.items():
if isinstance(arr, list):
# These keys have list of lists with different lengths
if key in ['imr_numerator', 'imr_denominator', 'mmr_numerator', 'mmr_denominator', 'imr_age_by_group',
'mmr_age_by_group', 'as_stillbirths', 'stillbirth_ages']:
self.results[key] = np.array(arr, dtype=object)
else:
self.results[key] = np.array(arr) # Convert any lists to arrays
# Calculate cumulative totals
self.results['cum_maternal_deaths_by_year'] = np.cumsum(self.results['maternal_deaths_over_year'])
self.results['cum_infant_deaths_by_year'] = np.cumsum(self.results['infant_deaths_over_year'])
self.results['cum_live_births_by_year'] = np.cumsum(self.results['live_births_over_year'])
self.results['cum_stillbirths_by_year'] = np.cumsum(self.results['stillbirths_over_year'])
self.results['cum_miscarriages_by_year'] = np.cumsum(self.results['miscarriages_over_year'])
self.results['cum_abortions_by_year'] = np.cumsum(self.results['abortions_over_year'])
self.results['cum_short_intervals_by_year'] = np.cumsum(self.results['short_intervals_over_year'])
self.results['cum_secondary_births_by_year'] = np.cumsum(self.results['secondary_births_over_year'])
self.results['cum_pregnancies_by_year'] = np.cumsum(self.results['pregnancies_over_year'])
# Convert to an objdict for easier access
self.results = sc.objdict(self.results)
[docs]
def store_postpartum(self):
"""
Stores snapshot of who is currently pregnant, their parity, and various
postpartum states in final step of model for use in calibration
"""
min_age = 12.5
max_age = self['age_limit_fecundity']
ppl = self.people
rows = []
for i in range(len(ppl)):
if ppl.alive[i] and ppl.sex[i] == 0 and min_age <= ppl.age[i] < max_age:
row = dict(
Age=int(round(ppl.age[i])),
PP0to5=None,
PP6to11=None,
PP12to23=None,
NonPP=1 if not ppl.postpartum[i] else 0,
Pregnant=1 if ppl.pregnant[i] else 0,
Parity=ppl.parity[i],
)
if ppl.postpartum[i]:
pp_dur = ppl.postpartum_dur[i]
row['PP0to5'] = 1 if 0 <= pp_dur < 6 else 0
row['PP6to11'] = 1 if 6 <= pp_dur < 12 else 0
row['PP12to23'] = 1 if 12 <= pp_dur <= 24 else 0
rows.append(row)
pp = pd.DataFrame(rows, index=None,
columns=['Age', 'PP0to5', 'PP6to11', 'PP12to23', 'NonPP', 'Pregnant', 'Parity'])
pp.fillna(0, inplace=True)
return pp
[docs]
def to_df(self, include_range=False):
"""
Export all sim results to a dataframe
Args:
include_range (bool): if True, and if the sim results have best, high, and low, then export all of them; else just best
"""
raw_res = sc.odict(defaultdict=list)
for reskey in self.results.keys():
res = self.results[reskey]
if isinstance(res, dict):
for blh, blhres in res.items(): # Best, low, high
if len(blhres) == self.npts:
if not include_range and blh != 'best':
continue
if include_range:
blhkey = f'{reskey}_{blh}'
else:
blhkey = reskey
raw_res[blhkey] += blhres.tolist()
elif sc.isarray(res) and len(res) == self.npts:
raw_res[reskey] += res.tolist()
df = pd.DataFrame(raw_res)
self.df = df
return df
# Function to scale all y-axes in fig based on input channel
@staticmethod
def conform_y_axes(figure, bottom=0, top=100):
for axes in figure.axes:
axes.set_ylim([bottom, top])
return figure
[docs]
def plot(self, to_plot=None, xlims=None, ylims=None, do_save=None, do_show=True, filename='fpsim.png', style=None,
fig_args=None,
plot_args=None, axis_args=None, fill_args=None, label=None, new_fig=True, colors=None):
"""
Plot the results -- can supply arguments for both the figure and the plots.
Args:
to_plot (str/dict): What to plot (e.g. 'default' or 'cpr'), or a dictionary of result:label pairs
xlims (list/dict): passed to pl.xlim() (use ``[None, None]`` for default)
ylims (list/dict): passed to pl.ylim()
do_save (bool): Whether or not to save the figure. If a string, save to that filename.
do_show (bool): Whether to show the plots at the end
filename (str): If a figure is saved, use this filename
style (bool): Custom style arguments
fig_args (dict): Passed to pl.figure() (plus ``nrows`` and ``ncols`` for overriding defaults)
plot_args (dict): Passed to pl.plot()
axis_args (dict): Passed to pl.subplots_adjust()
fill_args (dict): Passed to pl.fill_between())
label (str): Label to override default
new_fig (bool): Whether to create a new figure (true unless part of a multisim)
colors (list/dict): Colors for plots with multiple lines
"""
if to_plot is None: to_plot = 'default'
fig_args = sc.mergedicts(dict(figsize=(16, 10), nrows=None, ncols=None), fig_args)
plot_args = sc.mergedicts(dict(lw=2, alpha=0.7), plot_args)
axis_args = sc.mergedicts(dict(left=0.1, bottom=0.05, right=0.9, top=0.97, wspace=0.2, hspace=0.25), axis_args)
fill_args = sc.mergedicts(dict(alpha=0.2), fill_args)
with fpo.with_style(style):
nrows, ncols = fig_args.pop('nrows'), fig_args.pop('ncols')
fig = pl.figure(**fig_args) if new_fig else pl.gcf()
pl.subplots_adjust(**axis_args)
if to_plot is not None and 'as_' in to_plot:
nrows, ncols = 2, 3
res = self.results # Shorten since heavily used
method_age_groups = list(fpd.age_specific_channel_bins.keys())
if self.pars['track_as']:
no_plot_age = method_age_groups[-1]
method_age_groups.remove(no_plot_age)
delete_keys = [] # to avoid mutating dict during iteration
for key in res:
if no_plot_age in key:
delete_keys.append(key)
for bad_key in delete_keys:
res.remove(bad_key)
agelim = ('-'.join([str(self.pars['low_age_short_int']), str(
self.pars['high_age_short_int'])])) ## age limit to be added to the title of short birth interval plot
# Plot everything
if ('as_' in to_plot and not self.pars['track_as']):
raise ValueError(f"Age specific plot selected but sim.pars['track_as'] is False")
if isinstance(to_plot, dict):
pass
elif isinstance(to_plot, str):
if to_plot == 'default':
to_plot = {
'mcpr_by_year': 'Modern contraceptive prevalence rate (%)',
'cum_live_births_by_year': 'Live births',
'cum_stillbirths_by_year': 'Stillbirths',
'cum_maternal_deaths_by_year': 'Maternal deaths',
'cum_infant_deaths_by_year': 'Infant deaths',
'imr': 'Infant mortality rate',
}
elif to_plot == 'cpr':
to_plot = {
'mcpr': 'MCPR (modern contraceptive prevalence rate)',
'cpr': 'CPR (contraceptive prevalence rate)',
'acpr': 'ACPR (alternative contraceptive prevalence rate',
}
elif to_plot == 'mortality':
to_plot = {
'mmr': 'Maternal mortality ratio',
'cum_maternal_deaths_by_year': 'Maternal deaths',
'cum_infant_deaths_by_year': 'Infant deaths',
'imr': 'Infant mortality rate',
}
elif to_plot == 'apo': #adverse pregnancy outcomes
to_plot = {
'cum_pregnancies_by_year': 'Pregnancies',
'cum_stillbirths_by_year': 'Stillbirths',
'cum_miscarriages_by_year': 'Miscarriages',
'cum_abortions_by_year': 'Abortions',
}
elif to_plot == 'method':
to_plot = {
'method_usage': 'Method usage'
}
elif to_plot == 'short-interval':
to_plot = {
'proportion_short_interval_by_year': f"Proportion of short birth interval [{age_group})" for age_group in agelim.split()
}
elif to_plot == 'as_cpr':
to_plot = {f"cpr_{age_group}": f"Contraceptive Prevalence Rate ({age_group})" for age_group in method_age_groups}
elif to_plot == 'as_acpr':
to_plot = {f"acpr_{age_group}": f"Alternative Contraceptive Prevalence Rate ({age_group})" for age_group in method_age_groups}
elif to_plot == 'as_mcpr':
to_plot = {f"mcpr_{age_group}": f"Modern Contraceptive Prevalence Rate ({age_group})" for age_group in method_age_groups}
elif to_plot == 'as_pregnancies':
to_plot = {f"pregnancies_{age_group}": f"Number of Pregnancies for ({age_group})" for age_group in method_age_groups}
elif to_plot == 'as_tfr':
to_plot = {f"tfr_{age_group}": f"Fertility Rate for ({age_group})" for age_group in fpd.age_bin_map}
elif to_plot == 'as_imr':
to_plot = {f"imr_{age_group}": f"Infant Mortality Rate for ({age_group})" for age_group in method_age_groups}
elif to_plot == 'as_mmr':
to_plot = {f"mmr_{age_group}": f"Maternal Mortality Rate for ({age_group})" for age_group in method_age_groups}
elif to_plot == 'as_stillbirths':
to_plot = {f"stillbirths_{age_group}": f"Stillbirths for ({age_group})" for age_group in method_age_groups}
elif to_plot == 'as_births':
to_plot = {f"births_{age_group}": f"Live births for ({age_group})" for age_group in method_age_groups}
elif to_plot is not None:
errormsg = f"Your to_plot value: {to_plot} is not a valid option"
raise ValueError(errormsg)
else:
errmsg = f"to_plot can be a dictionary or a string. A {type(to_plot)} is not a valid option."
raise TypeError(errmsg)
rows, cols = sc.getrowscols(len(to_plot), nrows=nrows, ncols=ncols)
if to_plot == 'cpr':
rows, cols = 1, 3
for p, key, reslabel in sc.odict(to_plot).enumitems():
ax = pl.subplot(rows, cols, p + 1)
this_res = res[key]
is_dist = hasattr(this_res, 'best')
if is_dist:
y, low, high = this_res.best, this_res.low, this_res.high
else:
y, low, high = this_res, None, None
years = res['tfr_years']
# Figure out x axis
years = res['tfr_years']
timepoints = res['t'] # Likewise
x = None
for x_opt in [years, timepoints]:
if len(y) == len(x_opt):
x = x_opt
break
if x is None:
errormsg = f'Could not figure out how to plot {key}: result of length {len(y)} does not match a known x-axis'
raise RuntimeError(errormsg)
percent_keys = ['mcpr_by_year', 'mcpr', 'cpr', 'acpr', 'method_usage',
'proportion_short_interval_by_year']
if (
'cpr_' in key or 'acpr_' in key or 'mcpr_' in key or 'proportion_short_interval_' in key) and 'by_year' not in key:
percent_keys = percent_keys + list(to_plot.keys())
if key in percent_keys and key != 'method_usage':
y *= 100
if is_dist:
low *= 100
high *= 100
# Handle label
if label is not None:
plotlabel = label
else:
if new_fig: # It's a new figure, use the result label
plotlabel = reslabel
else: # Replace with sim label to avoid duplicate labels
plotlabel = self.label
# Actually plot
if key == "method_usage":
data = self.format_method_df(timeseries=True)
method_names = data['Method'].unique()
flipped_data = {method: [percentage for percentage in data[data['Method'] == method]['Percentage']]
for method in method_names}
colors = [colors[method] for method in method_names] if isinstance(colors, dict) else colors
ax.stackplot(data["Year"].unique(), list(flipped_data.values()), labels=method_names, colors=colors)
else:
ax.plot(x, y, label=plotlabel, **plot_args)
if is_dist:
if 'c' in plot_args:
fill_args['facecolor'] = plot_args['c']
ax.fill_between(x, low, high, **fill_args)
# Plot interventions, if present
# for intv in sc.tolist(self['interventions']):
# if hasattr(intv, 'plot_intervention'): # Don't plot e.g. functions
# intv.plot_intervention(self, ax)
# Handle annotations
as_plot = (
'cpr_' in key or 'acpr_' in key or 'mcpr_' in key or 'pregnancies_' in key or 'stillbirths' in key or 'tfr_' in key or 'imr_' in key or 'mmr_' in key or 'births_' in key or 'proportion_short_interval_' in key) and 'by_year' not in key
fixaxis(useSI=fpd.useSI, set_lim=new_fig) # If it's not a new fig, don't set the lim
if key in percent_keys:
pl.ylabel('Percentage')
elif 'mmr' in key:
pl.ylabel('Deaths per 100,000 live births')
elif 'imr' in key:
pl.ylabel('Deaths per 1,000 live births')
elif 'tfr_' in key:
pl.ylabel('Fertility rate per 1,000 women')
elif 'mmr_' in key:
pl.ylabel('Maternal deaths per 10,000 births')
elif 'stillbirths_' in key:
pl.ylabel('Number of stillbirths')
else:
pl.ylabel('Count')
pl.xlabel('Year')
pl.title(reslabel, fontweight='bold')
if xlims is not None:
pl.xlim(xlims)
if ylims is not None:
pl.ylim(ylims)
if (key == "method_usage") or as_plot: # need to overwrite legend for some plots
ax.legend(loc='upper left', frameon=True)
if 'cpr' in to_plot and '_' not in to_plot:
if is_dist:
top = int(np.ceil(max(self.results['acpr'].high) / 10.0)) * 10 # rounding up to nearest 10
else:
top = int(np.ceil(max(self.results['acpr']) / 10.0)) * 10
self.conform_y_axes(figure=fig, top=top)
if as_plot: # this condition is impossible if self.pars['track_as']
channel_type = key.split("_")[0]
tfr_scaling = 'tfr_' in key
age_bins = fpd.age_bin_map if tfr_scaling else fpd.age_specific_channel_bins
age_bins = {bin: interval for bin, interval in age_bins.items() if no_plot_age not in bin}
if is_dist:
top = max([max(group_result) for group_result in
[res[f'{channel_type}_{age_group}'].high for age_group in age_bins]])
else:
top = max([max(group_result) for group_result in
[res[f'{channel_type}_{age_group}'] for age_group in age_bins]])
tidy_top = int(np.ceil(top / 10.0)) * 10
tidy_top = tidy_top + 20 if tfr_scaling or 'imr_' in key else tidy_top
tidy_top = tidy_top + 50 if 'mmr_' in key else tidy_top
self.conform_y_axes(figure=fig, top=tidy_top)
return tidy_up(fig=fig, do_show=do_show, do_save=do_save, filename=filename)
[docs]
def plot_age_first_birth(self, do_show=None, do_save=None, fig_args=None, filename="first_birth_age.png"):
"""
Plot age at first birth
Args:
fig_args (dict): arguments to pass to ``pl.figure()``
do_show (bool): whether or not the user wants to show the output plot (default: true)
do_save (bool): whether or not the user wants to save the plot to filepath (default: false)
filename (str): the name of the path to output the plot
"""
birth_age = self.people.first_birth_age
data = birth_age[birth_age > 0]
fig = pl.figure(**sc.mergedicts(dict(figsize=(7, 5)), fig_args))
pl.title("Age at first birth")
sns.boxplot(x=data, orient='v', notch=True)
pl.xlabel('Age (years')
return tidy_up(fig=fig, do_show=do_show, do_save=do_save, filename=filename)
[docs]
def compute_method_usage(self):
"""
Computes method mix proportions from a sim object
Returns:
list of lists where list[years_after_start][method_index] == proportion of
fecundity aged women using that method on that year
"""
ppl = self.people
min_age = fpd.min_age
max_age = self['age_limit_fecundity']
# filtering for women with appropriate characteristics
bool_list = ppl.alive * [sex == 0 for sex in ppl.sex] * [min_age <= age for age in ppl.age] * [age < max_age for
age in ppl.age]
filtered_methods = [method for index, method in enumerate(ppl.method) if bool_list[index]]
unique, counts = np.unique(filtered_methods, return_counts=True)
count_dict = dict(zip(unique, counts))
result = [0] * (len(self.pars['methods']['eff']))
for method in count_dict:
result[method] = count_dict[method] / len(filtered_methods)
return result
# %% Multisim and running
[docs]
class MultiSim(sc.prettyobj):
"""
The MultiSim class handles the running of multiple simulations
"""
def __init__(self, sims=None, base_sim=None, label=None, n=None, **kwargs):
# Handle inputs
if base_sim is None:
if isinstance(sims, Sim):
base_sim = sims
sims = None
elif isinstance(sims, list):
base_sim = sims[0]
else:
errormsg = f'If base_sim is not supplied, sims must be either a single sim (treated as base_sim) or a list of sims, not {type(sims)}'
raise TypeError(errormsg)
# Set properties
self.sims = sims
self.base_sim = base_sim
self.label = base_sim.label if (label is None and base_sim is not None) else label
self.run_args = sc.mergedicts(kwargs)
self.results = None
self.which = None # Whether the multisim is to be reduced, combined, etc.
self.already_run = False
fpu.set_metadata(self) # Set version, date, and git info
return
def __len__(self):
if isinstance(self.sims, list):
return len(self.sims)
elif isinstance(self.sims, Sim):
return 1
else:
return 0
[docs]
def run(self, compute_stats=True, **kwargs):
""" Run all simulations in the MultiSim """
# Handle missing labels
for s, sim in enumerate(sc.tolist(self.sims)):
if sim.label is None:
sim.label = f'Sim {s}'
# Run
if self.already_run:
errormsg = 'Cannot re-run an already run MultiSim'
raise RuntimeError(errormsg)
self.sims = multi_run(self.sims, **kwargs)
# Recompute stats
if compute_stats:
self.compute_stats()
self.already_run = True
return self
[docs]
def compute_stats(self, return_raw=False, quantiles=None, use_mean=False, bounds=None):
""" Compute statistics across multiple sims """
if use_mean:
if bounds is None:
bounds = 1
else:
if quantiles is None:
quantiles = {'low': 0.1, 'high': 0.9}
if not isinstance(quantiles, dict):
try:
quantiles = {'low': float(quantiles[0]), 'high': float(quantiles[1])}
except Exception as E:
errormsg = f'Could not figure out how to convert {quantiles} into a quantiles object: must be a dict with keys low, high or a 2-element array ({str(E)})'
raise ValueError(errormsg)
base_sim = sc.dcp(self.sims[0])
raw = sc.objdict()
results = sc.objdict()
axis = 1
start_end = np.array([sim.tvec[[0, -1]] for sim in self.sims])
if len(np.unique(start_end)) != 2:
errormsg = f'Cannot compute stats for sims: start and end values do not match:\n{start_end}'
raise ValueError(errormsg)
reskeys = list(base_sim.results.keys())
if self.sims[0].pars['track_as']:
for bad_key in ['imr_numerator', 'imr_denominator', 'mmr_numerator', 'mmr_denominator', 'imr_age_by_group',
'mmr_age_by_group', 'as_stillbirths', 'stillbirth_ages']:
reskeys.remove(bad_key) # these keys are intermediate results so we don't really want to save them
bad_keys = ['t', 'tfr_years', 'method_usage']
for key in bad_keys: # Don't compute high/low for these
results[key] = base_sim.results[key]
reskeys.remove(key)
for reskey in reskeys:
if isinstance(base_sim.results[reskey], dict):
if return_raw:
for s, sim in enumerate(self.sims):
raw[reskey][s] = base_sim.results[reskey]
else:
results[reskey] = sc.objdict()
npts = len(base_sim.results[reskey])
raw[reskey] = np.zeros((npts, len(self.sims)))
for s, sim in enumerate(self.sims):
raw[reskey][:, s] = sim.results[reskey] # Stack into an array for processing
if use_mean:
r_mean = np.mean(raw[reskey], axis=axis)
r_std = np.std(raw[reskey], axis=axis)
results[reskey].best = r_mean
results[reskey].low = r_mean - bounds * r_std
results[reskey].high = r_mean + bounds * r_std
else:
results[reskey].best = np.quantile(raw[reskey], q=0.5, axis=axis)
results[reskey].low = np.quantile(raw[reskey], q=quantiles['low'], axis=axis)
results[reskey].high = np.quantile(raw[reskey], q=quantiles['high'], axis=axis)
self.results = results
self.base_sim.results = results # Store here too, to enable plotting
if return_raw:
return raw
else:
return
[docs]
@staticmethod
def merge(*args, base=False):
"""
Convenience method for merging two MultiSim objects.
Args:
args (MultiSim): the MultiSims to merge (either a list, or separate)
base (bool): if True, make a new list of sims from the multisim's two base sims; otherwise, merge the multisim's lists of sims
Returns:
msim (MultiSim): a new MultiSim object
**Examples**::
mm1 = fp.MultiSim.merge(msim1, msim2, base=True)
mm2 = fp.MultiSim.merge([m1, m2, m3, m4], base=False)
"""
# Handle arguments
if len(args) == 1 and isinstance(args[0], list):
args = args[0] # A single list of MultiSims has been provided
# Create the multisim from the base sim of the first argument
msim = MultiSim(base_sim=sc.dcp(args[0].base_sim), sims=[], label=args[0].label)
msim.sims = []
msim.chunks = [] # This is used to enable automatic splitting later
# Handle different options for combining
if base: # Only keep the base sims
for i, ms in enumerate(args):
sim = sc.dcp(ms.base_sim)
sim.label = ms.label
msim.sims.append(sim)
msim.chunks.append([[i]])
else: # Keep all the sims
for ms in args:
len_before = len(msim.sims)
msim.sims += list(sc.dcp(ms.sims))
len_after = len(msim.sims)
msim.chunks.append(list(range(len_before, len_after)))
return msim
[docs]
def split(self, inds=None, chunks=None):
"""
Convenience method for splitting one MultiSim into several. You can specify
either individual indices of simulations to extract, via inds, or consecutive
chunks of indices, via chunks. If this function is called on a merged MultiSim,
the chunks can be retrieved automatically and no arguments are necessary.
Args:
inds (list): a list of lists of indices, with each list turned into a MultiSim
chunks (int or list): if an int, split the MultiSim into that many chunks; if a list return chunks of that many sims
Returns:
A list of MultiSim objects
**Examples**::
m1 = fp.MultiSim(fp.Sim(label='sim1'))
m2 = fp.MultiSim(fp.Sim(label='sim2'))
m3 = fp.MultiSim.merge(m1, m2)
m3.run()
m1b, m2b = m3.split()
msim = fp.MultiSim(fp.Sim(), n_runs=6)
msim.run()
m1, m2 = msim.split(inds=[[0,2,4], [1,3,5]])
mlist1 = msim.split(chunks=[2,4]) # Equivalent to inds=[[0,1], [2,3,4,5]]
mlist2 = msim.split(chunks=2) # Equivalent to inds=[[0,1,2], [3,4,5]]
"""
# Process indices and chunks
if inds is None: # Indices not supplied
if chunks is None: # Chunks not supplied
if hasattr(self, 'chunks'): # Created from a merged MultiSim
inds = self.chunks
else: # No indices or chunks and not created from a merge
errormsg = 'If a MultiSim has not been created via merge(), you must supply either inds or chunks to split it'
raise ValueError(errormsg)
else: # Chunks supplied, but not inds
inds = [] # Initialize
sim_inds = np.arange(len(self)) # Indices for the simulations
if sc.isiterable(chunks): # e.g. chunks = [2,4]
chunk_inds = np.cumsum(chunks)[:-1]
inds = np.split(sim_inds, chunk_inds)
else: # e.g. chunks = 3
inds = np.split(sim_inds, chunks) # This will fail if the length is wrong
# Do the conversion
mlist = []
for indlist in inds:
sims = sc.dcp([self.sims[i] for i in indlist])
msim = MultiSim(sims=sims)
mlist.append(msim)
return mlist
[docs]
def remerge(self, base=True, recompute=True, **kwargs):
"""
Split a sim, compute stats, and re-merge.
Args:
base (bool): whether to use the base sim (otherwise, has no effect)
kwargs (dict): passed to msim.split()
recompute (bool): whether to run compute_statson each sim
Note: returns a new MultiSim object (if that concerns you).
"""
ms = self.split(**kwargs)
if recompute:
for m in ms:
m.compute_stats() # Recompute the statistics on each separate MultiSim
out = MultiSim.merge(*ms, base=base) # Now re-merge, this time using the base_sim
return out
[docs]
def to_df(self, yearly=False, mean=False):
"""
Export all individual sim results to a dataframe
"""
if mean:
df = self.base_sim.to_df()
else:
raw_res = sc.odict(defaultdict=list)
for s, sim in enumerate(self.sims):
for reskey in sim.results.keys():
res = sim.results[reskey]
if sc.isarray(res):
if len(res) == sim.npts and not yearly:
raw_res[reskey] += res.tolist()
elif len(res) == len(sim.results['tfr_years']) and yearly:
raw_res[reskey] += res.tolist()
scale = len(sim.results['tfr_years']) if yearly else sim.npts
raw_res['sim'] += [s] * scale
raw_res['sim_label'] += [sim.label] * scale
df = pd.DataFrame(raw_res)
self.df = df
return df
[docs]
def plot(self, to_plot=None, plot_sims=True, do_save=None, filename='fp_multisim.png',
fig_args=None, axis_args=None, plot_args=None, style=None, colors=None, **kwargs):
"""
Plot the MultiSim
Args:
plot_sims (bool): whether to plot individual sims (else, plot with uncertainty bands)
See ``sim.plot()`` for additional args.
"""
fig_args = sc.mergedicts(dict(figsize=(16, 10)), fig_args)
no_plot_age = list(fpd.age_specific_channel_bins.keys())[-1]
fig = pl.figure(**fig_args)
do_show = kwargs.pop('do_show', True)
labels = sc.autolist()
labellist = sc.autolist() # TODO: shouldn't need this
for sim in self.sims: # Loop over and find unique labels
if sim.label not in labels:
labels += sim.label
labellist += sim.label
label = sim.label
else:
labellist += ''
n_unique = len(np.unique(labels)) # How many unique sims there are
def get_scale_ceil(channel):
is_dist = hasattr(self.sims[0].results['acpr'], 'best') # picking a random channel
if is_dist:
maximum_value = max([max(sim.results[channel].high) for sim in self.sims if no_plot_age not in channel])
else:
maximum_value = max([max(sim.results[channel]) for sim in self.sims if no_plot_age not in channel])
top = int(np.ceil(maximum_value / 10.0)) * 10 # rounding up to nearest 10
return top
if to_plot == 'method':
axis_args_method = sc.mergedicts(dict(left=0.1, bottom=0.05, right=0.9, top=0.97, wspace=0.2, hspace=0.30),
axis_args)
with fpo.with_style(style):
pl.subplots_adjust(**axis_args_method)
for axis_index, label in enumerate(np.unique(labels)):
total_df = pd.DataFrame()
return_default = lambda name: fig_args[name] if name in fig_args else None
rows, cols = sc.getrowscols(n_unique, nrows=return_default('nrows'), ncols=return_default('ncols'))
ax = pl.subplot(rows, cols, axis_index + 1)
for sim in self.sims:
if sim.label == label:
total_df = pd.concat([total_df, sim.format_method_df(timeseries=True)], ignore_index=True)
method_names = total_df['Method'].unique()
# Getting the mean of each seed as a list of lists, could add conditional here if different method plots are added
percentage_by_method = []
for method in method_names:
method_df = total_df[(total_df['Method'] == method) & (total_df['Sim'] == label)]
seed_split = [method_df[method_df['Seed'] == seed]['Percentage'].values for seed in
method_df['Seed'].unique()]
percentage_by_method.append(
[np.mean([seed[i] for seed in seed_split]) for i in range(len(seed_split[0]))])
legend = axis_index + 1 == cols # True for last plot in first row
colors = [colors[method] for method in method_names] if isinstance(colors, dict) else colors
ax.stackplot(total_df["Year"].unique(), percentage_by_method, labels=method_names, colors=colors)
ax.set_title(label.capitalize())
ax.legend().set_visible(legend)
ax.set_xlabel('Year')
ax.set_ylabel('Percentage')
if legend:
ax.legend(loc='lower left', bbox_to_anchor=(1, -0.05), frameon=True) if len(
labels) > 1 else ax.legend(loc='upper left', frameon=True)
pl.ylim(0, max(
max([sum(proportion[1:] * 100) for proportion in results['method_usage']]) for results in
[sim.results for sim in self.sims]) + 1)
return tidy_up(fig=fig, do_show=do_show, do_save=do_save, filename=filename)
elif plot_sims:
colors = sc.gridcolors(n_unique)
colors = {k: c for k, c in zip(labels, colors)}
for s, sim in enumerate(self.sims): # Note: produces duplicate legend entries
label = labellist[s]
color = colors[sim.label]
alpha = max(0.2, 1 / np.sqrt(n_unique))
sim_plot_args = sc.mergedicts(dict(alpha=alpha, c=color), plot_args)
kw = dict(new_fig=False, do_show=False, label=label, plot_args=sim_plot_args)
sim.plot(to_plot=to_plot, **kw, **kwargs)
if to_plot is not None:
# Scale axes
if to_plot == 'cpr':
fig = self.base_sim.conform_y_axes(figure=fig, top=get_scale_ceil('acpr'))
if 'as_' in to_plot:
channel_type = to_plot.split("_")[1]
is_tfr = "tfr" in to_plot
age_bins = list(fpd.age_specific_channel_bins)[:-1]
if is_tfr:
age_bins = fpd.age_bin_map
if hasattr(sim.results[f'cpr_{list(fpd.age_specific_channel_bins.keys())[0]}'],
'best'): # if compute_stats has been applied
top = max([max([max(group_result) for group_result in
[sim.results[f'{channel_type}_{age_group}'].high for age_group in age_bins]])
for sim in self.sims])
else:
top = max([max([max(group_result) for group_result in
[sim.results[f'{channel_type}_{age_group}'] for age_group in age_bins]]) for sim
in self.sims])
tidy_top = int(np.ceil(top / 10.0)) * 10 # rounds top of y axis up to the nearest ten
tidy_top = tidy_top + 20 if is_tfr or 'imr' in to_plot else tidy_top # some custom axis adjustments for neatness
tidy_top = tidy_top + 50 if 'mmr' in to_plot else tidy_top
self.base_sim.conform_y_axes(figure=fig, top=tidy_top)
return tidy_up(fig=fig, do_show=do_show, do_save=do_save, filename=filename)
else:
return self.base_sim.plot(to_plot=to_plot, do_show=do_show, fig_args=fig_args, plot_args=plot_args,
**kwargs)
def plot_age_first_birth(self, do_show=False, do_save=True, output_file='age_first_birth_multi.png'):
length = sum([len([num for num in sim.people.first_birth_age if num is not None]) for sim in self.sims])
data_dict = {"age": [0] * length, "sim": [0] * length}
i = 0
for sim in self.sims:
for value in [num for num in sim.people.first_birth_age if num is not None]:
data_dict['age'][i] = value
data_dict['sim'][i] = sim.label
i = i + 1
data = pd.DataFrame(data_dict)
pl.title("Age at first birth")
sns.boxplot(data=data, y='age', x='sim', orient='v', notch=True)
if do_show:
pl.show()
if do_save:
print(f"Saved age at first birth plot at {output_file}")
pl.savefig(output_file)
def single_run(sim):
""" Helper function for multi_run(); rarely used on its own """
sim.run()
return sim
def multi_run(sims, **kwargs):
""" Run multiple sims in parallel; usually used via the MultiSim class, not directly """
sims = sc.parallelize(single_run, iterarg=sims, **kwargs)
return sims
[docs]
def parallel(*args, **kwargs):
"""
A shortcut to ``fp.MultiSim()``, allowing the quick running of multiple simulations
at once.
Args:
args (list): The simulations to run
kwargs (dict): passed to multi_run()
Returns:
A run MultiSim object.
**Examples**::
s1 = fp.Sim(exposure_factor=0.5, label='Low')
s2 = fp.Sim(exposure_factor=2.0, label='High')
fp.parallel(s1, s2).plot()
msim = fp.parallel(s1, s2)
"""
sims = sc.mergelists(*args)
return MultiSim(sims=sims).run(**kwargs)