Source code for vivarium_public_health.disease.transition

"""
===================
Disease Transitions
===================

This module contains tools to model transitions between disease states.

"""
from typing import TYPE_CHECKING, Callable, Dict

import pandas as pd
from vivarium.framework.engine import Builder
from vivarium.framework.state_machine import Transition, Trigger
from vivarium.framework.utilities import rate_to_probability
from vivarium.framework.values import list_combiner, union_post_processor

if TYPE_CHECKING:
    from vivarium_public_health.disease import BaseDiseaseState


[docs]class TransitionString(str): def __new__(cls, value): # noinspection PyArgumentList obj = str.__new__(cls, value.lower()) obj.from_state, obj.to_state = value.split("_TO_") return obj
[docs]class RateTransition(Transition): ##################### # Lifecycle methods # ##################### def __init__( self, input_state: "BaseDiseaseState", output_state: "BaseDiseaseState", get_data_functions: Dict[str, Callable] = None, triggered=Trigger.NOT_TRIGGERED, ): super().__init__( input_state, output_state, probability_func=self._probability, triggered=triggered ) self._get_data_functions = ( get_data_functions if get_data_functions is not None else {} ) # noinspection PyAttributeOutsideInit
[docs] def setup(self, builder: Builder) -> None: rate_data, pipeline_name = self.load_transition_rate_data(builder) self.base_rate = builder.lookup.build_table( rate_data, key_columns=["sex"], parameter_columns=["age", "year"] ) self.transition_rate = builder.value.register_rate_producer( pipeline_name, source=self.compute_transition_rate, requires_columns=["age", "sex", "alive"], requires_values=[f"{pipeline_name}.paf"], ) paf = builder.lookup.build_table(0) self.joint_paf = builder.value.register_value_producer( f"{pipeline_name}.paf", source=lambda index: [paf(index)], preferred_combiner=list_combiner, preferred_post_processor=union_post_processor, ) self.population_view = builder.population.get_view(["alive"])
################# # Setup methods # #################
[docs] def load_transition_rate_data(self, builder): if "incidence_rate" in self._get_data_functions: rate_data = self._get_data_functions["incidence_rate"]( builder, self.output_state.state_id ) pipeline_name = f"{self.output_state.state_id}.incidence_rate" elif "remission_rate" in self._get_data_functions: rate_data = self._get_data_functions["remission_rate"]( builder, self.input_state.state_id ) pipeline_name = f"{self.input_state.state_id}.remission_rate" elif "transition_rate" in self._get_data_functions: rate_data = self._get_data_functions["transition_rate"]( builder, self.input_state.state_id, self.output_state.state_id ) pipeline_name = ( f"{self.input_state.state_id}_to_{self.output_state.state_id}.transition_rate" ) else: raise ValueError("No valid data functions supplied.") return rate_data, pipeline_name
################################## # Pipeline sources and modifiers # ##################################
[docs] def compute_transition_rate(self, index: pd.Index) -> pd.Series: transition_rate = pd.Series(0.0, index=index) living = self.population_view.get(index, query='alive == "alive"').index base_rates = self.base_rate(living) joint_paf = self.joint_paf(living) transition_rate.loc[living] = base_rates * (1 - joint_paf) return transition_rate
################## # Helper methods # ################## def _probability(self, index: pd.Index) -> pd.Series: return pd.Series(rate_to_probability(self.transition_rate(index)))
[docs]class ProportionTransition(Transition): ##################### # Lifecycle methods # ##################### def __init__( self, input_state: "BaseDiseaseState", output_state: "BaseDiseaseState", get_data_functions: Dict[str, Callable] = None, triggered=Trigger.NOT_TRIGGERED, ): super().__init__( input_state, output_state, probability_func=self._probability, triggered=triggered ) self._get_data_functions = ( get_data_functions if get_data_functions is not None else {} ) # noinspection PyAttributeOutsideInit
[docs] def setup(self, builder): super().setup(builder) get_proportion_func = self._get_data_functions.get("proportion", None) if get_proportion_func is None: raise ValueError("Must supply a proportion function") self._proportion_data = get_proportion_func(builder, self.output_state.state_id) self.proportion = builder.lookup.build_table( self._proportion_data, key_columns=["sex"], parameter_columns=["age", "year"] )
def _probability(self, index): return self.proportion(index)