Source code for vivarium_public_health.disease.transition

"""
===================
Disease Transitions
===================

This module contains tools to model transitions between disease states.

"""
from __future__ import annotations

import warnings
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import pandas as pd
from vivarium.framework.engine import Builder
from vivarium.framework.state_machine import Transition, Trigger
from vivarium.framework.utilities import rate_to_probability
from vivarium.types import DataInput

from vivarium_public_health.causal_factor.calibration_constant import (
    register_risk_affected_rate_producer,
)
from vivarium_public_health.disease.exceptions import DiseaseModelError

if TYPE_CHECKING:
    from vivarium_public_health.disease import BaseDiseaseState


[docs] class TransitionString(str): """A string subclass representing a transition between two disease states. Parses the transition name into ``from_state`` and ``to_state`` attributes from the format ``{from_state}_TO_{to_state}``. """ def __new__(cls, value: str) -> "TransitionString": # noinspection PyArgumentList obj = str.__new__(cls, value.lower()) obj.from_state, obj.to_state = value.split("_TO_") return obj def __getnewargs__(self) -> tuple[str]: return (self.from_state + "_TO_" + self.to_state,)
[docs] class RateTransition(Transition): """A transition between disease states governed by a rate. Converts the rate to a probability using either a linear or exponential conversion at each time step. """ ############## # Properties # ############## @property def configuration_defaults(self) -> dict[str, Any]: """Provides default configuration values for this transition. Configuration structure:: {transition_name}: data_sources: transition_rate: Source for transition rate data. The default value is determined by the ``transition_rate`` constructor argument. rate_conversion_type: str Method for converting rates to probabilities. Options are ``"linear"`` (default) or ``"exponential"``. Linear uses ``rate * dt``, exponential uses ``1 - exp(-rate * dt)``. """ return { f"{self.name}": { "data_sources": { "transition_rate": self.transition_rate, }, "rate_conversion_type": "linear", }, } @property def transition_rate_pipeline(self) -> str: if self.rate_type == "incidence_rate": pipeline_name = f"{self.output_state.state_id}.incidence_rate" elif self.rate_type == "remission_rate": pipeline_name = f"{self.input_state.state_id}.remission_rate" elif self.rate_type == "transition_rate": pipeline_name = ( f"{self.input_state.state_id}_to_{self.output_state.state_id}" ".transition_rate" ) else: raise DiseaseModelError( "Cannot determine rate_transition pipeline name: invalid" f" rate_type '{self.rate_type} supplied." ) return pipeline_name ##################### # Lifecycle methods # ##################### def __init__( self, input_state: "BaseDiseaseState", output_state: "BaseDiseaseState", transition_rate: DataInput, triggered: Trigger = Trigger.NOT_TRIGGERED, rate_type: str = "transition_rate", ): """ Parameters ---------- input_state The starting state of this transition. output_state The ending state of this transition. transition_rate The transition rate source. Can be the data itself, a function to retrieve the data, or the artifact key containing the data. triggered The trigger for the transition. rate_type The type of rate. Can be "incidence_rate", "transition_rate", or "remission_rate". """ super().__init__( input_state, output_state, probability_func=self._probability, triggered=triggered ) self.transition_rate = transition_rate self.rate_type = rate_type
[docs] def setup(self, builder: Builder) -> None: """Perform this component's setup. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ self.transition_rate_table = self.build_lookup_table(builder, "transition_rate") register_risk_affected_rate_producer( builder=builder, name=self.transition_rate_pipeline, source=self.compute_transition_rate, required_resources=["is_alive", self.transition_rate_table], ) self.rate_conversion_type = self.configuration["rate_conversion_type"]
################################## # Pipeline sources and modifiers # ##################################
[docs] def compute_transition_rate(self, index: pd.Index[int]) -> pd.Series[float]: """Compute the transition rate for the given simulants. Parameters ---------- index An iterable of integer labels for the simulants. Returns ------- The transition rates indexed by the provided ``index``. """ transition_rate = pd.Series(0.0, index=index) living = self.population_view.get_filtered_index(index, query="is_alive == True") base_rates = self.transition_rate_table(living) transition_rate.loc[living] = base_rates return transition_rate
################## # Helper methods # ################## def _probability(self, index: pd.Index[int]) -> pd.Series[float]: return pd.Series( rate_to_probability( self.population_view.get(index, self.transition_rate_pipeline), rate_conversion_type=self.rate_conversion_type, ) )
[docs] class ProportionTransition(Transition): """A transition between disease states governed by a fixed proportion. At each time step, a fixed proportion of eligible simulants transition to the output state. """ ############## # Properties # ############## @property def configuration_defaults(self) -> dict[str, Any]: """Provides default configuration values for this transition. Configuration structure:: {transition_name}: data_sources: proportion: Source for the proportion of simulants transitioning at each time step. The default uses the ``load_proportion`` method which resolves data from the ``proportion`` constructor argument. """ return { f"{self.name}": { "data_sources": { "proportion": self.proportion, }, }, } ##################### # Lifecycle methods # ##################### def __init__( self, input_state: "BaseDiseaseState", output_state: "BaseDiseaseState", proportion: DataInput, triggered: Trigger = Trigger.NOT_TRIGGERED, ): """ Parameters ---------- input_state The starting state of this transition. output_state The ending state of this transition. proportion The proportion source. Can be the data itself, a function to retrieve the data, or the artifact key containing the data. triggered The trigger for the transition. """ super().__init__( input_state, output_state, probability_func=self._probability, triggered=triggered ) self.proportion = proportion
[docs] def setup(self, builder: Builder) -> None: """Perform this component's setup. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ self.proportion_table = self.build_lookup_table(builder, "proportion")
def _probability(self, index: pd.Index[int]) -> pd.Series[float]: return self.proportion_table(index)