Source code for emod_api.serialization.SerializedPopulation

"""Class to load and manipulate a saved population."""
import json
import collections
import difflib
import copy
import emod_api.serialization.dtkFileTools as dft


[docs]class SerializedPopulation: """Opens the passed file and reads in all the nodes. Args: file: serialized population file Examples: Create an instance of SerializedPopulation:: import emod_api.serialization.SerializedPopulation as SerPop ser_pop = SerPop.SerializedPopulation('state-00001.dtk') """ def __init__(self, file: str): self.next_infection_suid = None self.next_infection_suid_initialized = False self.dtk = self._nodes = [n for n in self.dtk.nodes] @property def nodes(self): """All nodes. Examples: Delete number_of_ind individuals from node 0:: node = ser_pop.nodes[0] del node.individualHumans[0:number_of_ind] Only keep individuals with a certain condition:: node.individualHumans = [ind for ind in node.individualHumans if keep_fct(ind)] Change susceptibility of an individual:: print(node.individualHumans[0].susceptibility) new_susceptibility = {"age": 101.01, "mod_acquire": 0} node.individualHumans[0].susceptibility.update(new_susceptibility) Copy individual[0] from node 0, change properties and add individual as new individual:: import copy individual_properties={"m_age": 1234} individual = copy.deepcopy(node.individualHumans[0]) individual["suid"] = ser_pop.get_next_individual_suid(0) individual.update(individual_properties) ser_pop.nodes[0].individualHumans.append(individual) Infect an individual with an infection copied from another individual:: infection = node["individualHumans"][0]["infections"][0] infection["suid"] = self.get_next_infection_suid() node["individualHumans"][1]["infections"].append(infection) node["individualHumans"][1].m_is_infected = True """ return self._nodes
[docs] def flush(self): """Save all made changes to the node(s).""" for idx in range(len(self._nodes)): self.dtk.nodes[idx] = self._nodes[idx]
[docs] def write(self, output_file: str = "my_sp_file.dtk"): """Write the population to a file. Args: output_file: output file """ self.flush() sim = self.dtk.simulation sim["infectionSuidGenerator"]["next_suid"] = self.get_next_infection_suid() self.dtk.simulation = sim self.dtk.compression = dft.LZ4 print(f"Saving file {output_file}.") dft.write(self.dtk, output_file)
[docs] def get_next_infection_suid(self): """Each infection needs a unique identifier, this function returns one.""" sim = self.dtk.simulation if not self.next_infection_suid_initialized: self.next_infection_suid = sim["infectionSuidGenerator"]["next_suid"] self.next_infection_suid_initialized = True else: self.next_infection_suid["id"] = ( self.next_infection_suid["id"] + sim["infectionSuidGenerator"]["numtasks"] ) return dict(self.next_infection_suid)
[docs] def get_next_individual_suid(self, node_id: int) -> dict: """Each individual needs a unique identifier, this function returns one. Args: node_id: The first parameter. Returns: The return value. True for success, False otherwise. Examples: To get a unique id for an individual:: print(sp.get_next_individual_suid(0)) {'id': 2} """ suid = self._nodes[node_id]["m_IndividualHumanSuidGenerator"]["next_suid"] self._nodes[node_id]["m_IndividualHumanSuidGenerator"]["id"] = ( suid["id"] + self._nodes[node_id]["m_IndividualHumanSuidGenerator"]["numtasks"] ) return dict(suid)
### Some useful functions ###
[docs]def find(name: str, handle, currentlevel="dtk.nodes"): """Recursively searches for a paramters that matches or is close to name and prints out where to find it in the file. Args: name: the paramter you are looking for e.g. "age", "gender". handle: some iterable data structure, can be a list of nodes, a node, list of individuals, etc currentlevel: just a string to print out where the found item is located e.g. "dtk.nodes" or "dtk.node.individuals" Examples: What is the exact paramteter name used for the age of an individual?:: SerPop.find("age", node) ... 1998 Found in: dtk.nodes.individualHumans[999].m_age 1999 Found in: dtk.nodes.individualHumans[999].susceptibility.age 2000 Found in: dtk.nodes.m_vectorpopulations[0].EggQueues[0].age 2001 Found in: dtk.nodes.m_vectorpopulations[0].EggQueues[1].age ... """ global COUNTER if isinstance( handle, str) and difflib.get_close_matches( name, [handle], cutoff=0.6): print(COUNTER, " Found in: ", currentlevel) COUNTER += 1 return if isinstance(handle, str) or not isinstance(handle, collections.Iterable): return # key can be a string or on dict/list/.. for idx, key in enumerate(handle): level = ( currentlevel + "." + key if isinstance(key, str) else currentlevel + "[" + str(idx) + "]" ) try: tmp = handle[key] if isinstance(tmp, collections.Iterable): find(name, key, level + "[]") else: find(name, key, level) except BaseException: # list or keys of a dict, works in all cases but misses objects in # dicts find(name, key, level) if isinstance(handle, dict): find(name, handle[key], level) # check if string is key for a dict
[docs]def get_parameters(handle, currentlevel="dtk.nodes"): """Return a set of all parameters in the serialized population file. Helpful to get an overview about what is in the serialized population file. Args: handle: some iterable data structure, can be a list of nodes, a node, list of individuals, etc currentlevel: just a string to print out where the found item is located e.g. "dtk.nodes"or "dtk.node.individuals Examples: Print all parameters in serialized population file:: for n in sorted(SerPop.get_parameters(node)): print(n) """ global COUNTER param = set() if isinstance(handle, str): param.add(currentlevel) return param if not isinstance(handle, collections.Iterable): return param for _, d in enumerate(handle): level = currentlevel + " " + d if isinstance(d, str) else currentlevel param.update(get_parameters(d, level)) if isinstance(handle, dict): param.update(get_parameters(handle[d], level)) return param