Source code for vivarium_public_health.disease.special_disease

"""
========================
"Special" Disease Models
========================

This module contains frequently used, but non-standard disease models.

"""
from __future__ import annotations

import re
from collections import namedtuple
from collections.abc import Callable
from operator import gt, lt
from typing import Any

import pandas as pd
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData

from vivarium_public_health.causal_factor.calibration_constant import (
    register_risk_affected_attribute_producer,
)
from vivarium_public_health.disease.state import ExcessMortalityState
from vivarium_public_health.disease.transition import TransitionString
from vivarium_public_health.utilities import EntityString, is_non_zero


[docs] class RiskAttributableDisease(ExcessMortalityState): """Component to model a disease fully attributed by a risk. For some (risk, cause) pairs with population attributable fraction equal to 1, the clinical definition of the with condition state corresponds to a particular exposure of a risk. For example, a diagnosis of ``diabetes_mellitus`` occurs after repeated measurements of fasting plasma glucose above 7 mmol/L. Similarly, ``protein_energy_malnutrition`` corresponds to a weight for height ratio that is more than two standard deviations below the WHO guideline median weight for height. In the Global Burden of Disease, this corresponds to a categorical exposure to ``child_wasting`` in either ``cat1`` or ``cat2``. The definition of the disease in terms of exposure should be provided in the ``threshold`` configuration flag. For risks with continuous exposure models, the threshold should be provided as a single ``float`` or ``int`` with a proper sign between ">" and "<", implying that disease is defined by the exposure level ">" than threshold level or, "<" than threshold level, respectively. For categorical risks, the threshold should be provided as a list of categories. This list contains the categories that indicate the simulant is experiencing the condition. For a dichotomous risk there will be 2 categories. By convention ``cat1`` is used to indicate the with condition state and would be the single item in the ``threshold`` setting list. In addition to the threshold level, you may configure whether there is any mortality associated with this disease with the ``mortality`` configuration flag. Finally, you may specify whether an individual should "recover" from the disease if their exposure level falls outside the provided threshold. In our provided examples, a person would no longer be experiencing ``protein_energy_malnutrition`` if their exposure drift out (or changes via an intervention) of the provided exposure categories. Having your ``fasting_plasma_glucose`` drop below a provided level does not necessarily mean you're no longer diabetic. To add this component, you need to initialize it with full cause name and full risk name, e.g., RiskAttributableDisease('cause.protein_energy_malnutrition', 'risk_factor.child_wasting') Configuration defaults should be given as, for the continuous risk factor, diabetes_mellitus: threshold : ">7" mortality : True recoverable : False For the categorical risk factor, protein_energy_malnutrition: threshold : ['cat1', 'cat2'] # provide the categories to get PEM. mortality : True recoverable : True """ ############## # Properties # ############## @property def name(self): return f"risk_attributable_disease.{self.cause.name}" @property def configuration_defaults(self) -> dict[str, Any]: """Provides default configuration values for this component. Configuration structure:: {component_name}: data_sources: raw_disability_weight: Source for disability weight data. Default is the artifact key ``{cause}.disability_weight``. cause_specific_mortality_rate: Source for cause-specific mortality rate data. Default uses ``load_cause_specific_mortality_rate_data`` method which loads from artifact if ``mortality`` is True. excess_mortality_rate: Source for excess mortality rate data. Default uses ``load_excess_mortality_rate_data`` method which loads from artifact if ``mortality`` is True. population_attributable_fraction: Source for PAF data. Default is 0, indicating no mediated effects from other risks. threshold: str or list Exposure threshold defining disease state. For continuous risks, provide a string like ``">7"`` or ``"<5"``. For categorical risks, provide a list of categories (e.g., ``['cat1', 'cat2']``). mortality: bool Whether this disease has associated mortality. Default is True. recoverable: bool Whether simulants can recover from this disease when their exposure falls outside the threshold. Default is True. """ return { self.name: { "data_sources": { "raw_disability_weight": f"{self.cause}.disability_weight", "cause_specific_mortality_rate": self.load_cause_specific_mortality_rate_data, "excess_mortality_rate": self.load_excess_mortality_rate_data, "population_attributable_fraction": 0, }, "threshold": None, "mortality": True, "recoverable": True, } } @property def state_names(self) -> list[str]: """List of names of all states in this disease model.""" return self._state_names @property def transition_names(self) -> list[TransitionString]: """List of names of all transitions in this disease model.""" return self._transition_names ##################### # Lifecycle methods # ##################### def __init__(self, cause: str, risk: str) -> None: """ Parameters ---------- cause The full entity string for the cause (e.g., "cause.protein_energy_malnutrition"). risk The full entity string for the risk (e.g., "risk_factor.child_wasting"). """ super().__init__() self.cause = EntityString(cause) self.risk = EntityString(risk) self.state_column = self.cause.name self.cause_type = "risk_attributable_disease" self.model = self.risk.name self.state_id = self.cause.name self.diseased_event_time_column = f"{self.cause.name}_event_time" self.susceptible_event_time_column = f"susceptible_to_{self.cause.name}_event_time" self._state_names = [f"{self.cause.name}", f"susceptible_to_{self.cause.name}"] self._transition_names = [ TransitionString(f"susceptible_to_{self.cause.name}_TO_{self.cause.name}") ] self.disability_weight_name = f"{self.cause.name}.disability_weight" self.excess_mortality_rate_name = f"{self.cause.name}.excess_mortality_rate" self.exposure_name = f"{self.risk.name}.exposure" # noinspection PyAttributeOutsideInit
[docs] def setup(self, builder: Builder) -> None: """Perform this component's setup. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ self.recoverable = builder.configuration[self.name].recoverable self.adjust_state_and_transitions() self.clock = builder.time.clock() self.raw_disability_weight_table = self.build_lookup_table( builder, "raw_disability_weight" ) self.cause_specific_mortality_rate_table = self.build_lookup_table( builder, "cause_specific_mortality_rate" ) self.excess_mortality_rate_table = self.build_lookup_table( builder, "excess_mortality_rate" ) if self._has_excess_mortality is None: self._has_excess_mortality = is_non_zero(self.excess_mortality_rate_table.data) self.population_attributable_fraction_table = self.build_lookup_table( builder, "population_attributable_fraction" ) builder.value.register_attribute_producer( self.disability_weight_name, source=self.compute_disability_weight, required_resources=[self.raw_disability_weight_table], ) builder.value.register_attribute_modifier( "all_causes.disability_weight", modifier=self.disability_weight_name ) builder.value.register_attribute_modifier( "cause_specific_mortality_rate", self.adjust_cause_specific_mortality_rate, required_resources=[self.cause_specific_mortality_rate_table], ) register_risk_affected_attribute_producer( builder=builder, name=self.excess_mortality_rate_name, source=self.compute_excess_mortality_rate, required_resources=[self.excess_mortality_rate_table], ) builder.value.register_attribute_modifier( "mortality_rate", modifier=self.adjust_mortality_rate, required_resources=[self.excess_mortality_rate_name], ) distribution = builder.data.load(f"{self.risk}.distribution") threshold = builder.configuration[self.name].threshold self.filter_by_exposure = self.get_exposure_filter(distribution, threshold) builder.population.register_initializer( initializer=self.initialize_disease, columns=[ self.cause.name, self.diseased_event_time_column, self.susceptible_event_time_column, ], required_resources=[self.exposure_name], )
################# # Setup methods # #################
[docs] def adjust_state_and_transitions(self) -> None: """Add recovery transition if the disease is recoverable.""" if self.recoverable: self._transition_names.append( TransitionString(f"{self.cause.name}_TO_susceptible_to_{self.cause.name}") )
[docs] def load_cause_specific_mortality_rate_data( self, builder: Builder ) -> float | pd.DataFrame: """Load cause-specific mortality rate data. Parameters ---------- builder Access point for utilizing framework interfaces during setup. Returns ------- The cause-specific mortality rate data, or 0 if mortality is disabled. """ if builder.configuration[self.name].mortality: csmr_data = builder.data.load( f"cause.{self.cause.name}.cause_specific_mortality_rate" ) else: csmr_data = 0 return csmr_data
[docs] def load_excess_mortality_rate_data(self, builder: Builder) -> float | pd.DataFrame: """Load excess mortality rate data. Parameters ---------- builder Access point for utilizing framework interfaces during setup. Returns ------- The excess mortality rate data, or 0 if mortality is disabled. """ if builder.configuration[self.name].mortality: emr_data = builder.data.load(f"cause.{self.cause.name}.excess_mortality_rate") else: emr_data = 0 return emr_data
[docs] def get_exposure_filter(self, distribution: str, threshold: Any) -> Callable: """Build a filter function that identifies simulants with the condition. Parameters ---------- distribution The risk's exposure distribution type. threshold The exposure threshold defining the disease state. For continuous risks, a string like ">7". For categorical risks, a list of category names. Returns ------- A function that takes a simulant index and returns a boolean series indicating which simulants have the condition. """ if distribution in ["dichotomous", "ordered_polytomous", "unordered_polytomous"]: def categorical_filter(index): exposure = self.population_view.get(index, self.exposure_name) return exposure.isin(threshold) filter_function = categorical_filter else: # continuous Threshold = namedtuple("Threshold", ["operator", "value"]) threshold_val = re.findall(r"[-+]?\d*\.?\d+", threshold) if len(threshold_val) != 1: raise ValueError( f"Your {threshold} is an incorrect threshold format. It should include " f'"<" or ">" along with an integer or float number. Your threshold does not ' f"include a number or more than one number." ) allowed_operator = {"<", ">"} threshold_op = [s for s in threshold.split(threshold_val[0]) if s] # if threshold_op has more than 1 operators or 0 operator if len(threshold_op) != 1 or not allowed_operator.intersection(threshold_op): raise ValueError( f"Your {threshold} is an incorrect threshold format. It should include " f'"<" or ">" along with an integer or float number.' ) op = gt if threshold_op[0] == ">" else lt threshold = Threshold(op, float(threshold_val[0])) def continuous_filter(index): exposure = self.population_view.get(index, self.exposure_name) return threshold.operator(exposure, threshold.value) filter_function = continuous_filter return filter_function
######################## # Event-driven methods # ########################
[docs] def initialize_disease(self, pop_data: SimulantData) -> None: """Initialize disease state for new simulants based on exposure. Parameters ---------- pop_data Metadata about the simulants being initialized. """ new_pop = pd.DataFrame( { self.cause.name: f"susceptible_to_{self.cause.name}", self.diseased_event_time_column: pd.Series(pd.NaT, index=pop_data.index), self.susceptible_event_time_column: pd.Series(pd.NaT, index=pop_data.index), }, index=pop_data.index, ) sick = self.filter_by_exposure(pop_data.index) new_pop.loc[sick, self.cause.name] = self.cause.name new_pop.loc[ sick, self.diseased_event_time_column ] = self.clock() # match VPH disease, only set w/ condition self.population_view.initialize(new_pop)
[docs] def on_time_step(self, event: Event) -> None: """Update disease state based on current exposure levels. Parameters ---------- event The event that triggered this method call. """ def _update_disease_state(pop: pd.DataFrame) -> pd.DataFrame: living_idx = self.population_view.get_filtered_index( event.index, query="is_alive == True" ) update = pop.loc[living_idx] sick = self.filter_by_exposure(living_idx) # if this is recoverable, anyone who gets lower exposure in the event goes back in to susceptible status. if self.recoverable: change_to_susceptible = (~sick) & ( update[self.cause.name] != f"susceptible_to_{self.cause.name}" ) update.loc[ change_to_susceptible, self.susceptible_event_time_column ] = event.time update.loc[ change_to_susceptible, self.cause.name ] = f"susceptible_to_{self.cause.name}" change_to_diseased = sick & (update[self.cause.name] != self.cause.name) update.loc[change_to_diseased, self.diseased_event_time_column] = event.time update.loc[change_to_diseased, self.cause.name] = self.cause.name return update self.population_view.update( self.population_view.private_columns, _update_disease_state )
################################## # Pipeline sources and modifiers # ##################################
[docs] def compute_disability_weight(self, index: pd.Index[int]) -> pd.Series[float]: """Get the disability weight associated with this disease. Parameters ---------- index An iterable of integer labels for the simulants. Returns ------- An iterable of disability weights indexed by the provided ``index``. """ disability_weight = pd.Series(0.0, index=index) with_condition = self.with_condition(index) disability_weight.loc[with_condition] = self.raw_disability_weight_table( with_condition ) return disability_weight
[docs] def compute_excess_mortality_rate(self, index: pd.Index[int]) -> pd.Series[float]: """Get the excess mortality rate associated with this disease. Parameters ---------- index An iterable of integer labels for the simulants. Returns ------- An iterable of excess mortality rates indexed by the provided ``index``. """ excess_mortality_rate = pd.Series(0.0, index=index) with_condition = self.with_condition(index) base_excess_mort = self.excess_mortality_rate_table(with_condition) excess_mortality_rate.loc[with_condition] = base_excess_mort return excess_mortality_rate
[docs] def adjust_cause_specific_mortality_rate( self, index: pd.Index[int], rate: pd.Series[float] ) -> pd.Series[float]: """Modify the cause-specific mortality rate for the given simulants. Parameters ---------- index An iterable of integer labels for the simulants. rate The base cause-specific mortality rate. Returns ------- The adjusted cause-specific mortality rate. """ return rate + self.cause_specific_mortality_rate_table(index)
[docs] def adjust_mortality_rate( self, index: pd.Index[int], rates_df: pd.DataFrame ) -> pd.DataFrame: """Modifies the baseline mortality rate for a simulant if they are in this state. Parameters ---------- index An iterable of integer labels for the simulants. rates_df A DataFrame of mortality rates. Returns ------- The modified DataFrame of mortality rates. """ rate = self.population_view.get( index, self.excess_mortality_rate_name, skip_post_processor=True ) rates_df[self.cause.name] = rate return rates_df
################## # Helper methods # ##################
[docs] def with_condition(self, index: pd.Index[int]) -> pd.Index[int]: """Get the subset of simulants who have this condition. Parameters ---------- index An iterable of integer labels for the simulants. Returns ------- The subset of simulants who are alive and have this condition. """ return self.population_view.get_filtered_index( index, query=f'is_alive == True and {self.cause.name} == "{self.cause.name}"', )