"""
====================
Disability Observers
====================
This module contains tools for observing years lived with disability (YLDs)
in the simulation.
"""
import pandas as pd
from loguru import logger
from pandas.api.types import CategoricalDtype
from vivarium.framework.engine import Builder
from vivarium_public_health.disease import DiseaseState, RiskAttributableDisease
from vivarium_public_health.results.columns import COLUMNS
from vivarium_public_health.results.observer import PublicHealthObserver
from vivarium_public_health.results.simple_cause import SimpleCause
from vivarium_public_health.utilities import to_years
[docs]
class DisabilityObserver(PublicHealthObserver):
"""Count years lived with disability.
By default, this counts both aggregate and cause-specific years lived
with disability over the full course of the simulation.
In the model specification, your configuration for this component should
be specified as, e.g.:
.. code-block:: yaml
configuration:
stratification:
disability:
exclude:
- "sex"
include:
- "sample_stratification"
Attributes
----------
step_size
The time step size of the simulation.
disability_weight
The pipeline that produces disability weights.
causes_of_disability
The causes of disability to be observed.
"""
##############
# Properties #
##############
@property
def disability_classes(self) -> list[type]:
"""The classes to be considered as causes of disability."""
return [DiseaseState, RiskAttributableDisease]
#################
# Setup methods #
#################
[docs]
def setup(self, builder: Builder) -> None:
"""Set up the observer."""
self.step_size = pd.Timedelta(days=builder.configuration.time.step_size)
self.set_causes_of_disability(builder)
[docs]
def set_causes_of_disability(self, builder: Builder) -> None:
"""Set the causes of disability to be observed.
The causes to be observed are any registered components of class types
found in the ``disability_classes`` property *excluding* any listed in
the model spec as ``excluded_categories``.
Notes
-----
We implement exclusions here instead of during the stratification call
like most other categories because disabilities are unique in that they are
*not* actually registered stratifications.
Also note that we add an 'all_causes' category here.
"""
causes_of_disability = builder.components.get_components_by_type(
self.disability_classes
)
# Convert to SimpleCause instances and add on all_causes
causes_of_disability = [
SimpleCause.create_from_specific_cause(cause) for cause in causes_of_disability
] + [SimpleCause("all_causes", "all_causes", "cause")]
excluded_causes = (
builder.configuration.stratification.excluded_categories.to_dict().get(
"disability", []
)
)
# Handle exclusions that don't exist in the list of causes
cause_names = [cause.state_id for cause in causes_of_disability]
unknown_exclusions = set(excluded_causes) - set(cause_names)
if len(unknown_exclusions) > 0:
raise ValueError(
f"Excluded 'disability' causes {unknown_exclusions} not found in "
f"expected categories categories: {cause_names}"
)
# Drop excluded causes
if excluded_causes:
logger.debug(
f"'disability' has category exclusion requests: {excluded_causes}\n"
"Removing these from the allowable categories."
)
self.causes_of_disability = [
cause for cause in causes_of_disability if cause.state_id not in excluded_causes
]
[docs]
def register_observations(self, builder: Builder) -> None:
"""Register an observation for years lived with disability."""
cause_pipelines = [
f"{cause.state_id}.disability_weight" for cause in self.causes_of_disability
]
self.register_adding_observation(
builder=builder,
name="ylds",
pop_filter="is_alive == True",
when="time_step__prepare",
requires_attributes=cause_pipelines,
additional_stratifications=self.configuration.include,
excluded_stratifications=self.configuration.exclude,
aggregator=self.disability_weight_aggregator,
)
###############
# Aggregators #
###############
[docs]
def disability_weight_aggregator(self, dw: pd.DataFrame) -> float | pd.Series:
"""Aggregate disability weights for the time step.
Parameters
----------
dw
The disability weights to aggregate.
Returns
-------
The aggregated disability weights.
"""
aggregated_dw = (dw * to_years(self.step_size)).sum().squeeze()
if isinstance(aggregated_dw, pd.Series):
aggregated_dw.index.name = "cause_of_disability"
return aggregated_dw
##############################
# Results formatting methods #
##############################
[docs]
def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'entity_type' column values."""
entity_type_map = {
cause.state_id: cause.cause_type for cause in self.causes_of_disability
}
return results[COLUMNS.SUB_ENTITY].map(entity_type_map).astype(CategoricalDtype())
[docs]
def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'entity' column values."""
entity_map = {cause.state_id: cause.model for cause in self.causes_of_disability}
return results[COLUMNS.SUB_ENTITY].map(entity_map).astype(CategoricalDtype())
[docs]
def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'sub_entity' column values."""
# The sub-entity col was created in the 'format' method
return results[COLUMNS.SUB_ENTITY]