"""
==================
Results Stratifier
==================
This module contains tools for stratifying observed quantities
by specified characteristics through the vivarium results interface.
"""
import pandas as pd
from vivarium import Component
from vivarium.framework.engine import Builder
[docs]
class ResultsStratifier(Component):
"""A component for registering common public health stratifications.
The purpose of this component is to encapsulate all common public health
stratification registrations in one place. This is not enforced, however,
and stratification registrations can be done in any component.
Attributes
----------
age_bins
The age bins for stratifying by age.
start_year
The start year of the simulation.
end_year
The end year of the simulation.
"""
#####################
# Lifecycle methods #
#####################
# noinspection PyAttributeOutsideInit
[docs]
def setup(self, builder: Builder) -> None:
"""Set up the stratifier.
Define age bins and simulation years and register default stratifications.
"""
self.age_bins = self.get_age_bins(builder)
self.start_year = builder.configuration.time.start.year
self.end_year = builder.configuration.time.end.year
self.register_stratifications(builder)
#################
# Setup methods #
#################
[docs]
def register_stratifications(self, builder: Builder) -> None:
"""Register stratifications for the simulation."""
builder.results.register_stratification(
"age_group",
self.age_bins["age_group_name"].to_list(),
mapper=self.map_age_groups,
is_vectorized=True,
requires_attributes=["age"],
)
builder.results.register_stratification(
"current_year",
[str(year) for year in range(self.start_year, self.end_year + 1)],
mapper=self.map_year,
is_vectorized=True,
requires_attributes=["current_time"],
)
builder.results.register_stratification(
"event_year",
[str(year) for year in range(self.start_year, self.end_year + 2)],
excluded_categories=[str(self.end_year + 1)],
mapper=self.map_year,
is_vectorized=True,
requires_attributes=["event_time"],
)
# TODO [MIC-3892]: simulants occasionally have entrance year of start_year-1 if the start time minus step size
# lands in the previous year. possible solution detailed in ticket
# builder.results.register_stratification(
# "entrance_year",
# [str(year) for year in range(self.start_year, self.end_year + 1)],
# self.map_year,
# is_vectorized=True,
# requires_attributes=["entrance_time"],
# )
# TODO [MIC-4083]: Known bug with this registration
# builder.results.register_stratification(
# "exit_year",
# [str(year) for year in range(self.start_year, self.end_year + 1)] + ["nan"],
# mapper=self.map_year,
# is_vectorized=True,
# requires_attributes=["exit_time"],
# )
builder.results.register_stratification(
"sex", ["Female", "Male"], requires_attributes=["sex"]
)
###########
# Mappers #
###########
[docs]
def map_age_groups(self, pop: pd.DataFrame) -> pd.Series:
"""Map age with age group name strings.
Parameters
----------
pop
A table with one column, an age to be mapped to an age group name string.
Returns
-------
The age group name strings corresponding to the pop passed into the function.
"""
bins = self.age_bins["age_start"].to_list() + [self.age_bins["age_end"].iloc[-1]]
labels = self.age_bins["age_group_name"].to_list()
age_group = pd.cut(pop.squeeze(axis=1), bins, labels=labels).rename("age_group")
return age_group
[docs]
@staticmethod
def map_year(pop: pd.DataFrame) -> pd.Series:
"""Map datetime with year.
Parameters
----------
pop
A table with one column, a datetime to be mapped to year.
Returns
-------
The years corresponding to the pop passed into the function.
"""
return pop.squeeze(axis=1).dt.year.apply(str)
[docs]
@staticmethod
def get_age_bins(builder: Builder) -> pd.DataFrame:
"""Get the age bins for stratifying by age.
Parameters
----------
builder
The builder object for the simulation.
Returns
-------
The age bins for stratifying by age.
"""
raw_age_bins = builder.data.load("population.age_bins")
age_start = builder.configuration.population.initialization_age_min
exit_age = builder.configuration.population.untracking_age
age_start_mask = age_start < raw_age_bins["age_end"]
exit_age_mask = raw_age_bins["age_start"] < exit_age if exit_age else True
age_bins = raw_age_bins.loc[age_start_mask & exit_age_mask, :].copy()
age_bins["age_group_name"] = (
age_bins["age_group_name"].str.replace(" ", "_").str.lower()
)
return age_bins