Calibration#

class Calibration(sim, calib_pars, n_workers=None, total_trials=None, reseed=True, build_fn=None, build_kw=None, eval_fn=None, eval_kwargs=None, components=None, label=None, study_name=None, db_name=None, keep_db=None, storage=None, sampler=None, die=False, debug=False, verbose=True)[source]#

Bases: prettyobj

A class to handle calibration of Starsim simulations. Uses the Optuna hyperparameter optimization library (optuna.org).

Parameters:
  • sim (Sim) – the base simulation to calibrate

  • calib_pars (dict) – a dictionary of the parameters to calibrate of the format dict(key1=dict(low=1, high=2, guess=1.5, **kwargs), key2=…), where kwargs can include “suggest_type” to choose the suggest method of the trial (e.g. suggest_float) and args passed to the trial suggest function like “log” and “step”

  • n_workers (int) – the number of parallel workers (if None, will use all available CPUs)

  • total_trials (int) – the total number of trials to run, each worker will run approximately n_trials = total_trial / n_workers

  • reseed (bool) – whether to generate new random seeds for each trial

  • build_fn (callable) – function that takes a sim object and calib_pars dictionary and returns a modified sim

  • build_kw (dict) – a dictionary of options that are passed to build_fn to aid in modifying the base simulation. The API is self.build_fn(sim, calib_pars=calib_pars, **self.build_kw), where sim is a copy of the base simulation to be modified with calib_pars

  • components (list) – CalibComponents independently assess pseudo-likelihood as part of evaluating the quality of input parameters

  • eval_fn (callable) – Function mapping a sim to a float (e.g. negative log likelihood) to be maximized. If None, the default will use CalibComponents.

  • eval_kwargs (dict) – Additional keyword arguments to pass to the eval_fn

  • label (str) – a label for this calibration object

  • study_name (str) – name of the optuna study

  • db_name (str) – the name of the database file (default: ‘starsim_calibration.db’)

  • keep_db (bool) – whether to keep the database after calibration (default: false)

  • storage (str) – the location of the database (default: sqlite)

  • sampler (BaseSampler) – the sampler used by optuna, like optuna.samplers.TPESampler

  • die (bool) – whether to stop if an exception is encountered (default: false)

  • debug (bool) – if True, do not run in parallel

  • verbose (bool) – whether to print details of the calibration

Returns:

A Calibration object

Methods

run_sim(calib_pars=None, label=None)[source]#

Create and run a simulation

static translate_pars(sim=None, calib_pars=None)[source]#

Take the nested dict of calibration pars and modify the sim

run_trial(trial)[source]#

Define the objective for Optuna

worker()[source]#

Run a single worker

run_workers()[source]#

Run multiple workers in parallel

remove_db()[source]#

Remove the database file if keep_db is false and the path exists

make_study()[source]#

Make a study, deleting one if it already exists

calibrate(calib_pars=None, load=False, tidyup=True, **kwargs)[source]#

Perform calibration.

Parameters:
  • calib_pars (dict) – if supplied, overwrite stored calib_pars

  • load (bool) – whether to load existing trials from the database (if rerunning the same calibration)

  • tidyup (bool) – whether to delete temporary files from trial runs

  • verbose (bool) – whether to print output from each trial

  • kwargs (dict) – if supplied, overwrite stored run_args (n_trials, n_workers, etc.)

check_fit(n_runs=5)[source]#

Run before and after simulations to validate the fit

parse_study(study)[source]#

Parse the study into a data frame – called automatically

to_json(filename=None, indent=2, **kwargs)[source]#

Convert the results to JSON

plot_sims(**kwargs)[source]#

Plot sims, before and after calibration.

Parameters:

kwargs (dict) – passed to MultiSim.plot()

plot_trend(best_thresh=None, fig_kw=None)[source]#

Plot the trend in best mismatch over time.

Parameters:
  • best_thresh (int) – Define the threshold for the “best” fits, relative to the lowest mismatch value (if None, show all)

  • fig_kw (dict) – passed to plt.figure()