Source code for vivarium_public_health.mslt.disease

"""
==============
Disease Models
==============

This module contains tools for modeling diseases in multi-state lifetable
simulations.

"""
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData


[docs]class AcuteDisease(Component): """ An acute disease has a sufficiently short duration, relative to the time-step size, that it is not meaningful to talk about prevalence. Instead, it simply contributes an excess mortality rate, and/or a disability rate. Interventions may affect these rates: - `<disease>_intervention.excess_mortality` - `<disease>_intervention.yld_rate` where `<disease>` is the name as provided to the constructor. Parameters ---------- disease The disease name (referred to as `<disease>` here). """ ##################### # Lifecycle methods # ##################### def __init__(self, disease: str): super().__init__() self.disease = disease
[docs] def setup(self, builder: Builder) -> None: """Load the morbidity and mortality data.""" mty_data = builder.data.load(f"acute_disease.{self.disease}.mortality") mty_rate = builder.lookup.build_table( mty_data, key_columns=["sex"], parameter_columns=["age", "year"] ) yld_data = builder.data.load(f"acute_disease.{self.disease}.morbidity") yld_rate = builder.lookup.build_table( yld_data, key_columns=["sex"], parameter_columns=["age", "year"] ) self.excess_mortality = builder.value.register_rate_producer( f"{self.disease}.excess_mortality", source=mty_rate ) self.int_excess_mortality = builder.value.register_rate_producer( f"{self.disease}_intervention.excess_mortality", source=mty_rate ) self.disability_rate = builder.value.register_rate_producer( f"{self.disease}.yld_rate", source=yld_rate ) self.int_disability_rate = builder.value.register_rate_producer( f"{self.disease}_intervention.yld_rate", source=yld_rate ) builder.value.register_value_modifier("mortality_rate", self.mortality_adjustment) builder.value.register_value_modifier("yld_rate", self.disability_adjustment)
################################## # Pipeline sources and modifiers # ##################################
[docs] def mortality_adjustment(self, index, mortality_rate): """ Adjust the all-cause mortality rate in the intervention scenario, to account for any change in prevalence (relative to the BAU scenario). """ delta = self.int_excess_mortality(index) - self.excess_mortality(index) return mortality_rate + delta
[docs] def disability_adjustment(self, index, yld_rate): """ Adjust the years lost due to disability (YLD) rate in the intervention scenario, to account for any change in prevalence (relative to the BAU scenario). """ delta = self.int_disability_rate(index) - self.disability_rate(index) return yld_rate + delta
[docs]class Disease(Component): """This component characterises a chronic disease. It defines the following rates, which may be affected by interventions: - `<disease>_intervention.incidence` - `<disease>_intervention.remission` - `<disease>_intervention.mortality` - `<disease>_intervention.morbidity` where `<disease>` is the name as provided to the constructor. Parameters ---------- disease The disease name (referred to as `<disease>` here). """ ############## # Properties # ############## @property def configuration_defaults(self) -> Dict[str, Any]: return { self.disease: { "simplified_no_remission_equations": False, }, } @property def columns_created(self) -> List[str]: columns = [] for scenario in ["", "_intervention"]: for rate in ["_S", "_C"]: for when in ["", "_previous"]: columns.append(self.disease + rate + scenario + when) return columns @property def columns_required(self) -> Optional[List[str]]: return ["age", "sex"] @property def initialization_requirements(self) -> Dict[str, List[str]]: return { "requires_columns": ["age", "sex"], "requires_values": [], "requires_streams": [], } def __init__(self, disease: str): super().__init__() self.disease = disease
[docs] def setup(self, builder: Builder) -> None: """Load the disease prevalence and rates data.""" data_prefix = "chronic_disease.{}.".format(self.disease) bau_prefix = self.disease + "." int_prefix = self.disease + "_intervention." self.clock = builder.time.clock() self.start_year = builder.configuration.time.start.year self.simplified_equations = builder.configuration[ self.disease ].simplified_no_remission_equations inc_data = builder.data.load(data_prefix + "incidence") i = builder.lookup.build_table( inc_data, key_columns=["sex"], parameter_columns=["age", "year"] ) self.incidence = builder.value.register_rate_producer( bau_prefix + "incidence", source=i ) self.incidence_intervention = builder.value.register_rate_producer( int_prefix + "incidence", source=i ) rem_data = builder.data.load(data_prefix + "remission") r = builder.lookup.build_table( rem_data, key_columns=["sex"], parameter_columns=["age", "year"] ) self.remission = builder.value.register_rate_producer( bau_prefix + "remission", source=r ) mty_data = builder.data.load(data_prefix + "mortality") f = builder.lookup.build_table( mty_data, key_columns=["sex"], parameter_columns=["age", "year"] ) self.excess_mortality = builder.value.register_rate_producer( bau_prefix + "excess_mortality", source=f ) yld_data = builder.data.load(data_prefix + "morbidity") yld_rate = builder.lookup.build_table( yld_data, key_columns=["sex"], parameter_columns=["age", "year"] ) self.disability_rate = builder.value.register_rate_producer( bau_prefix + "yld_rate", source=yld_rate ) prev_data = builder.data.load(data_prefix + "prevalence") self.initial_prevalence = builder.lookup.build_table( prev_data, key_columns=["sex"], parameter_columns=["age", "year"] ) builder.value.register_value_modifier("mortality_rate", self.mortality_adjustment) builder.value.register_value_modifier("yld_rate", self.disability_adjustment)
######################## # Event-driven methods # ########################
[docs] def on_initialize_simulants(self, pop_data: SimulantData) -> None: """Initialize the test population for which this disease is modeled.""" C = 1000 * self.initial_prevalence(pop_data.index) S = 1000 - C pop = pd.DataFrame( { f"{self.disease}_S": S, f"{self.disease}_C": C, f"{self.disease}_S_previous": S, f"{self.disease}_C_previous": C, f"{self.disease}_S_intervention": S, f"{self.disease}_C_intervention": C, f"{self.disease}_S_intervention_previous": S, f"{self.disease}_C_intervention_previous": C, }, index=pop_data.index, ) self.population_view.update(pop)
[docs] def on_time_step_prepare(self, event: Event) -> None: """ Update the disease status for both the BAU and intervention scenarios. """ # Do not update the disease status in the first year, the initial data # describe the disease state at the end of the year. if self.clock().year == self.start_year: return pop = self.population_view.get(event.index) if pop.empty: return idx = pop.index S_bau, C_bau = pop[f"{self.disease}_S"], pop[f"{self.disease}_C"] S_int = pop[f"{self.disease}_S_intervention"] C_int = pop[f"{self.disease}_C_intervention"] # Extract all of the required rates *once only*. i_bau = self.incidence(idx) i_int = self.incidence_intervention(idx) r = self.remission(idx) f = self.excess_mortality(idx) # NOTE: if the remission rate is always zero, which is the case for a # number of chronic diseases, we can make some simplifications. if np.all(r == 0): r = 0 if self.simplified_equations: # NOTE: for the 'mslt_reduce_chd' experiment, this results in a # slightly lower HALY gain than that obtained when using the # full equations (below). new_S_bau = S_bau * np.exp(-i_bau) new_S_int = S_int * np.exp(-i_int) new_C_bau = C_bau * np.exp(-f) + S_bau - new_S_bau new_C_int = C_int * np.exp(-f) + S_int - new_S_int pop_update = pd.DataFrame( { f"{self.disease}_S": new_S_bau, f"{self.disease}_C": new_C_bau, f"{self.disease}_S_previous": S_bau, f"{self.disease}_C_previous": C_bau, f"{self.disease}_S_intervention": new_S_int, f"{self.disease}_C_intervention": new_C_int, f"{self.disease}_S_intervention_previous": S_int, f"{self.disease}_C_intervention_previous": C_int, }, index=pop.index, ) self.population_view.update(pop_update) return # Calculate common factors. i_bau2 = i_bau**2 i_int2 = i_int**2 r2 = r**2 f2 = f**2 f_r = f * r i_bau_r = i_bau * r i_int_r = i_int * r i_bau_f = i_bau * f i_int_f = i_int * f f_plus_r = f + r # Calculate convenience terms. l_bau = i_bau + f_plus_r l_int = i_int + f_plus_r q_bau = np.sqrt(i_bau2 + r2 + f2 + 2 * i_bau_r + 2 * f_r - 2 * i_bau_f) q_int = np.sqrt(i_int2 + r2 + f2 + 2 * i_int_r + 2 * f_r - 2 * i_int_f) w_bau = np.exp(-(l_bau + q_bau) / 2) w_int = np.exp(-(l_int + q_int) / 2) v_bau = np.exp(-(l_bau - q_bau) / 2) v_int = np.exp(-(l_int - q_int) / 2) # Identify where the denominators are non-zero. nz_bau = q_bau != 0 nz_int = q_int != 0 denom_bau = 2 * q_bau denom_int = 2 * q_int new_S_bau = S_bau.copy() new_C_bau = C_bau.copy() new_S_int = S_int.copy() new_C_int = C_int.copy() # Calculate new_S_bau, new_C_bau, new_S_int, new_C_int. num_S_bau = 2 * (v_bau - w_bau) * (S_bau * f_plus_r + C_bau * r) + S_bau * ( v_bau * (q_bau - l_bau) + w_bau * (q_bau + l_bau) ) num_S_int = 2 * (v_int - w_int) * (S_int * f_plus_r + C_int * r) + S_int * ( v_int * (q_int - l_int) + w_int * (q_int + l_int) ) new_S_bau[nz_bau] = num_S_bau[nz_bau] / denom_bau[nz_bau] new_S_int[nz_int] = num_S_int[nz_int] / denom_int[nz_int] num_C_bau = -( (v_bau - w_bau) * (2 * (f_plus_r * (S_bau + C_bau) - l_bau * S_bau) - l_bau * C_bau) - (v_bau + w_bau) * q_bau * C_bau ) num_C_int = -( (v_int - w_int) * (2 * (f_plus_r * (S_int + C_int) - l_int * S_int) - l_int * C_int) - (v_int + w_int) * q_int * C_int ) new_C_bau[nz_bau] = num_C_bau[nz_bau] / denom_bau[nz_bau] new_C_int[nz_int] = num_C_int[nz_int] / denom_int[nz_int] pop_update = pd.DataFrame( { f"{self.disease}_S": new_S_bau, f"{self.disease}_C": new_C_bau, f"{self.disease}_S_previous": S_bau, f"{self.disease}_C_previous": C_bau, f"{self.disease}_S_intervention": new_S_int, f"{self.disease}_C_intervention": new_C_int, f"{self.disease}_S_intervention_previous": S_int, f"{self.disease}_C_intervention_previous": C_int, }, index=pop.index, ) self.population_view.update(pop_update)
################################## # Pipeline sources and modifiers # ##################################
[docs] def mortality_adjustment(self, index, mortality_rate): """ Adjust the all-cause mortality rate in the intervention scenario, to account for any change in disease prevalence (relative to the BAU scenario). """ pop = self.population_view.get(index) S, C = pop[f"{self.disease}_S"], pop[f"{self.disease}_C"] S_prev, C_prev = pop[f"{self.disease}_S_previous"], pop[f"{self.disease}_C_previous"] D, D_prev = 1000 - S - C, 1000 - S_prev - C_prev S_int, C_int = ( pop[f"{self.disease}_S_intervention"], pop[f"{self.disease}_C_intervention"], ) S_int_prev, C_int_prev = ( pop[f"{self.disease}_S_intervention_previous"], pop[f"{self.disease}_C_intervention_previous"], ) D_int, D_int_prev = 1000 - S_int - C_int, 1000 - S_int_prev - C_int_prev # NOTE: as per the spreadsheet, the denominator is from the same point # in time as the term being subtracted in the numerator. mortality_risk = (D - D_prev) / (S_prev + C_prev) mortality_risk_int = (D_int - D_int_prev) / (S_int_prev + C_int_prev) delta = np.log((1 - mortality_risk) / (1 - mortality_risk_int)) return mortality_rate + delta
[docs] def disability_adjustment(self, index, yld_rate): """ Adjust the years lost due to disability (YLD) rate in the intervention scenario, to account for any change in disease prevalence (relative to the BAU scenario). """ pop = self.population_view.get(index) S, S_prev = pop[f"{self.disease}_S"], pop[f"{self.disease}_S_previous"] C, C_prev = pop[f"{self.disease}_C"], pop[f"{self.disease}_C_previous"] S_int, S_int_prev = ( pop[f"{self.disease}_S_intervention"], pop[f"{self.disease}_S_intervention_previous"], ) C_int, C_int_prev = ( pop[f"{self.disease}_C_intervention"], pop[f"{self.disease}_C_intervention_previous"], ) # The prevalence rate is the mean number of diseased people over the # year, divided by the mean number of alive people over the year. # The 0.5 multipliers in the numerator and denominator therefore cancel # each other out, and can be removed. prevalence_rate = (C + C_prev) / (S + C + S_prev + C_prev) prevalence_rate_int = (C_int + C_int_prev) / (S_int + C_int + S_int_prev + C_int_prev) delta = prevalence_rate_int - prevalence_rate return yld_rate + self.disability_rate(index) * delta