stisim.calibration module#

Define the calibration class

class stisim.calibration.Calibration(sim, data, calib_pars, n_trials=None, n_workers=None, total_trials=None, reseed=True, weights=None, fit_args=None, build_fn=None, sep='.', name=None, db_name=None, keep_db=None, storage=None, rand_seed=None, sampler=None, label=None, die=False, debug=False, verbose=True, save_results=False)[source]#

Bases: prettyobj

A class to handle calibration of Starsim simulations. Uses the Optuna hyperparameter optimization library (optuna.org).

Args:

sim (Sim) : the simulation to calibrate data (df) : pandas dataframe (or dataframe-compatible dict) of the data to calibrate to calib_pars (dict) : a dictionary of the parameters to calibrate of the format dict(key1=[best, low, high]) n_trials (int) : the number of trials per worker n_workers (int) : the number of parallel workers (default: maximum number of available CPUs) total_trials (int) : if n_trials is not supplied, calculate by dividing this number by n_workers reseed (bool) : whether to generate new random seeds for each trial weights (dict) : the relative weights of each data source fit_args (dict) : a dictionary of options that are passed to sim.compute_fit() to calculate the goodness-of-fit sep (str) : the separate between different types of results, e.g. ‘hiv.deaths’ vs ‘hiv_deaths’ name (str) : the name of the database (default: ‘starsim_calibration’) db_name (str) : the name of the database file (default: ‘starsim_calibration.db’) keep_db (bool) : whether to keep the database after calibration (default: false) storage (str) : the location of the database (default: sqlite) rand_seed (int) : if provided, use this random seed to initialize Optuna runs (for reproducibility) label (str) : a label for this calibration object die (bool) : whether to stop if an exception is encountered (default: false) debug (bool) : if True, do not run in parallel verbose (bool) : whether to print details of the calibration

Returns:

A Calibration object

run_sim(calib_pars=None, label=None)[source]#

Create and run a simulation

static translate_pars(sim=None, calib_pars=None)[source]#

Take the nested dict of calibration pars and modify the sim

trial_to_sim_pars(pardict=None, trial=None)[source]#

Take in an optuna trial and sample from pars, after extracting them from the structure they’re provided in

Different use cases:
  • pardict is self.calib_pars, i.e. {‘diseases’:{‘hiv’:{‘art_efficacy’:[0.96, 0.9, 0.99]}}}, need to sample

  • pardict is self.initial_pars, i.e. {‘diseases’:{‘hiv’:{‘art_efficacy’:[0.96, 0.9, 0.99]}}}, pull 1st vals

  • pardict is self.best_pars, i.e. {‘diseases’:{‘hiv’:{‘art_efficacy’:0.96786}}}, pull single vals

run_trial(trial)[source]#

Define the objective for Optuna

compute_fit(df_res=None)[source]#

Compute goodness-of-fit

worker()[source]#

Run a single worker

run_workers()[source]#

Run multiple workers in parallel

remove_db()[source]#

Remove the database file if keep_db is false and the path exists

make_study()[source]#

Make a study, deleting one if it already exists

calibrate(calib_pars=None, confirm_fit=False, load=False, tidyup=True, **kwargs)[source]#

Perform calibration.

Args:

calib_pars (dict): if supplied, overwrite stored calib_pars confirm_fit (bool): if True, run simulations with parameters from before and after calibration load (bool): whether to load existing trials from the database (if rerunning the same calibration) tidyup (bool): whether to delete temporary files from trial runs verbose (bool): whether to print output from each trial kwargs (dict): if supplied, overwrite stored run_args (n_trials, n_workers, etc.)

confirm_fit()[source]#

Run before and after simulations to validate the fit

parse_study(study)[source]#

Parse the study into a data frame – called automatically

static shrink(calib, n_results=100)[source]#

Shrink the results to only the best fit

to_json(filename=None, indent=2, **kwargs)[source]#

Convert the results to JSON

plot_sims(**kwargs)[source]#

Plot sims, before and after calibration.

Args:

kwargs (dict): passed to MultiSim.plot()

plot_trend(best_thresh=None, fig_kw=None)[source]#

Plot the trend in best mismatch over time.

Args:

best_thresh (int): Define the threshold for the “best” fits, relative to the lowest mismatch value (if None, show all) fig_kw (dict): passed to plt.figure()

stisim.calibration.compute_gof(actual, predicted, normalize=True, use_frac=False, use_squared=False, as_scalar='none', eps=1e-09, skestimator=None, estimator=None, **kwargs)[source]#

Calculate the goodness of fit. By default use normalized absolute error, but highly customizable. For example, mean squared error is equivalent to setting normalize=False, use_squared=True, as_scalar=’mean’.

Args:

actual (arr): array of actual (data) points predicted (arr): corresponding array of predicted (model) points normalize (bool): whether to divide the values by the largest value in either series use_frac (bool): convert to fractional mismatches rather than absolute use_squared (bool): square the mismatches as_scalar (str): return as a scalar instead of a time series: choices are sum, mean, median eps (float): to avoid divide-by-zero skestimator (str): if provided, use this scikit-learn estimator instead estimator (func): if provided, use this custom estimator instead kwargs (dict): passed to the scikit-learn or custom estimator

Returns:

gofs (arr): array of goodness-of-fit values, or a single value if as_scalar is True

Examples:

x1 = np.cumsum(np.random.random(100))
x2 = np.cumsum(np.random.random(100))

e1 = compute_gof(x1, x2) # Default, normalized absolute error
e2 = compute_gof(x1, x2, normalize=False, use_frac=False) # Fractional error
e3 = compute_gof(x1, x2, normalize=False, use_squared=True, as_scalar='mean') # Mean squared error
e4 = compute_gof(x1, x2, skestimator='mean_squared_error') # Scikit-learn's MSE method
e5 = compute_gof(x1, x2, as_scalar='median') # Normalized median absolute error -- highly robust