'''
Define the calibration class
'''
import os
import numpy as np
import pylab as pl
import pandas as pd
import sciris as sc
from . import misc as hpm
from . import plotting as hppl
from . import analysis as hpa
from . import parameters as hppar
from .settings import options as hpo # For setting global options
__all__ = ['Calibration']
def import_optuna():
''' A helper function to import Optuna, which is an optional dependency '''
try:
import optuna as op # Import here since it's slow
except ModuleNotFoundError as E: # pragma: no cover
errormsg = f'Optuna import failed ({str(E)}), please install first (pip install optuna)'
raise ModuleNotFoundError(errormsg)
return op
[docs]
class Calibration(sc.prettyobj):
'''
A class to handle calibration of HPVsim simulations. Uses the Optuna hyperparameter
optimization library (optuna.org), which must be installed separately (via
pip install optuna).
Note: running a calibration does not guarantee a good fit! You must ensure that
you run for a sufficient number of iterations, have enough free parameters, and
that the parameters have wide enough bounds. Please see the tutorial on calibration
for more information.
Args:
sim (Sim) : the simulation to calibrate
datafiles (list) : list of datafile strings to calibrate to
calib_pars (dict) : a dictionary of the parameters to calibrate of the format dict(key1=[best, low, high])
genotype_pars(dict) : a dictionary of the genotype-specific parameters to calibrate of the format dict(genotype=dict(key1=[best, low, high]))
hiv_pars (dict) : a dictionary of the hiv-specific parameters to calibrate of the format dict(key1=[best, low, high])
extra_sim_results (list) : list of result strings to store
fit_args (dict) : a dictionary of options that are passed to sim.compute_fit() to calculate the goodness-of-fit
par_samplers (dict) : an optional mapping from parameters to the Optuna sampler to use for choosing new points for each; by default, suggest_float
n_trials (int) : the number of trials per worker
n_workers (int) : the number of parallel workers (default: maximum
total_trials (int) : if n_trials is not supplied, calculate by dividing this number by n_workers)
name (str) : the name of the database (default: 'hpvsim_calibration')
db_name (str) : the name of the database file (default: 'hpvsim_calibration.db')
keep_db (bool) : whether to keep the database after calibration (default: false)
storage (str) : the location of the database (default: sqlite)
rand_seed (int) : if provided, use this random seed to initialize Optuna runs (for reproducibility)
label (str) : a label for this calibration object
die (bool) : whether to stop if an exception is encountered (default: false)
verbose (bool) : whether to print details of the calibration
kwargs (dict) : passed to hpv.Calibration()
Returns:
A Calibration object
**Example**::
sim = hpv.Sim(pars, genotypes=[16, 18])
calib_pars = dict(beta=[0.05, 0.010, 0.20],hpv_control_prob=[.9, 0.5, 1])
calib = hpv.Calibration(sim, calib_pars=calib_pars,
datafiles=['test_data/south_africa_hpv_data.xlsx',
'test_data/south_africa_cancer_data.xlsx'],
total_trials=10, n_workers=4)
calib.calibrate()
calib.plot()
'''
def __init__(self, sim, datafiles, calib_pars=None, genotype_pars=None, hiv_pars=None, fit_args=None, extra_sim_result_keys=None,
par_samplers=None, n_trials=None, n_workers=None, total_trials=None, name=None, db_name=None, estimator=None,
keep_db=None, storage=None, rand_seed=None, sampler=None, label=None, die=False, verbose=True):
import multiprocessing as mp # Import here since it's also slow
# Handle run arguments
if n_trials is None: n_trials = 20
if n_workers is None: n_workers = mp.cpu_count()
if name is None: name = 'hpvsim_calibration'
if db_name is None: db_name = f'{name}.db'
if keep_db is None: keep_db = False
if storage is None: storage = f'sqlite:///{db_name}'
if total_trials is not None: n_trials = int(np.ceil(total_trials/n_workers))
self.run_args = sc.objdict(n_trials=int(n_trials), n_workers=int(n_workers), name=name, db_name=db_name,
keep_db=keep_db, storage=storage, rand_seed=rand_seed, sampler=sampler)
# Handle other inputs
self.label = label
self.sim = sim
self.calib_pars = calib_pars
self.genotype_pars = genotype_pars
self.hiv_pars = hiv_pars
self.extra_sim_result_keys = extra_sim_result_keys
self.fit_args = sc.mergedicts(fit_args)
self.par_samplers = sc.mergedicts(par_samplers)
self.die = die
self.verbose = verbose
self.calibrated = False
# Create age_results intervention
self.target_data = []
for datafile in datafiles:
self.target_data.append(hpm.load_data(datafile))
sim_results = sc.objdict()
age_result_args = sc.objdict()
# Go through each of the target keys and determine how we are going to get the results from sim
for targ in self.target_data:
targ_keys = targ.name.unique()
if len(targ_keys) > 1:
errormsg = f'Only support one set of targets per datafile, {len(targ_keys)} provided'
raise ValueError(errormsg)
if 'age' in targ.columns:
age_result_args[targ_keys[0]] = sc.objdict(
datafile=sc.dcp(targ),
compute_fit=True,
)
else:
sim_results[targ_keys[0]] = sc.objdict(
data=sc.dcp(targ)
)
ar = hpa.age_results(result_args=age_result_args)
self.sim['analyzers'] += [ar]
if hiv_pars is not None:
self.sim['model_hiv'] = True # if calibrating HIV parameters, make sure model is running HIV
self.sim.initialize()
for rkey in sim_results.keys():
sim_results[rkey].timepoints = sim.get_t(sim_results[rkey].data.year.unique()[0], return_date_format='str')[0]//sim.resfreq
if 'weights' not in sim_results[rkey].data.columns:
sim_results[rkey].weights = np.ones(len(sim_results[rkey].data))
self.age_results_keys = age_result_args.keys()
self.sim_results = sim_results
self.sim_results_keys = sim_results.keys()
self.result_args = sc.objdict()
for rkey in self.age_results_keys + self.sim_results_keys:
self.result_args[rkey] = sc.objdict()
if 'hiv' in rkey:
self.result_args[rkey].name = self.sim.hivsim.results[rkey].name
self.result_args[rkey].color = self.sim.hivsim.results[rkey].color
else:
self.result_args[rkey].name = self.sim.results[rkey].name
self.result_args[rkey].color = self.sim.results[rkey].color
if self.extra_sim_result_keys:
for rkey in self.extra_sim_result_keys:
self.result_args[rkey] = sc.objdict()
self.result_args[rkey].name = self.sim.results[rkey].name
self.result_args[rkey].color = self.sim.results[rkey].color
# Temporarily store a filename
self.tmp_filename = 'tmp_calibration_%05i.obj'
return
[docs]
def run_sim(self, calib_pars=None, genotype_pars=None, hiv_pars=None, label=None, return_sim=False):
''' Create and run a simulation '''
sim = sc.dcp(self.sim)
if label: sim.label = label
new_pars = self.get_full_pars(sim=sim, calib_pars=calib_pars, genotype_pars=genotype_pars, hiv_pars=hiv_pars)
sim.update_pars(new_pars)
sim.initialize(reset=True, init_analyzers=False) # Necessary to reinitialize the sim here so that the initial infections get the right parameters
# Run the sim
try:
sim.run()
if return_sim:
return sim
else:
return sim.fit
except Exception as E:
if self.die:
raise E
else:
warnmsg = f'Encountered error running sim!\nParameters:\n{new_pars}\nTraceback:\n{sc.traceback()}'
hpm.warn(warnmsg)
output = None if return_sim else np.inf
return output
[docs]
@staticmethod
def update_dict_pars(name_pars, value_pars):
''' Function to update parameters from nested dict to nested dict's value '''
new_pars = sc.dcp(name_pars)
target_pars_flatten = sc.flattendict(value_pars)
for key, val in target_pars_flatten.items():
try:
sc.setnested(new_pars, list(key), val)
except Exception as e:
errormsg = f"Parameter {'_'.join(key)} is not part of the sim, nor is a custom function specified to use them"
raise ValueError(errormsg)
return new_pars
[docs]
def update_dict_pars_from_trial(self, name_pars, value_pars):
''' Function to update parameters from nested dict to trial parameter's value '''
# new_pars = sc.dcp(name_pars)
new_pars = {}
name_pars_keys = sc.flattendict(name_pars).keys()
for key in name_pars_keys:
name = '_'.join(key)
sc.setnested(new_pars, list(key), value_pars[name])
return new_pars
[docs]
def update_dict_pars_init_and_bounds(self, initial_pars, par_bounds, target_pars):
''' Function to update initial parameters and parameter bounds from a trial pars dict'''
target_pars_keys = sc.flattendict(target_pars)
for key, val in target_pars_keys.items():
name = '_'.join(key)
initial_pars[name] = val[0]
par_bounds[name] = np.array([val[1], val[2]])
return initial_pars, par_bounds
[docs]
def get_full_pars(self, sim=None, calib_pars=None, genotype_pars=None, hiv_pars=None):
''' Make a full pardict from the subset of regular sim parameters, genotype parameters, and hiv parameters used in calibration'''
# Prepare the parameters
new_pars = {}
if genotype_pars is not None:
new_pars['genotype_pars'] = self.update_dict_pars(sim['genotype_pars'], genotype_pars)
if hiv_pars is not None:
new_pars['hiv_pars'] = self.update_dict_pars(sim.hivsim.pars['hiv_pars'], hiv_pars)
if calib_pars is not None:
calib_pars_flatten = sc.flattendict(calib_pars)
for key, val in calib_pars_flatten.items():
if key[0] in sim.pars and key[0] not in new_pars:
new_pars[key[0]] = sc.dcp(sim.pars[key[0]])
try:
sc.setnested(new_pars, list(key), val) # only update on keys that have values in sim.pars. If this line makes error, raise error errormsg
except Exception as e:
errormsg = f"Parameter {'_'.join(key)} is not part of the sim, nor is a custom function specified to use them"
raise ValueError(errormsg)
return new_pars
[docs]
def trial_pars_to_sim_pars(self, trial_pars=None, which_pars=None, return_full=True):
'''
Create genotype_pars and pars dicts from the trial parameters.
Note: not used during self.calibrate.
Args:
trial_pars (dict): dictionary of parameters from a single trial. If not provided, best parameters will be used
return_full (bool): whether to return a unified par dict ready for use in a sim, or the sim pars and genotype pars separately
**Example**::
sim = hpv.Sim(genotypes=[16, 18])
calib_pars = dict(beta=[0.05, 0.010, 0.20],hpv_control_prob=[.9, 0.5, 1])
genotype_pars = dict(hpv16=dict(prog_time=[3, 3, 10]))
calib = hpv.Calibration(sim, calib_pars=calib_pars, genotype_pars=genotype_pars
datafiles=['test_data/south_africa_hpv_data.xlsx',
'test_data/south_africa_cancer_data.xlsx'],
total_trials=10, n_workers=4)
calib.calibrate()
new_pars = calib.trial_pars_to_sim_pars() # Returns best parameters from calibration in a format ready for sim running
sim.update_pars(new_pars)
sim.run()
'''
# Initialize
calib_pars = sc.objdict()
genotype_pars = sc.objdict()
hiv_pars = sc.objdict()
# Deal with trial parameters
if trial_pars is None:
try:
if which_pars is None or which_pars==0:
trial_pars = self.best_pars
else:
ddict = self.df.to_dict(orient='records')[which_pars]
trial_pars = {k:v for k,v in ddict.items() if k not in ['index','mismatch']}
except:
errormsg = 'No trial parameters provided.'
raise ValueError(errormsg)
# Handle genotype parameters
if self.genotype_pars is not None:
genotype_pars = self.update_dict_pars_from_trial(self.genotype_pars, trial_pars)
# Handle hiv sim parameters
if self.hiv_pars is not None:
hiv_pars = self.update_dict_pars_from_trial(self.hiv_pars, trial_pars)
# Handle regular sim parameters
if self.calib_pars is not None:
calib_pars = self.update_dict_pars_from_trial(self.calib_pars, trial_pars)
# Return
if return_full:
all_pars = self.get_full_pars(sim=self.sim, calib_pars=calib_pars, genotype_pars=genotype_pars, hiv_pars=hiv_pars)
return all_pars
else:
return calib_pars, genotype_pars, hiv_pars
[docs]
def sim_to_sample_pars(self):
''' Convert sim pars to sample pars '''
initial_pars = sc.objdict()
par_bounds = sc.objdict()
# Convert regular sim pars
if self.calib_pars is not None:
initial_pars, par_bounds = self.update_dict_pars_init_and_bounds(initial_pars, par_bounds, self.calib_pars)
# Convert genotype pars
if self.genotype_pars is not None:
initial_pars, par_bounds = self.update_dict_pars_init_and_bounds(initial_pars, par_bounds, self.genotype_pars)
# Convert hiv pars
if self.hiv_pars is not None:
initial_pars, par_bounds = self.update_dict_pars_init_and_bounds(initial_pars, par_bounds, self.hiv_pars)
return initial_pars, par_bounds
[docs]
def trial_to_sim_pars(self, pardict=None, trial=None):
'''
Take in an optuna trial and sample from pars, after extracting them from the structure they're provided in
'''
pars = sc.dcp(pardict)
pars_flatten = sc.flattendict(pardict)
for key, val in pars_flatten.items():
sampler_key = '_'.join(key)
low, high = val[1], val[2]
step = val[3] if len(val) > 3 else None
if key in self.par_samplers: # If a custom sampler is used, get it now (Not working properly for now)
try:
sampler_fn = getattr(trial, self.par_samplers[key])
except Exception as E:
errormsg = 'The requested sampler function is not found: ensure it is a valid attribute of an Optuna Trial object'
raise AttributeError(errormsg) from E
else:
sampler_fn = trial.suggest_float
sc.setnested(pars, list(key), sampler_fn(sampler_key, low, high, step=step))
return pars
[docs]
def run_trial(self, trial, save=True):
''' Define the objective for Optuna '''
if self.genotype_pars is not None:
genotype_pars = self.trial_to_sim_pars(self.genotype_pars, trial)
else:
genotype_pars = None
if self.hiv_pars is not None:
hiv_pars = self.trial_to_sim_pars(self.hiv_pars, trial)
else:
hiv_pars = None
if self.calib_pars is not None:
calib_pars = self.trial_to_sim_pars(self.calib_pars, trial)
else:
calib_pars = None
sim = self.run_sim(calib_pars, genotype_pars, hiv_pars, return_sim=True)
# Compute fit for sim results and save sim results (TODO: THIS IS FOR A SINGLE TIMEPOINT. GENERALIZE THIS)
sim_results = sc.objdict()
for rkey in self.sim_results:
if sim.results[rkey][:].ndim==1:
model_output = sim.results[rkey][self.sim_results[rkey].timepoints[0]]
else:
model_output = sim.results[rkey][:,self.sim_results[rkey].timepoints[0]]
diffs = self.sim_results[rkey].data.value - model_output
gofs = hpm.compute_gof(self.sim_results[rkey].data.value, model_output)
losses = gofs * self.sim_results[rkey].weights
mismatch = losses.sum()
sim.fit += mismatch
sim_results[rkey] = model_output
extra_sim_results = sc.objdict()
if self.extra_sim_result_keys:
for rkey in self.extra_sim_result_keys:
model_output = sim.results[rkey]
extra_sim_results[rkey] = model_output
# Store results in temporary files (TODO: consider alternatives)
if save:
results = dict(sim=sim_results, analyzer=sim.get_analyzer('age_results').results,
extra_sim_results=extra_sim_results)
filename = self.tmp_filename % trial.number
sc.save(filename, results)
return sim.fit
[docs]
def worker(self):
''' Run a single worker '''
op = import_optuna()
if self.verbose:
op.logging.set_verbosity(op.logging.DEBUG)
else:
op.logging.set_verbosity(op.logging.ERROR)
study = op.load_study(storage=self.run_args.storage, study_name=self.run_args.name, sampler = self.run_args.sampler)
output = study.optimize(self.run_trial, n_trials=self.run_args.n_trials, callbacks=None)
return output
[docs]
def run_workers(self):
''' Run multiple workers in parallel '''
if self.run_args.n_workers > 1: # Normal use case: run in parallel
output = sc.parallelize(self.worker, iterarg=self.run_args.n_workers)
else: # Special case: just run one
output = [self.worker()]
return output
[docs]
def remove_db(self):
'''
Remove the database file if keep_db is false and the path exists.
'''
try:
op = import_optuna()
op.delete_study(study_name=self.run_args.name, storage=self.run_args.storage)
if self.verbose:
print(f'Deleted study {self.run_args.name} in {self.run_args.storage}')
except Exception as E:
print('Could not delete study, skipping...')
print(str(E))
if os.path.exists(self.run_args.db_name):
os.remove(self.run_args.db_name)
if self.verbose:
print(f'Removed existing calibration {self.run_args.db_name}')
return
[docs]
def make_study(self):
''' Make a study, deleting one if it already exists '''
op = import_optuna()
if not self.run_args.keep_db:
self.remove_db()
if self.run_args.rand_seed is not None:
sampler = op.samplers.RandomSampler(self.run_args.rand_seed)
sampler.reseed_rng()
raise NotImplementedError('Implemented but does not work')
else:
sampler = None
output = op.create_study(storage=self.run_args.storage, study_name=self.run_args.name, sampler=sampler)
return output
[docs]
def calibrate(self, calib_pars=None, genotype_pars=None, hiv_pars=None, verbose=True, load=True, tidyup=True, **kwargs):
'''
Actually perform calibration.
Args:
calib_pars (dict): if supplied, overwrite stored calib_pars
verbose (bool): whether to print output from each trial
kwargs (dict): if supplied, overwrite stored run_args (n_trials, n_workers, etc.)
'''
op = import_optuna()
# Load and validate calibration parameters
if calib_pars is not None:
self.calib_pars = calib_pars
if genotype_pars is not None:
self.genotype_pars = genotype_pars
if hiv_pars is not None:
self.hiv_pars = hiv_pars
if (self.calib_pars is None) and (self.genotype_pars is None) and (self.hiv_pars is None):
errormsg = 'You must supply calibration parameters (calib_pars or genotype_pars) either when creating the calibration object or when calling calibrate().'
raise ValueError(errormsg)
self.run_args.update(kwargs) # Update optuna settings
# Run the optimization
t0 = sc.tic()
self.make_study()
self.run_workers()
study = op.load_study(storage=self.run_args.storage, study_name=self.run_args.name, sampler = self.run_args.sampler)
self.best_pars = sc.objdict(study.best_params)
self.elapsed = sc.toc(t0, output=True)
# Collect analyzer results
# Load a single sim
sim = self.sim # TODO: make sure this is OK #sc.jsonpickle(self.study.trials[0].user_attrs['jsonpickle_sim'])
self.ng = sim['n_genotypes']
self.glabels = [g.upper() for g in sim['genotype_map'].values()]
# Replace with something else, this is fragile
self.analyzer_results = []
self.sim_results = []
self.extra_sim_results = []
if load:
print('Loading saved results...')
for trial in study.trials:
n = trial.number
try:
filename = self.tmp_filename % trial.number
results = sc.load(filename)
self.sim_results.append(results['sim'])
self.analyzer_results.append(results['analyzer'])
self.extra_sim_results.append(results['extra_sim_results'])
if tidyup:
try:
os.remove(filename)
print(f' Removed temporary file {filename}')
except Exception as E:
errormsg = f'Could not remove {filename}: {str(E)}'
print(errormsg)
print(f' Loaded trial {n}')
except Exception as E:
errormsg = f'Warning, could not load trial {n}: {str(E)}'
print(errormsg)
# Compare the results
self.initial_pars, self.par_bounds = self.sim_to_sample_pars()
self.parse_study(study)
# Tidy up
self.calibrated = True
if not self.run_args.keep_db:
self.remove_db()
return self
[docs]
def parse_study(self, study):
'''Parse the study into a data frame -- called automatically '''
best = study.best_params
self.best_pars = best
print('Making results structure...')
results = []
n_trials = len(study.trials)
failed_trials = []
for trial in study.trials:
data = {'index':trial.number, 'mismatch': trial.value}
for key,val in trial.params.items():
data[key] = val
if data['mismatch'] is None:
failed_trials.append(data['index'])
else:
results.append(data)
print(f'Processed {n_trials} trials; {len(failed_trials)} failed')
keys = ['index', 'mismatch'] + list(best.keys())
data = sc.objdict().make(keys=keys, vals=[])
for i,r in enumerate(results):
for key in keys:
if key not in r:
warnmsg = f'Key {key} is missing from trial {i}, replacing with default'
hpm.warn(warnmsg)
r[key] = best[key]
data[key].append(r[key])
self.data = data
self.df = pd.DataFrame.from_dict(data)
self.df = self.df.sort_values(by=['mismatch']) # Sort
return
[docs]
def to_json(self, filename=None, indent=2, **kwargs):
'''
Convert the data to JSON.
'''
order = np.argsort(self.df['mismatch'])
json = []
for o in order:
row = self.df.iloc[o,:].to_dict()
rowdict = dict(index=row.pop('index'), mismatch=row.pop('mismatch'), pars={})
for key,val in row.items():
rowdict['pars'][key] = val
json.append(rowdict)
self.json = json
if filename:
return sc.savejson(filename, json, indent=indent, **kwargs)
else:
return json
[docs]
def plot(self, res_to_plot=None, fig_args=None, axis_args=None, data_args=None, show_args=None, do_save=None,
fig_path=None, do_show=True, plot_type='sns.boxplot', **kwargs):
'''
Plot the calibration results
Args:
res_to_plot (int): number of results to plot. if None, plot them all
fig_args (dict): passed to pl.figure()
axis_args (dict): passed to pl.subplots_adjust()
data_args (dict): 'width', 'color', and 'offset' arguments for the data
do_save (bool): whether to save
fig_path (str or filepath): filepath to save to
do_show (bool): whether to show the figure
kwargs (dict): passed to ``hpv.options.with_style()``; see that function for choices
'''
# Import Seaborn here since slow
if sc.isstring(plot_type) and plot_type.startswith('sns'):
import seaborn as sns
if plot_type.split('.')[1]=='boxplot':
extra_args=dict(boxprops=dict(alpha=.3), showfliers=False)
else: extra_args = dict()
plot_func = getattr(sns, plot_type.split('.')[1])
else:
plot_func = plot_type
extra_args = dict()
# Handle inputs
fig_args = sc.mergedicts(dict(figsize=(12,8)), fig_args)
axis_args = sc.mergedicts(dict(left=0.08, right=0.92, bottom=0.08, top=0.92), axis_args)
d_args = sc.objdict(sc.mergedicts(dict(width=0.3, color='#000000', offset=0), data_args))
show_args = sc.objdict(sc.mergedicts(dict(show=dict(tight=True, maximize=False)), show_args))
all_args = sc.objdict(sc.mergedicts(fig_args, axis_args, d_args, show_args))
# Pull out results to use
analyzer_results = sc.dcp(self.analyzer_results)
sim_results = sc.dcp(self.sim_results)
# Get rows and columns
if not len(analyzer_results) and not len(sim_results):
errormsg = 'Cannot plot since no results were recorded)'
raise ValueError(errormsg)
else:
all_dates = [[date for date in r.keys() if date != 'bins'] for r in analyzer_results[0].values()]
dates_per_result = [len(date_list) for date_list in all_dates]
other_results = len(sim_results[0].keys())
n_plots = sum(dates_per_result) + other_results
n_rows, n_cols = sc.get_rows_cols(n_plots)
# Initialize
fig, axes = pl.subplots(n_rows, n_cols, **fig_args)
if n_plots>1:
for ax in axes.flat[n_plots:]:
ax.set_visible(False)
axes = axes.flatten()
pl.subplots_adjust(**axis_args)
# Pull out attributes that don't vary by run
age_labels = sc.objdict()
for resname,resdict in zip(self.age_results_keys, analyzer_results[0].values()):
age_labels[resname] = [str(int(resdict['bins'][i])) + '-' + str(int(resdict['bins'][i + 1])) for i in range(len(resdict['bins']) - 1)]
age_labels[resname].append(str(int(resdict['bins'][-1])) + '+')
# determine how many results to plot
if res_to_plot is not None:
index_to_plot = self.df.iloc[0:res_to_plot, 0].values
analyzer_results = [analyzer_results[i] for i in index_to_plot]
sim_results = [sim_results[i] for i in index_to_plot]
# Make the figure
with hpo.with_style(**kwargs):
plot_count = 0
for rn, resname in enumerate(self.age_results_keys):
x = np.arange(len(age_labels[resname])) # the label locations
for date in all_dates[rn]:
# Initialize axis and data storage structures
if n_plots>1:
ax = axes[plot_count]
else:
ax = axes
bins = []
genotypes = []
values = []
# Pull out data
thisdatadf = self.target_data[rn][(self.target_data[rn].year == float(date)) & (self.target_data[rn].name == resname)]
unique_genotypes = thisdatadf.genotype.unique()
# Start making plot
if 'genotype' in resname:
for g in range(self.ng):
glabel = self.glabels[g].upper()
# Plot data
if glabel in unique_genotypes:
ydata = np.array(thisdatadf[thisdatadf.genotype == glabel].value)
ax.scatter(x, ydata, color=self.result_args[resname].color[g], marker='s', label=f'Data - {glabel}')
# Construct a dataframe with things in the most logical order for plotting
for run_num, run in enumerate(analyzer_results):
genotypes += [glabel]*len(x)
bins += x.tolist()
values += list(run[resname][date][g])
# Plot model
modeldf = pd.DataFrame({'bins':bins, 'values':values, 'genotypes':genotypes})
ax = plot_func(ax=ax, x='bins', y='values', hue="genotypes", data=modeldf, **extra_args)
else:
# Plot data
ydata = np.array(thisdatadf.value)
ax.scatter(x, ydata, color=self.result_args[resname].color, marker='s', label='Data')
# Construct a dataframe with things in the most logical order for plotting
for run_num, run in enumerate(analyzer_results):
bins += x.tolist()
values += list(run[resname][date])
# Plot model
modeldf = pd.DataFrame({'bins':bins, 'values':values})
ax = plot_func(ax=ax, x='bins', y='values', data=modeldf, color=self.result_args[resname].color, **extra_args)
# Set title and labels
ax.set_xlabel('Age group')
ax.set_title(f'{self.result_args[resname].name}, {date}')
ax.legend()
ax.set_xticks(x, age_labels[resname], rotation=45)
plot_count += 1
for rn, resname in enumerate(self.sim_results_keys):
if n_plots > 1:
ax = axes[plot_count]
else:
ax = axes
bins = sc.autolist()
values = sc.autolist()
thisdatadf = self.target_data[rn+sum(dates_per_result)][self.target_data[rn + sum(dates_per_result)].name == resname]
ydata = np.array(thisdatadf.value)
x = np.arange(len(ydata))
ax.scatter(x, ydata, color=pl.cm.Reds(0.95), marker='s', label='Data')
# Construct a dataframe with things in the most logical order for plotting
for run_num, run in enumerate(sim_results):
bins += x.tolist()
if sc.isnumber(run[resname]):
values += sc.promotetolist(run[resname])
else:
values += run[resname].tolist()
# Plot model
modeldf = pd.DataFrame({'bins': bins, 'values': values})
ax = plot_func(ax=ax, x='bins', y='values', data=modeldf, **extra_args)
# Set title and labels
date = thisdatadf.year[0]
ax.set_title(self.result_args[resname].name + ', ' + str(date))
ax.legend()
if 'genotype_dist' in resname:
ax.set_xticks(x, self.glabels)
ax.set_xlabel('Genotype')
plot_count += 1
return hppl.tidy_up(fig, do_save=do_save, fig_path=fig_path, do_show=do_show, args=all_args)