Observers
This tutorial serves two purposes: it gives minimal working examples of each public health observer, and it demonstrates three ways to configure stratification.
These observer classes are public-health-specific helpers built on top of the
vivarium framework’s
Observer base class (see the
vivarium results concepts
documentation for details on the underlying results system).
Common Setup
from vivarium import Component, InteractiveContext
from vivarium.framework.engine import Builder
from vivarium_public_health.disease import (
DiseaseModel, DiseaseState, SusceptibleState, SI, SIS,
)
from vivarium_public_health.population import BasePopulation
from vivarium_public_health.results import (
DiseaseObserver, MortalityObserver, DisabilityObserver,
CategoricalRiskObserver, ResultsStratifier,
)
from vivarium_public_health.risks import Risk
from vivarium_public_health._example_data import (
BASE_PLUGINS, make_base_config, build_cause_table,
disease_disability_weight,
)
base_plugins = BASE_PLUGINS
config = make_base_config()
DiseaseObserver
A DiseaseObserver registers two
observations for a disease model:
person_time_{disease}- person-years spent in each disease state, accumulated each time step. Thesub_entitycolumn contains the state name (e.g.,"susceptible_to_test_cause","test_cause").transition_count_{disease}- count of simulants transitioning between states each time step. Thesub_entitycolumn contains the transition name (e.g.,"susceptible_to_test_cause_to_test_cause"). Only transitions that actually occur appear in the output.
config = make_base_config()
config.update({"population": {"population_size": 1000}}, layer="model_override")
sim = InteractiveContext(
components=[
BasePopulation(),
SI("test_cause"),
DiseaseObserver("test_cause"),
ResultsStratifier(),
],
configuration=config,
plugin_configuration=base_plugins,
)
sim.step()
results = sim.get_results()
print(sorted(results.keys()))
['person_time_test_cause', 'transition_count_test_cause']
pt = results["person_time_test_cause"]
print(pt.columns.tolist())
print(pt["sub_entity"].tolist())
['measure', 'entity_type', 'entity', 'sub_entity', 'value']
['susceptible_to_test_cause', 'test_cause']
tc = results["transition_count_test_cause"]
print(tc["measure"].iloc[0])
print(tc["sub_entity"].iloc[0])
assert tc["value"].iloc[0] > 0
transition_count
susceptible_to_test_cause_to_test_cause
MortalityObserver
A MortalityObserver registers two
observations, stratified by cause of death:
deaths- count of simulants who died during each time step. Theentitycolumn contains the cause name or"other_causes".ylls- sum of remaining life expectancy at death (years of life lost). Uses the same cause-level breakdown asdeaths.
To produce non-zero values, the simulation needs a disease state with
non-zero excess_mortality_rate.
healthy = SusceptibleState("test_cause")
infected = DiseaseState("test_cause", excess_mortality_rate=build_cause_table(5.0))
healthy.add_rate_transition(infected)
fatal_model = DiseaseModel("test_cause", states=[healthy, infected])
config = make_base_config()
config.update({"population": {"population_size": 1000}}, layer="model_override")
sim = InteractiveContext(
components=[
BasePopulation(),
fatal_model,
MortalityObserver(),
ResultsStratifier(),
],
configuration=config,
plugin_configuration=base_plugins,
)
for _ in range(5):
sim.step()
results = sim.get_results()
deaths = results["deaths"]
print(deaths.columns.tolist())
print(deaths["entity"].tolist())
['measure', 'entity_type', 'entity', 'sub_entity', 'value']
['test_cause', 'other_causes']
test_cause_deaths = deaths.loc[deaths["entity"] == "test_cause", "value"].iloc[0]
assert test_cause_deaths > 0
ylls = results["ylls"]
test_cause_ylls = ylls.loc[ylls["entity"] == "test_cause", "value"].iloc[0]
assert test_cause_ylls > 0
DisabilityObserver
A DisabilityObserver registers one
observation:
ylds- years lived with disability, computed as each simulant’s disability weight multiplied by the time step duration, summed across simulants. Results are broken out by cause in theentitycolumn, plus an"all_causes"total row.
It requires at least one disease state with a non-zero
disability_weight.
healthy = SusceptibleState("test_cause")
infected = DiseaseState("test_cause", disability_weight=disease_disability_weight(0.3))
healthy.add_rate_transition(infected)
disability_model = DiseaseModel("test_cause", states=[healthy, infected])
config = make_base_config()
config.update({"population": {"population_size": 1000}}, layer="model_override")
sim = InteractiveContext(
components=[
BasePopulation(),
disability_model,
DisabilityObserver(),
ResultsStratifier(),
],
configuration=config,
plugin_configuration=base_plugins,
)
for _ in range(3):
sim.step()
results = sim.get_results()
ylds = results["ylds"]
print(ylds.columns.tolist())
print(sorted(ylds["entity"].unique().tolist()))
['measure', 'entity_type', 'entity', 'sub_entity', 'stratification', 'value']
['all_causes', 'test_cause']
test_cause_ylds = ylds.loc[ylds["entity"] == "test_cause", "value"].iloc[0]
assert test_cause_ylds > 0
CategoricalRiskObserver
A CategoricalRiskObserver registers
one observation:
person_time_{risk}- person-years spent in each exposure category, accumulated each time step. Thesub_entitycolumn contains the category name (e.g.,"exposed","unexposed"). Theentity_typeis"rei"(risk/etiology/impairment).
config = make_base_config()
config.update({"population": {"population_size": 1000}}, layer="model_override")
sim = InteractiveContext(
components=[
BasePopulation(),
Risk("risk_factor.test_risk"),
CategoricalRiskObserver("test_risk"),
ResultsStratifier(),
],
configuration=config,
plugin_configuration=base_plugins,
)
for _ in range(3):
sim.step()
results = sim.get_results()
pt = results["person_time_test_risk"]
print(pt.columns.tolist())
print(sorted(pt["sub_entity"].tolist()))
['measure', 'entity_type', 'entity', 'sub_entity', 'value']
['exposed', 'unexposed']
assert all(pt["value"] > 0)
exposed_pt = pt.loc[pt["sub_entity"] == "exposed", "value"].iloc[0]
unexposed_pt = pt.loc[pt["sub_entity"] == "unexposed", "value"].iloc[0]
assert exposed_pt > unexposed_pt
Stratification
A stratification splits observer output into sub-groups based on simulant attributes (e.g. age group, sex, or custom categories). Each stratification adds a column to the results table whose values identify which group each row belongs to. You can include, exclude, or define custom stratifications per observer.
Including a stratification
Add a registered stratification to one observer via
stratification.<observer_name>.include. Here we include sex, one of
the four stratifications registered by ResultsStratifier (age_group,
current_year, event_year, and sex):
config = make_base_config()
config.update(
{
"population": {"population_size": 1000},
"stratification": {
"test_cause": {
"include": ["sex"],
"exclude": [],
},
},
},
layer="model_override",
)
sim = InteractiveContext(
components=[
BasePopulation(),
SI("test_cause"),
DiseaseObserver("test_cause"),
ResultsStratifier(),
],
configuration=config,
plugin_configuration=base_plugins,
)
sim.step()
pt = sim.get_results()["person_time_test_cause"]
print(pt.columns.tolist())
print(sorted(pt["sex"].unique().tolist()))
print(len(pt))
['measure', 'entity_type', 'entity', 'sub_entity', 'sex', 'value']
['Female', 'Male']
4
Excluding a default stratification
Set global defaults with stratification.default, then exclude specific
ones per observer with stratification.<observer_name>.exclude:
config = make_base_config()
config.update(
{
"population": {"population_size": 1000},
"stratification": {
"default": ["age_group", "sex"],
"test_cause": {
"include": [],
"exclude": ["age_group"],
},
},
},
layer="model_override",
)
sim = InteractiveContext(
components=[
BasePopulation(),
SI("test_cause"),
DiseaseObserver("test_cause"),
ResultsStratifier(),
],
configuration=config,
plugin_configuration=base_plugins,
)
sim.step()
pt = sim.get_results()["person_time_test_cause"]
# age_group excluded - only sex remains from defaults
print(pt.columns.tolist())
print(len(pt))
['measure', 'entity_type', 'entity', 'sub_entity', 'sex', 'value']
4
Including a custom stratification
Register a custom stratification from any component, then reference it by
name in the observer’s include list:
import pandas as pd
class AgeCohortStratifier(Component):
"""Register a binary young/old stratification."""
def setup(self, builder: Builder) -> None:
builder.results.register_stratification(
"age_cohort",
["young", "old"],
mapper=self.map_age_cohort,
is_vectorized=True,
requires_attributes=["age"],
)
@staticmethod
def map_age_cohort(pop: pd.DataFrame) -> pd.Series:
age = pop.squeeze(axis=1)
return age.apply(lambda a: "young" if a < 50 else "old")
config = make_base_config()
config.update(
{
"population": {"population_size": 1000},
"stratification": {
"test_cause": {
"include": ["age_cohort"],
"exclude": [],
},
},
},
layer="model_override",
)
sim = InteractiveContext(
components=[
BasePopulation(),
SI("test_cause"),
DiseaseObserver("test_cause"),
AgeCohortStratifier(),
ResultsStratifier(),
],
configuration=config,
plugin_configuration=base_plugins,
)
sim.step()
pt = sim.get_results()["person_time_test_cause"]
print(pt.columns.tolist())
print(sorted(pt["age_cohort"].unique().tolist()))
['measure', 'entity_type', 'entity', 'sub_entity', 'age_cohort', 'value']
['old', 'young']
young_total = pt.loc[pt["age_cohort"] == "young", "value"].sum()
old_total = pt.loc[pt["age_cohort"] == "old", "value"].sum()
assert young_total > 0
assert old_total > 0