Source code for vivarium_public_health.disease.special_disease

"""
========================
"Special" Disease Models
========================

This module contains frequently used, but non-standard disease models.

"""
import re
from collections import namedtuple
from operator import gt, lt
from typing import Any, Dict, List, Optional

import pandas as pd
from vivarium import Component
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.values import list_combiner, union_post_processor

from vivarium_public_health.disease.transition import TransitionString
from vivarium_public_health.utilities import EntityString, is_non_zero


[docs]class RiskAttributableDisease(Component): """Component to model a disease fully attributed by a risk. For some (risk, cause) pairs with population attributable fraction equal to 1, the clinical definition of the with condition state corresponds to a particular exposure of a risk. For example, a diagnosis of ``diabetes_mellitus`` occurs after repeated measurements of fasting plasma glucose above 7 mmol/L. Similarly, ``protein_energy_malnutrition`` corresponds to a weight for height ratio that is more than two standard deviations below the WHO guideline median weight for height. In the Global Burden of Disease, this corresponds to a categorical exposure to ``child_wasting`` in either ``cat1`` or ``cat2``. The definition of the disease in terms of exposure should be provided in the ``threshold`` configuration flag. For risks with continuous exposure models, the threshold should be provided as a single ``float`` or ``int`` with a proper sign between ">" and "<", implying that disease is defined by the exposure level ">" than threshold level or, "<" than threshold level, respectively. For categorical risks, the threshold should be provided as a list of categories. This list contains the categories that indicate the simulant is experiencing the condition. For a dichotomous risk there will be 2 categories. By convention ``cat1`` is used to indicate the with condition state and would be the single item in the ``threshold`` setting list. In addition to the threshold level, you may configure whether there is any mortality associated with this disease with the ``mortality`` configuration flag. Finally, you may specify whether an individual should "recover" from the disease if their exposure level falls outside the provided threshold. In our provided examples, a person would no longer be experiencing ``protein_energy_malnutrition`` if their exposure drift out (or changes via an intervention) of the provided exposure categories. Having your ``fasting_plasma_glucose`` drop below a provided level does not necessarily mean you're no longer diabetic. To add this component, you need to initialize it with full cause name and full risk name, e.g., RiskAttributableDisease('cause.protein_energy_malnutrition', 'risk_factor.child_wasting') Configuration defaults should be given as, for the continuous risk factor, diabetes_mellitus: threshold : ">7" mortality : True recoverable : False For the categorical risk factor, protein_energy_malnutrition: threshold : ['cat1', 'cat2'] # provide the categories to get PEM. mortality : True recoverable : True """ CONFIGURATION_DEFAULTS = { "risk_attributable_disease": { "threshold": None, "mortality": True, "recoverable": True, } } ############## # Properties # ############## @property def name(self): return f"disease_model.{self.cause.name}" @property def configuration_defaults(self) -> Dict[str, Any]: return {self.cause.name: self.CONFIGURATION_DEFAULTS["risk_attributable_disease"]} @property def columns_created(self) -> List[str]: return [ self.cause.name, self.diseased_event_time_column, self.susceptible_event_time_column, ] @property def columns_required(self) -> Optional[List[str]]: return ["alive"] @property def initialization_requirements(self) -> Dict[str, List[str]]: return { "requires_columns": [], "requires_values": [f"{self.risk.name}.exposure"], "requires_streams": [], } @property def state_names(self): return self._state_names @property def transition_names(self): return self._transition_names ##################### # Lifecycle methods # ##################### def __init__(self, cause: str, risk: str): super().__init__() self.cause = EntityString(cause) self.risk = EntityString(risk) self.state_column = self.cause.name self.state_id = self.cause.name self.diseased_event_time_column = f"{self.cause.name}_event_time" self.susceptible_event_time_column = f"susceptible_to_{self.cause.name}_event_time" self._state_names = [f"{self.cause.name}", f"susceptible_to_{self.cause.name}"] self._transition_names = [ TransitionString(f"susceptible_to_{self.cause.name}_TO_{self.cause.name}") ] self.excess_mortality_rate_pipeline_name = f"{self.cause.name}.excess_mortality_rate" self.excess_mortality_rate_paf_pipeline_name = ( f"{self.excess_mortality_rate_pipeline_name}.paf" ) # noinspection PyAttributeOutsideInit
[docs] def setup(self, builder): self.recoverable = builder.configuration[self.cause.name].recoverable self.adjust_state_and_transitions() self.clock = builder.time.clock() disability_weight_data = builder.data.load(f"{self.cause}.disability_weight") self.has_disability = is_non_zero(disability_weight_data) self.base_disability_weight = builder.lookup.build_table( disability_weight_data, key_columns=["sex"], parameter_columns=["age", "year"] ) self.disability_weight = builder.value.register_value_producer( f"{self.cause.name}.disability_weight", source=self.compute_disability_weight, requires_columns=["age", "sex", "alive", self.cause.name], ) builder.value.register_value_modifier( "disability_weight", modifier=self.disability_weight ) cause_specific_mortality_rate = self.load_cause_specific_mortality_rate_data(builder) self.cause_specific_mortality_rate = builder.lookup.build_table( cause_specific_mortality_rate, key_columns=["sex"], parameter_columns=["age", "year"], ) builder.value.register_value_modifier( "cause_specific_mortality_rate", self.adjust_cause_specific_mortality_rate, requires_columns=["age", "sex"], ) excess_mortality_data = self.load_excess_mortality_rate_data(builder) self.has_excess_mortality = is_non_zero(excess_mortality_data) self.base_excess_mortality_rate = builder.lookup.build_table( excess_mortality_data, key_columns=["sex"], parameter_columns=["age", "year"] ) self.excess_mortality_rate = builder.value.register_value_producer( self.excess_mortality_rate_pipeline_name, source=self.compute_excess_mortality_rate, requires_columns=["age", "sex", "alive", self.cause.name], requires_values=[self.excess_mortality_rate_paf_pipeline_name], ) paf = builder.lookup.build_table(0) self.joint_paf = builder.value.register_value_producer( self.excess_mortality_rate_paf_pipeline_name, source=lambda idx: [paf(idx)], preferred_combiner=list_combiner, preferred_post_processor=union_post_processor, ) builder.value.register_value_modifier( "mortality_rate", modifier=self.adjust_mortality_rate, requires_values=[self.excess_mortality_rate_pipeline_name], ) distribution = builder.data.load(f"{self.risk}.distribution") exposure_pipeline = builder.value.get_value(f"{self.risk.name}.exposure") threshold = builder.configuration[self.cause.name].threshold self.filter_by_exposure = self.get_exposure_filter( distribution, exposure_pipeline, threshold )
################# # Setup methods # #################
[docs] def adjust_state_and_transitions(self): if self.recoverable: self._transition_names.append( TransitionString(f"{self.cause.name}_TO_susceptible_to_{self.cause.name}") )
[docs] def load_cause_specific_mortality_rate_data(self, builder): if builder.configuration[self.cause.name].mortality: csmr_data = builder.data.load( f"cause.{self.cause.name}.cause_specific_mortality_rate" ) else: csmr_data = 0 return csmr_data
[docs] def load_excess_mortality_rate_data(self, builder): if builder.configuration[self.cause.name].mortality: emr_data = builder.data.load(f"cause.{self.cause.name}.excess_mortality_rate") else: emr_data = 0 return emr_data
[docs] def get_exposure_filter(self, distribution, exposure_pipeline, threshold): if distribution in ["dichotomous", "ordered_polytomous", "unordered_polytomous"]: def categorical_filter(index): exposure = exposure_pipeline(index) return exposure.isin(threshold) filter_function = categorical_filter else: # continuous Threshold = namedtuple("Threshold", ["operator", "value"]) threshold_val = re.findall(r"[-+]?\d*\.?\d+", threshold) if len(threshold_val) != 1: raise ValueError( f"Your {threshold} is an incorrect threshold format. It should include " f'"<" or ">" along with an integer or float number. Your threshold does not ' f"include a number or more than one number." ) allowed_operator = {"<", ">"} threshold_op = [s for s in threshold.split(threshold_val[0]) if s] # if threshold_op has more than 1 operators or 0 operator if len(threshold_op) != 1 or not allowed_operator.intersection(threshold_op): raise ValueError( f"Your {threshold} is an incorrect threshold format. It should include " f'"<" or ">" along with an integer or float number.' ) op = gt if threshold_op[0] == ">" else lt threshold = Threshold(op, float(threshold_val[0])) def continuous_filter(index): exposure = exposure_pipeline(index) return threshold.operator(exposure, threshold.value) filter_function = continuous_filter return filter_function
######################## # Event-driven methods # ########################
[docs] def on_initialize_simulants(self, pop_data: SimulantData) -> None: new_pop = pd.DataFrame( { self.cause.name: f"susceptible_to_{self.cause.name}", self.diseased_event_time_column: pd.Series(pd.NaT, index=pop_data.index), self.susceptible_event_time_column: pd.Series(pd.NaT, index=pop_data.index), }, index=pop_data.index, ) sick = self.filter_by_exposure(pop_data.index) new_pop.loc[sick, self.cause.name] = self.cause.name new_pop.loc[ sick, self.diseased_event_time_column ] = self.clock() # match VPH disease, only set w/ condition self.population_view.update(new_pop)
[docs] def on_time_step(self, event: Event) -> None: pop = self.population_view.get(event.index, query='alive == "alive"') sick = self.filter_by_exposure(pop.index) # if this is recoverable, anyone who gets lower exposure in the event goes back in to susceptible status. if self.recoverable: change_to_susceptible = (~sick) & ( pop[self.cause.name] != f"susceptible_to_{self.cause.name}" ) pop.loc[change_to_susceptible, self.susceptible_event_time_column] = event.time pop.loc[ change_to_susceptible, self.cause.name ] = f"susceptible_to_{self.cause.name}" change_to_diseased = sick & (pop[self.cause.name] != self.cause.name) pop.loc[change_to_diseased, self.diseased_event_time_column] = event.time pop.loc[change_to_diseased, self.cause.name] = self.cause.name self.population_view.update(pop)
################################## # Pipeline sources and modifiers # ##################################
[docs] def compute_disability_weight(self, index): disability_weight = pd.Series(0.0, index=index) with_condition = self.with_condition(index) disability_weight.loc[with_condition] = self.base_disability_weight(with_condition) return disability_weight
[docs] def compute_excess_mortality_rate(self, index): excess_mortality_rate = pd.Series(0.0, index=index) with_condition = self.with_condition(index) base_excess_mort = self.base_excess_mortality_rate(with_condition) joint_mediated_paf = self.joint_paf(with_condition) excess_mortality_rate.loc[with_condition] = base_excess_mort * ( 1 - joint_mediated_paf.values ) return excess_mortality_rate
[docs] def adjust_cause_specific_mortality_rate(self, index, rate): return rate + self.cause_specific_mortality_rate(index)
[docs] def adjust_mortality_rate(self, index, rates_df): """Modifies the baseline mortality rate for a simulant if they are in this state. Parameters ---------- index An iterable of integer labels for the simulants. rates_df """ rate = self.excess_mortality_rate(index, skip_post_processor=True) rates_df[self.cause.name] = rate return rates_df
################## # Helper methods # ##################
[docs] def with_condition(self, index): pop = self.population_view.subview(["alive", self.cause.name]).get(index) with_condition = pop.loc[ (pop[self.cause.name] == self.cause.name) & (pop["alive"] == "alive") ].index return with_condition