Source code for stisim.networks

"""
Define sexual network for STI transmission.

Overview:

- Risk groups: agents are randomly assigned into one of 3 main risk groups:

    - 0 = marry and remain married to a single partner throughout their lifetime
    - 1 = marry and then divorce or who have concurrent partner(s) during their marriage
    - 2 = never marry
    
- In addition, a proportion of each of the groups above engages in sex work.
"""

import starsim as ss
import sciris as sc
import numpy as np
import pandas as pd
import scipy.optimize as spo
import scipy.spatial as spsp

ss_float_ = ss.dtypes.float

# Specify all externally visible functions this file defines; see also more definitions below
__all__ = ['StructuredSexual', 'FastStructuredSexual', 'AgeMatchedMSM', 'AgeApproxMSM']


class NoPartnersFound(Exception):
    # Raise this exception if the matching algorithm wasn't able to match any partners
    pass


[docs] class StructuredSexual(ss.SexualNetwork): """ Structured sexual network """ def __init__(self, pars=None, key_dict=None, condom_data=None, name=None, **kwargs): key_dict = sc.mergedicts({ 'sw': bool, 'condoms': ss_float_, 'age_p1': ss_float_, 'age_p2': ss_float_, }, key_dict) super().__init__(key_dict=key_dict, name=name) self.define_pars( # Settings - generally shouldn't be adjusted unit='month', store_register=False, n_risk_groups=3, f_age_group_bins=dict( # For separating women into age groups: teens, young women, adult women teens=(0, 20), young=(20, 25), adult=(25, np.inf), ), # Age of sexual debut debut=ss.lognorm_ex(20, 3), debut_pars_f=[20, 3], debut_pars_m=[21, 3], # Risk groups p_lo_risk=ss.bernoulli(p=0), p_hi_risk=ss.bernoulli(p=0), prop_f0=0.85, prop_m0=0.8, prop_f2=0.01, prop_m2=0.02, # Age difference preferences age_diff_pars=dict( teens=[(7, 3), (6, 3), (5, 1)], # (mu,stdev) for levels 0, 1, 2 young=[(8, 3), (7, 3), (5, 2)], adult=[(8, 3), (7, 3), (5, 2)], ), # Concurrency preferences concurrency_dist=ss.poisson(lam=1), f0_conc=0.0001, f1_conc=0.01, f2_conc=0.1, m0_conc=0.0001, m1_conc=0.2, m2_conc=0.5, # Relationship initiation, stability, and duration p_pair_form=ss.bernoulli(p=0.5), # Probability of a (stable) pair forming between two matched people match_dist=ss.bernoulli(p=0), # Placeholder value replaced by risk-group stratified values below p_matched_stable=[0.9, 0.5, 0], # Probability of a stable pair forming between matched people (otherwise casual) p_mismatched_casual=[0.5, 0.5, 0.5], # Probability of a casual pair forming between mismatched people (otherwise instantanous) # Durations of stable and casual relationships stable_dur_pars=dict( teens=[ # (mu,stdev) for levels 0, 1, 2 [ss.dur(100, 'year'), ss.dur(1, 'year')], [ss.dur(8, 'year'), ss.dur(2, 'year')], [ss.dur(1e-4, 'month'), ss.dur(1e-4, 'month')] ], young=[ [ss.dur(100, 'year'), ss.dur(1, 'year')], [ss.dur(10, 'year'), ss.dur(3, 'year')], [ss.dur(1e-4, 'month'), ss.dur(1e-4, 'month')] ], adult=[ [ss.dur(100, 'year'), ss.dur(1, 'year')], [ss.dur(12, 'year'), ss.dur(3, 'year')], [ss.dur(1e-4, 'month'), ss.dur(1e-4, 'month')] ], ), casual_dur_pars=dict( teens=[[ss.dur(1, 'year'), ss.dur(3, 'year')]]*3, young=[[ss.dur(1, 'year'), ss.dur(3, 'year')]]*3, adult=[[ss.dur(1, 'year'), ss.dur(3, 'year')]]*3, ), # Acts acts=ss.lognorm_ex(ss.peryear(80), ss.peryear(30)), # Coital acts/year # Sex work parameters fsw_shares=ss.bernoulli(p=0.05), client_shares=ss.bernoulli(p=0.12), sw_seeking_rate=ss.rate(1, 'month'), # Monthly rate at which clients seek FSWs (1 new SW partner / month) sw_seeking_dist=ss.bernoulli(p=0.5), # Placeholder value replaced by dt-adjusted sw_seeking_rate sw_beta=1, sw_intensity=ss.random(), # At each time step, FSW may work with varying intensity # Distributions derived from parameters above - don't adjust age_diffs=ss.normal(), dur_dist=ss.lognorm_ex(), ) self.update_pars(pars=pars, **kwargs) # Set condom use self.condom_data = None if condom_data is not None: self.condom_data = self.process_condom_data(condom_data) # Store register if self.pars.store_register: self.breakup_register = [[]]*12 # Add states self.define_states( ss.BoolArr('participant', default=True), ss.FloatArr('risk_group'), # Which risk group an agent belongs to ss.BoolArr('fsw'), # Whether an agent is a female sex worker ss.BoolArr('client'), # Whether an agent is a client of sex workers ss.FloatArr('concurrency'), # Preferred number of concurrent partners ss.FloatArr('partners', default=0), # Actual number of concurrent partners ss.FloatArr('partners_12', default=0), # Number of partners over the past 12m ss.FloatArr('lifetime_partners', default=0), # Lifetime total number of partners ss.FloatArr('sw_intensity'), # Intensity of sex work ) return
[docs] @staticmethod def process_condom_data(condom_data): if sc.isnumber(condom_data): return condom_data elif isinstance(condom_data, pd.DataFrame): df = condom_data.melt(id_vars=['partnership']) dd = dict() for pcombo in df.partnership.unique(): key = tuple(map(int, pcombo[1:-1].split(','))) if pcombo != '(fsw,client)' else ('fsw','client') thisdf = df.loc[df.partnership == pcombo] dd[key] = dict() dd[key]['year'] = thisdf.variable.values.astype(int) dd[key]['val'] = thisdf.value.values return dd
[docs] def get_age_risk_pars(self, uids, par): loc = np.full(uids.shape, fill_value=np.nan, dtype=ss_float_) scale = np.full(uids.shape, fill_value=np.nan, dtype=ss_float_) for a_label, (age_lower, age_upper) in self.pars.f_age_group_bins.items(): for rg in range(self.pars.n_risk_groups): in_risk_group = (self.sim.people.age[uids] >= age_lower) & (self.sim.people.age[uids] < age_upper) & (self.risk_group[uids] == rg) loc[in_risk_group] = par[a_label][rg][0] scale[in_risk_group] = par[a_label][rg][1] if np.isnan(scale).any() or np.isnan(loc).any(): errormsg = 'Invalid entries for age difference preferences.' raise ValueError(errormsg) return loc, scale
[docs] def init_pre(self, sim): super().init_pre(sim) if self.condom_data is not None: if isinstance(self.condom_data, dict): for rgtuple, valdict in self.condom_data.items(): self.condom_data[rgtuple]['simvals'] = sc.smoothinterp(sim.timevec, valdict['year'], valdict['val']) # self.init_results() return
[docs] def init_results(self): self.define_results( ss.Result('share_active', dtype=float, scale=False), ss.Result('partners_f_mean', dtype=float, scale=False), ss.Result('partners_m_mean', dtype=float, scale=False), ) return
[docs] def init_post(self): super().init_post(add_pairs=False) self.set_network_states() return
[docs] def set_network_states(self, upper_age=None): self.set_risk_groups(upper_age=upper_age) self.set_concurrency(upper_age=upper_age) self.set_sex_work(upper_age=upper_age) self.set_debut(upper_age=upper_age) return
[docs] def over_debut(self): return self.sim.people.age > self.debut
def _get_uids(self, upper_age=None, by_sex=True): people = self.sim.people if upper_age is None: upper_age = 1000 within_age = people.age < upper_age if by_sex: f_uids = (within_age & people.female).uids m_uids = (within_age & people.male).uids return f_uids, m_uids else: uids = within_age.uids return uids
[docs] def set_risk_groups(self, upper_age=None): """ Assign each person to a risk group """ ppl = self.sim.people uids = self._get_uids(upper_age=upper_age, by_sex=False) p_lo = np.full(len(uids), fill_value=np.nan, dtype=ss_float_) p_lo[ppl.female[uids]] = self.pars.prop_f0 p_lo[ppl.male[uids]] = self.pars.prop_m0 self.pars.p_lo_risk.set(p=p_lo) lo_risk, hi_med_risk = self.pars.p_lo_risk.split(uids) p_hi = np.full(len(hi_med_risk), fill_value=np.nan, dtype=ss_float_) p_hi[ppl.female[hi_med_risk]] = self.pars.prop_f2/(1-self.pars.prop_f0) p_hi[ppl.male[hi_med_risk]] = self.pars.prop_m2/(1-self.pars.prop_m0) self.pars.p_hi_risk.set(p=p_hi) hi_risk, med_risk = self.pars.p_hi_risk.split(hi_med_risk) self.risk_group[lo_risk] = 0 self.risk_group[med_risk] = 1 self.risk_group[hi_risk] = 2 return
[docs] def set_concurrency(self, upper_age=None): """ Assign each person a preferred number of simultaneous partners """ people = self.sim.people if upper_age is None: upper_age = 1000 in_age_lim = (people.age < upper_age) uids = in_age_lim.uids lam = np.full(uids.shape, fill_value=np.nan, dtype=ss_float_) for rg in range(self.pars.n_risk_groups): f_conc = self.pars[f'f{rg}_conc'] m_conc = self.pars[f'm{rg}_conc'] in_risk_group = self.risk_group == rg in_group = in_risk_group & in_age_lim f_in = (people.female & in_group)[uids] m_in = (people.male & in_group)[uids] if f_in.any(): lam[f_in] = f_conc if m_in.any(): lam[m_in] = m_conc self.pars.concurrency_dist.set(lam=lam) self.concurrency[uids] = self.pars.concurrency_dist.rvs(uids) + 1 return
[docs] def set_sex_work(self, upper_age=None): f_uids, m_uids = self._get_uids(upper_age=upper_age) self.fsw[f_uids] = self.pars.fsw_shares.rvs(f_uids) self.client[m_uids] = self.pars.client_shares.rvs(m_uids) return
[docs] def set_debut(self, upper_age=None): uids = self._get_uids(upper_age=upper_age, by_sex=False) par1 = np.full(len(uids), fill_value=np.nan, dtype=ss_float_) par2 = np.full(len(uids), fill_value=np.nan, dtype=ss_float_) par1[self.sim.people.female[uids]] = self.pars.debut_pars_f[0] par2[self.sim.people.female[uids]] = self.pars.debut_pars_f[1] par1[self.sim.people.male[uids]] = self.pars.debut_pars_m[0] par2[self.sim.people.male[uids]] = self.pars.debut_pars_m[1] self.pars.debut.set(mean=par1, std=par2) self.debut[uids] = self.pars.debut.rvs(uids) return
[docs] def match_pairs(self, ppl): """ Match pairs by age """ # Find people eligible for a relationship active = self.over_debut() underpartnered = self.partners < self.concurrency f_eligible = active & ppl.female & underpartnered m_eligible = active & ppl.male & underpartnered f_looking = self.pars.p_pair_form.filter(f_eligible.uids) # ss.uids of women looking for partners if len(f_looking) == 0 or m_eligible.count() == 0: raise NoPartnersFound() # Get mean age differences and desired ages loc, scale = self.get_age_risk_pars(f_looking, self.pars.age_diff_pars) self.pars.age_diffs.set(loc=loc, scale=scale) age_gaps = self.pars.age_diffs.rvs(f_looking) # Sample the age differences desired_ages = ppl.age[f_looking] + age_gaps # Desired ages of the male partners m_ages = ppl.age[m_eligible] # Ages of eligible males dist_mat = spsp.distance_matrix(m_ages[:, np.newaxis], desired_ages[:, np.newaxis]) ind_m, ind_f = spo.linear_sum_assignment(dist_mat) p1 = m_eligible.uids[ind_m] p2 = f_looking[ind_f] return p1, p2
[docs] def add_pairs(self, ti=None): """ Add pairs """ ppl = self.sim.people dt = self.t.dt # Obtain new pairs try: p1_gp, p2_gp = self.match_pairs(ppl) p1_sw, p2_sw = self.match_sex_workers(ppl) except NoPartnersFound: return p1 = p1_gp.concat(p1_sw) p2 = p2_gp.concat(p2_sw) sw = np.array([False]*len(p1_gp) + [True]*len(p1_sw)) # Initialize beta, acts, duration beta = np.ones(len(p2), dtype=ss_float_) condoms = np.zeros(len(p2), dtype=ss_float_) # FILLED IN LATER acts = (self.pars.acts.rvs(p2)).astype(int) # Number of acts per timestep - does not depend on commitment/risk group dur = np.full(len(p2), dtype=ss_float_, fill_value=dt) # Default duration is dt, replaced for stable matches age_p1 = ppl.age[p1] age_p2 = ppl.age[p2] # Determine whether the pair have matched risk profiles # Partners with mismatched risk profiles may still form a casual partnership matched_risk = (self.risk_group[p1] == self.risk_group[p2]) & ~sw mismatched_risk = (self.risk_group[p1] != self.risk_group[p2]) & ~sw # Set the probability of forming a partnership p_match = np.full(len(p1), fill_value=np.nan, dtype=ss_float_) for rg in range(self.pars.n_risk_groups): p_match[matched_risk & (self.risk_group[p1] == rg)] = self.pars.p_matched_stable[rg] p_match[mismatched_risk & (self.risk_group[p2] == rg)] = self.pars.p_mismatched_casual[rg] self.pars.match_dist.set(p=p_match) matches = self.pars.match_dist.rvs(p2) stable = matches & matched_risk casual = matches & mismatched_risk any_match = stable | casual # Set duration dur_mean = np.full(sum(any_match), fill_value=np.nan, dtype=ss_float_) dur_std = np.full(sum(any_match), fill_value=np.nan, dtype=ss_float_) for which, bools in {'stable': stable, 'casual': casual}.items(): if bools.any(): uids = p2[bools] mean, std = self.get_age_risk_pars(uids, self.pars[f'{which}_dur_pars']) inds = bools[any_match].nonzero()[-1] dur_mean[inds] = mean dur_std[inds] = std self.pars.dur_dist.set(mean=dur_mean, std=dur_std) dur[any_match] = self.pars.dur_dist.rvs(p2[any_match]) self.append(p1=p1, p2=p2, beta=beta, condoms=condoms, dur=dur, acts=acts, sw=sw, age_p1=age_p1, age_p2=age_p2) # Checks if (self.sim.people.female[p1].any() or self.sim.people.male[p2].any()) and (self.name == 'structuredsexual'): errormsg = 'Same-sex pairings should not be possible in this network' raise ValueError(errormsg) if len(p1) != len(p2): errormsg = 'Unequal lengths in edge list' raise ValueError(errormsg) # Add partner counts, not including SW partners unique_p1, counts_p1 = np.unique(p1_gp, return_counts=True) unique_p2, counts_p2 = np.unique(p2_gp, return_counts=True) self.partners[unique_p1] += counts_p1 self.partners[unique_p2] += counts_p2 self.lifetime_partners[unique_p1] += counts_p1 self.lifetime_partners[unique_p2] += counts_p2 return
[docs] def match_sex_workers(self, ppl): """ Match sex workers to clients """ # Find people eligible for a relationship active = self.over_debut() active_fsw = active & self.fsw active_clients = active & self.client self.sw_intensity[active_fsw.uids] = self.pars.sw_intensity.rvs(active_fsw.uids) # Find clients who will seek FSW self.pars.sw_seeking_dist.pars.p = np.clip(self.pars.sw_seeking_rate, 0, 1) m_looking = self.pars.sw_seeking_dist.filter(active_clients.uids) # Attempt to assign a sex worker to every client by repeat sampling the sex workers. # FSW with higher work intensity will be sampled more frequently if len(m_looking) > len(active_fsw.uids): n_repeats = (self.sw_intensity[active_fsw]*10).astype(int)+1 fsw_repeats = np.repeat(active_fsw.uids, n_repeats) if len(fsw_repeats) < len(m_looking): fsw_repeats = np.repeat(fsw_repeats, 10) # 10x the number of clients each sex worker can have # Might still not have enough FSW, so form as many pairs as possible n_pairs = min(len(fsw_repeats), len(m_looking)) if len(fsw_repeats) < len(m_looking): p1 = m_looking[:n_pairs] p2 = fsw_repeats else: unique_sw, counts_sw = np.unique(fsw_repeats, return_counts=True) count_repeats = np.repeat(counts_sw, counts_sw) weights = self.sw_intensity[fsw_repeats] / count_repeats choices = np.argsort(-weights)[:n_pairs] p2 = fsw_repeats[choices] p1 = m_looking else: n_pairs = len(m_looking) weights = self.sw_intensity[active_fsw] choices = np.argsort(-weights)[:n_pairs] p2 = active_fsw.uids[choices] p1 = m_looking return p1, p2
[docs] def end_pairs(self): people = self.sim.people self.edges.dur = self.edges.dur - 1 # Decrement the duration of each partnership, noting that dur is timesteps # Non-alive agents are removed alive_bools = people.alive[ss.uids(self.edges.p1)] & people.alive[ss.uids(self.edges.p2)] active = (self.edges.dur > 0) & alive_bools # Update the breakup register if self.pars.store_register: over_12m = self.breakup_register[11] if len(over_12m): u, c = np.unique(over_12m, return_counts=True) self.partners_12[u] -= c self.breakup_register = self.breakup_register[:11] # Forget partners from >12m ago just_ended = (self.edges.dur == 0) & alive_bools je1 = ss.uids(self.edges.p1[just_ended]) je2 = ss.uids(self.edges.p2[just_ended]) je_uids = je1.concat(je2) self.breakup_register.insert(0, je1.concat(je2)) if len(je_uids): u, c = np.unique(je_uids, return_counts=True) self.partners_12[u] += c # For gen pop contacts that are due to expire, decrement the partner count inactive_gp = ~active & (~self.edges.sw) self.partners[ss.uids(self.edges.p1[inactive_gp])] -= 1 self.partners[ss.uids(self.edges.p2[inactive_gp])] -= 1 # For all contacts that are due to expire, remove them from the contacts list if len(active) > 0: for k in self.meta_keys(): self.edges[k] = (self.edges[k][active]) return
[docs] def update_results(self): ti = self.ti self.results.share_active[ti] = len(self.active(self.sim.people).uids)/len(self.sim.people)
[docs] def net_beta(self, disease_beta=None, uids=None, disease=None): if uids is None: uids = Ellipsis p_condom = self.edges.condoms[uids] eff_condom = disease.pars.eff_condom p_trans_condom = (1 - disease_beta*(1-eff_condom))**(self.edges.acts[uids]*p_condom) p_trans_no_condom = (1 - disease_beta)**(self.edges.acts[uids]*(1-p_condom)) p_trans = 1 - p_trans_condom * p_trans_no_condom result = p_trans * self.edges.beta[uids] return result
[docs] def set_condom_use(self): """ Set condom use """ if self.condom_data is not None: if isinstance(self.condom_data, dict): for rgm in range(self.pars.n_risk_groups): for rgf in range(self.pars.n_risk_groups): risk_pairing = (self.risk_group[self.p1] == rgm) & (self.risk_group[self.p2] == rgf) self.edges.condoms[risk_pairing] = self.condom_data[(rgm, rgf)]['simvals'][self.ti] self.edges.condoms[self.edges.sw] = self.condom_data[('fsw','client')]['simvals'][self.ti] elif sc.isnumber(self.condom_data): self.edges.condoms[:] = self.condom_data else: raise Exception("Unknown condom data input type") return
[docs] def count_partners(self): """ Count the number of partners each person has had over the past 3/12 months """ self.lifetime_partners return
[docs] def step(self): self.end_pairs() self.set_network_states(upper_age=self.t.dt) self.add_pairs() self.set_condom_use() self.count_partners() return
[docs] class FastStructuredSexual(StructuredSexual): def __init__(self, **kwargs): super().__init__(name='structuredsexual', **kwargs)
[docs] def match_pairs(self, ppl): """ Match pairs by age, using sorting rather than the linear sum assignment """ # Find people eligible for a relationship active = self.over_debut() underpartnered = self.partners < self.concurrency f_eligible = active & ppl.female & underpartnered m_eligible = active & ppl.male & underpartnered f_looking = self.pars.p_pair_form.filter(f_eligible.uids) # ss.uids of women looking for partners if len(f_looking) == 0 or m_eligible.count() == 0: raise NoPartnersFound() # Get mean age differences and desired ages loc, scale = self.get_age_risk_pars(f_looking, self.pars.age_diff_pars) self.pars.age_diffs.set(loc=loc, scale=scale) age_gaps = self.pars.age_diffs.rvs(f_looking) # Sample the age differences desired_ages = ppl.age[f_looking] + age_gaps # Desired ages of the male partners m_ages = ppl.age[m_eligible] # Ages of eligible males ind_m = np.argsort(m_ages) # Use sort instead of linear_sum_agreement ind_f = np.argsort(desired_ages) p1 = m_eligible.uids[ind_m] p2 = f_looking[ind_f] maxlen = min(len(p1), len(p2)) p1 = p1[:maxlen] p2 = p2[:maxlen] return p1, p2
[docs] class AgeMatchedMSM(StructuredSexual): def __init__(self, **kwargs): super().__init__(name='msm', **kwargs)
[docs] def match_pairs(self, ppl): """ Match males by age using sorting """ # Find people eligible for a relationship active = self.over_debut() underpartnered = self.partners < self.concurrency m_eligible = active & ppl.male & underpartnered m_looking = self.pars.p_pair_form.filter(m_eligible.uids) if len(m_looking) == 0: raise NoPartnersFound() # Match mairs by sorting the men looking for partners by age, then matching pairs by taking # 2 people at a time from the sorted list m_ages = ppl.age[m_looking] ind_m = np.argsort(m_ages) p1 = m_looking[ind_m][::2] p2 = m_looking[ind_m][1::2] maxlen = min(len(p1), len(p2)) p1 = p1[:maxlen] p2 = p2[:maxlen] # Make sure everyone only appears once (?) if len(np.intersect1d(p1, p2)): errormsg = 'Some people appear in both p1 and p2' raise ValueError(errormsg) return p1, p2
[docs] class AgeApproxMSM(StructuredSexual): def __init__(self, **kwargs): super().__init__(name='msm', **kwargs)
[docs] def match_pairs(self, ppl): """ Match""" # Find people eligible for a relationship active = self.over_debut() underpartnered = self.partners < self.concurrency m_eligible = active & ppl.male & underpartnered m_looking = self.pars.p_pair_form.filter(m_eligible.uids) # Split the total number of males looking for partners into 2 groups # The first group will be matched with the second group group1 = m_looking[::2] group2 = m_looking[1::2] loc, scale = self.get_age_risk_pars(group1, self.pars.age_diff_pars) self.pars.age_diffs.set(loc=loc, scale=scale) age_gaps = self.pars.age_diffs.rvs(group1) desired_ages = ppl.age[group1] + age_gaps g2_ages = ppl.age[group2] ind_p1 = np.argsort(g2_ages) ind_p2 = np.argsort(desired_ages) p1 = m_eligible.uids[ind_p1] p2 = group2[ind_p2] maxlen = min(len(p1), len(p2)) p1 = p1[:maxlen] p2 = p2[:maxlen] return p1, p2