Calibration#
- class Calibration(sim, calib_pars, n_workers=None, total_trials=None, reseed=True, build_fn=None, build_kw=None, eval_fn=None, eval_kwargs=None, components=None, label=None, study_name=None, db_name=None, keep_db=None, storage=None, sampler=None, die=False, debug=False, verbose=True)[source]#
Bases:
prettyobj
A class to handle calibration of Starsim simulations. Uses the Optuna hyperparameter optimization library (optuna.org).
- Parameters:
sim (Sim) – the base simulation to calibrate
calib_pars (dict) – a dictionary of the parameters to calibrate of the format dict(key1=dict(low=1, high=2, guess=1.5, **kwargs), key2=…), where kwargs can include “suggest_type” to choose the suggest method of the trial (e.g. suggest_float) and args passed to the trial suggest function like “log” and “step”
n_workers (int) – the number of parallel workers (if None, will use all available CPUs)
total_trials (int) – the total number of trials to run, each worker will run approximately n_trials = total_trial / n_workers
reseed (bool) – whether to generate new random seeds for each trial
build_fn (callable) – function that takes a sim object and calib_pars dictionary and returns a modified sim
build_kw (dict) – a dictionary of options that are passed to build_fn to aid in modifying the base simulation. The API is self.build_fn(sim, calib_pars=calib_pars, **self.build_kw), where sim is a copy of the base simulation to be modified with calib_pars
components (list) – CalibComponents independently assess pseudo-likelihood as part of evaluating the quality of input parameters
eval_fn (callable) – Function mapping a sim to a float (e.g. negative log likelihood) to be maximized. If None, the default will use CalibComponents.
eval_kwargs (dict) – Additional keyword arguments to pass to the eval_fn
label (str) – a label for this calibration object
study_name (str) – name of the optuna study
db_name (str) – the name of the database file (default: ‘starsim_calibration.db’)
keep_db (bool) – whether to keep the database after calibration (default: false)
storage (str) – the location of the database (default: sqlite)
sampler (BaseSampler) – the sampler used by optuna, like optuna.samplers.TPESampler
die (bool) – whether to stop if an exception is encountered (default: false)
debug (bool) – if True, do not run in parallel
verbose (bool) – whether to print details of the calibration
- Returns:
A Calibration object
Methods
- static translate_pars(sim=None, calib_pars=None)[source]#
Take the nested dict of calibration pars and modify the sim
- calibrate(calib_pars=None, load=False, tidyup=True, **kwargs)[source]#
Perform calibration.
- Parameters:
calib_pars (dict) – if supplied, overwrite stored calib_pars
load (bool) – whether to load existing trials from the database (if rerunning the same calibration)
tidyup (bool) – whether to delete temporary files from trial runs
verbose (bool) – whether to print output from each trial
kwargs (dict) – if supplied, overwrite stored run_args (n_trials, n_workers, etc.)