"""
=================
Disease Observers
=================
This module contains tools for observing disease incidence and prevalence
in the simulation.
"""
import pandas as pd
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium_public_health.results.columns import COLUMNS
from vivarium_public_health.results.observer import PublicHealthObserver
from vivarium_public_health.utilities import to_years
[docs]
class DiseaseObserver(PublicHealthObserver):
"""Observes disease counts and person time for a cause.
By default, this observer computes aggregate disease state person time and
counts of disease events over the full course of the simulation. It can be
configured to add or remove stratification groups to the default groups
defined by a ResultsStratifier.
In the model specification, your configuration for this component should
be specified as, e.g.:
.. code-block:: yaml
configuration:
stratification:
cause_name:
exclude:
- "sex"
include:
- "sample_stratification"
Attributes
----------
disease
The name of the disease being observed.
previous_state_column_name
The name of the column that stores the previous state of the disease.
step_size
The time step size of the simulation.
disease_model
The disease model for the disease being observed.
entity_type
The type of entity being observed.
entity
The entity being observed.
transition_stratification_name
The stratification name for transitions between disease states.
"""
#####################
# Lifecycle methods #
#####################
def __init__(self, disease: str) -> None:
"""Constructor for this observer.
Parameters
----------
disease
The name of the disease being observed.
"""
super().__init__()
self.disease = disease
self.previous_state_column_name = f"previous_{self.disease}"
#################
# Setup methods #
#################
[docs]
def setup(self, builder: Builder) -> None:
"""Set up the observer."""
self.step_size = builder.time.step_size()
self.disease_model = builder.components.get_component(f"disease_model.{self.disease}")
self.entity_type = self.disease_model.cause_type
self.entity = self.disease_model.cause
self.transition_stratification_name = f"transition_{self.disease}"
builder.population.register_initializer(
initializer=self.initialize_previous_state,
columns=self.previous_state_column_name,
required_resources=[self.disease],
)
[docs]
def get_configuration_name(self) -> str:
return self.disease
[docs]
def register_observations(self, builder: Builder) -> None:
"""Register stratifications and observations.
Notes
-----
Ideally, each observer registers a single observation. This one, however,
registeres two.
While it's typical for all stratification registrations to be encapsulated
in a single class (i.e. the
:class:ResultsStratifier <vivarium_public_health.results.stratification.ResultsStratifier),
this observer registers two additional stratifications. While they could
be registered in the ``ResultsStratifier`` as well, they are specific to
this observer and so they are registered here while we have easy access
to the required names and categories.
"""
self.register_disease_state_stratification(builder)
self.register_transition_stratification(builder)
pop_filter = "is_alive == True"
self.register_person_time_observation(builder, pop_filter)
self.register_transition_count_observation(builder, pop_filter)
[docs]
def register_disease_state_stratification(self, builder: Builder) -> None:
"""Register the disease state stratification."""
builder.results.register_stratification(
self.disease,
[state.state_id for state in self.disease_model.states],
requires_attributes=[self.disease],
)
[docs]
def register_transition_stratification(self, builder: Builder) -> None:
"""Register the transition stratification.
This stratification is used to track transitions between disease states.
It appends 'no_transition' to the list of transition categories and also
includes it as an exluded category.
Notes
-----
It is important to include 'no_transition' in bith the list of transition
categories as well as the list of excluded categories. This is because
it must exist as a category for the transition mapping to work correctly,
but then we don't want to include it later during the actual stratification
process.
"""
transitions = [
str(transition) for transition in self.disease_model.transition_names
] + ["no_transition"]
# manually append 'no_transition' as an excluded transition
excluded_categories = (
builder.configuration.stratification.excluded_categories.to_dict().get(
self.transition_stratification_name, []
)
) + ["no_transition"]
builder.results.register_stratification(
self.transition_stratification_name,
categories=transitions,
excluded_categories=excluded_categories,
mapper=self.map_transitions,
requires_attributes=[self.disease, self.previous_state_column_name],
is_vectorized=True,
)
[docs]
def register_person_time_observation(self, builder: Builder, pop_filter: str) -> None:
"""Register a person time observation."""
self.register_adding_observation(
builder=builder,
name=f"person_time_{self.disease}",
pop_filter=pop_filter,
when="time_step__prepare",
additional_stratifications=self.configuration.include + [self.disease],
excluded_stratifications=self.configuration.exclude,
aggregator=self.aggregate_state_person_time,
)
[docs]
def register_transition_count_observation(
self, builder: Builder, pop_filter: str
) -> None:
"""Register a transition count observation."""
self.register_adding_observation(
builder=builder,
name=f"transition_count_{self.disease}",
pop_filter=pop_filter,
additional_stratifications=self.configuration.include
+ [self.transition_stratification_name],
excluded_stratifications=self.configuration.exclude,
)
[docs]
def map_transitions(self, df: pd.DataFrame) -> pd.Series:
"""Map previous and current disease states to transition string.
Parameters
----------
df
The DataFrame containing the disease states.
Returns
-------
The transitions between disease states.
"""
transitions = pd.Series(index=df.index, dtype=str)
transition_mask = df[self.previous_state_column_name] != df[self.disease]
transitions[~transition_mask] = "no_transition"
transitions[transition_mask] = (
df[self.previous_state_column_name].astype(str)
+ "_to_"
+ df[self.disease].astype(str)
)
return transitions
########################
# Event-driven methods #
########################
[docs]
def initialize_previous_state(self, pop_data: SimulantData) -> None:
"""Initialize the previous state column to the current state"""
previous_states = self.population_view.get(pop_data.index, self.disease)
previous_states.name = self.previous_state_column_name
self.population_view.initialize(previous_states)
[docs]
def on_time_step_prepare(self, event: Event) -> None:
"""Update the previous state column to the current state.
This enables tracking of transitions between states.
"""
current_states = self.population_view.get(event.index, self.disease)
self.population_view.update(
self.previous_state_column_name,
lambda _: current_states.rename(self.previous_state_column_name),
)
###############
# Aggregators #
###############
[docs]
def aggregate_state_person_time(self, x: pd.DataFrame) -> float:
"""Aggregate person time for the time step.
Parameters
----------
x
The DataFrame containing the population.
Returns
-------
The aggregated person time.
"""
return len(x) * to_years(self.step_size())
##############################
# Results formatting methods #
##############################
[docs]
def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'measure' column values."""
if "transition_count_" in measure:
measure_name = "transition_count"
if "person_time_" in measure:
measure_name = "person_time"
return pd.Series(measure_name, index=results.index)
[docs]
def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'entity_type' column values."""
return pd.Series(self.entity_type, index=results.index)
[docs]
def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'entity' column values."""
return pd.Series(self.entity, index=results.index)
[docs]
def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'sub_entity' column values."""
# The sub-entity col was created in the 'format' method
return results[COLUMNS.SUB_ENTITY]