Source code for vivarium_public_health.results.mortality

"""
===================
Mortality Observers
===================

This module contains tools for observing cause-specific and
excess mortality in the simulation, including "other causes".

"""

from typing import Any

import pandas as pd
from pandas.api.types import CategoricalDtype
from vivarium.framework.engine import Builder

from vivarium_public_health.disease import DiseaseState, RiskAttributableDisease
from vivarium_public_health.disease.state import ExcessMortalityState
from vivarium_public_health.results.columns import COLUMNS
from vivarium_public_health.results.observer import PublicHealthObserver
from vivarium_public_health.results.simple_cause import SimpleCause


[docs] class MortalityObserver(PublicHealthObserver): """Observe cause-specific deaths and YLLs (including "other causes"). By default, this counts cause-specific deaths and years of life lost over the full course of the simulation. It can be configured to add or remove stratification groups to the default groups defined by a :class:`~vivarium_public_health.results.stratification.ResultsStratifier`. The aggregate configuration key can be set to True to aggregate all deaths and YLLs into a single observation and remove the stratification by cause of death to improve runtime. In the model specification, your configuration for this component should be specified as, e.g.: .. code-block:: yaml configuration: stratification: mortality: exclude: - "sex" include: - "sample_stratification" This observer needs to access the has_excess_mortality attribute of the causes we're observing, but this attribute gets defined in the setup of the cause models. As a result, the model specification should list this observer after causes. Attributes ---------- clock The simulation clock. causes_of_death Causes of death to be observed. """ ############## # Properties # ############## @property def configuration_defaults(self) -> dict[str, Any]: """Default configuration values for this observer. Extends the base PublicHealthObserver configuration with mortality-specific settings. Configuration structure:: stratification: {observer_name}: exclude: list[str] Stratification groups to exclude from results. Inherited from base observer. include: list[str] Additional stratification groups to include. Inherited from base observer. aggregate: bool If True, aggregates all deaths and YLLs into a single observation without cause-specific breakdown. Default is False (cause-specific results). """ config_defaults = super().configuration_defaults config_defaults["stratification"][self.get_configuration_name()]["aggregate"] = False return config_defaults ################# # Setup methods # #################
[docs] def setup(self, builder: Builder) -> None: """Set up the observer.""" self.clock = builder.time.clock() self.set_causes_of_death(builder)
[docs] def set_causes_of_death(self, builder: Builder) -> None: """Set the causes of death to be observed. The causes to be observed are any registered components of class types found in the ``mortality_classes`` property. Notes ----- We do not actually exclude any categories in this method. Also note that we add 'not_dead' and 'other_causes' categories here. """ excess_mortality_states: list[ExcessMortalityState] = [ cause for cause in builder.components.get_components_by_type(ExcessMortalityState) ] # Convert to SimpleCauses and add on other_causes and not_dead self.causes_of_death = [ SimpleCause.create_from_specific_cause(cause) for cause in excess_mortality_states if cause.has_excess_mortality(builder) ] + [ SimpleCause("not_dead", "not_dead", "cause"), SimpleCause("other_causes", "other_causes", "cause"), ]
[docs] def register_observations(self, builder: Builder) -> None: """Register stratifications and observations. Notes ----- Ideally, each observer registers a single observation. This one, however, registers two. While it's typical for all stratification registrations to be encapsulated in a single class (i.e. the :class:`ResultsStratifier <vivarium_public_health.results.stratification.ResultsStratifier>`), this observer potentially registers an additional one. While it could be registered in the ``ResultsStratifier`` as well, it is specific to this observer and so it is registered here while we have easy access to the required categories. """ pop_filter = "is_alive == False" additional_stratifications = self.configuration.include if not self.configuration.aggregate: # manually append 'not_dead' as an excluded cause excluded_categories = ( builder.configuration.stratification.excluded_categories.to_dict().get( "cause_of_death", [] ) ) + ["not_dead"] builder.results.register_stratification( "cause_of_death", [cause.state_id for cause in self.causes_of_death], excluded_categories=excluded_categories, requires_attributes=["cause_of_death"], ) additional_stratifications += ["cause_of_death"] self.register_adding_observation( builder=builder, name="deaths", pop_filter=pop_filter, requires_attributes=["exit_time"], additional_stratifications=additional_stratifications, excluded_stratifications=self.configuration.exclude, aggregator=self.count_deaths, ) self.register_adding_observation( builder=builder, name="ylls", pop_filter=pop_filter, requires_attributes=["exit_time", "years_of_life_lost"], additional_stratifications=additional_stratifications, excluded_stratifications=self.configuration.exclude, aggregator=self.calculate_ylls, )
############### # Aggregators # ###############
[docs] def count_deaths(self, x: pd.DataFrame) -> float: """Count the number of deaths that occurred during this time step.""" died_of_cause = x["exit_time"] > self.clock() return sum(died_of_cause)
[docs] def calculate_ylls(self, x: pd.DataFrame) -> float: """Calculate the years of life lost during this time step.""" died_of_cause = x["exit_time"] > self.clock() return x.loc[died_of_cause, "years_of_life_lost"].sum()
############################## # Results formatting methods # ##############################
[docs] def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame: """Rename the appropriate column to 'entity'. The primary thing this method does is rename the 'cause_of_death' column to 'entity' (or, it we are aggregating, and there is no 'cause_of_death' column, we simply create a new 'entity' column). We do this here instead of the 'get_entity_column' method simply because we do not want the 'cause_of_death' at all. If we keep it here and then return it as the entity column later, the final results would have both. Parameters ---------- measure The measure. results The results to format. Returns ------- The formatted results. """ results = results.reset_index() if self.configuration.aggregate: results[COLUMNS.ENTITY] = "all_causes" else: results.rename(columns={"cause_of_death": COLUMNS.ENTITY}, inplace=True) return results
[docs] def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series: """Get the 'entity_type' column values.""" entity_type_map = {cause.state_id: cause.cause_type for cause in self.causes_of_death} return results[COLUMNS.ENTITY].map(entity_type_map).astype(CategoricalDtype())
[docs] def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series: """Get the 'entity' column values.""" # The entity col was created in the 'format' method return results[COLUMNS.ENTITY]
[docs] def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series: """Get the 'sub_entity' column values.""" return results[COLUMNS.ENTITY]