"""
===================
Disease Transitions
===================
This module contains tools to model transitions between disease states.
"""
from __future__ import annotations
import warnings
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import pandas as pd
from vivarium.framework.engine import Builder
from vivarium.framework.state_machine import Transition, Trigger
from vivarium.framework.utilities import rate_to_probability
from vivarium.types import DataInput
from vivarium_public_health.causal_factor.calibration_constant import (
register_risk_affected_rate_producer,
)
from vivarium_public_health.disease.exceptions import DiseaseModelError
if TYPE_CHECKING:
from vivarium_public_health.disease import BaseDiseaseState
[docs]
class TransitionString(str):
"""A string subclass representing a transition between two disease states.
Parses the transition name into ``from_state`` and ``to_state``
attributes from the format ``{from_state}_TO_{to_state}``.
"""
def __new__(cls, value: str) -> "TransitionString":
# noinspection PyArgumentList
obj = str.__new__(cls, value.lower())
obj.from_state, obj.to_state = value.split("_TO_")
return obj
def __getnewargs__(self) -> tuple[str]:
return (self.from_state + "_TO_" + self.to_state,)
[docs]
class RateTransition(Transition):
"""A transition between disease states governed by a rate.
Converts the rate to a probability using either a linear or
exponential conversion at each time step.
"""
##############
# Properties #
##############
@property
def configuration_defaults(self) -> dict[str, Any]:
"""Provides default configuration values for this transition.
Configuration structure::
{transition_name}:
data_sources:
transition_rate:
Source for transition rate data. The default value is
determined by the ``transition_rate`` constructor argument.
rate_conversion_type: str
Method for converting rates to probabilities. Options
are ``"linear"`` (default) or ``"exponential"``. Linear
uses ``rate * dt``, exponential uses ``1 - exp(-rate * dt)``.
"""
return {
f"{self.name}": {
"data_sources": {
"transition_rate": self.transition_rate,
},
"rate_conversion_type": "linear",
},
}
@property
def transition_rate_pipeline(self) -> str:
if self.rate_type == "incidence_rate":
pipeline_name = f"{self.output_state.state_id}.incidence_rate"
elif self.rate_type == "remission_rate":
pipeline_name = f"{self.input_state.state_id}.remission_rate"
elif self.rate_type == "transition_rate":
pipeline_name = (
f"{self.input_state.state_id}_to_{self.output_state.state_id}"
".transition_rate"
)
else:
raise DiseaseModelError(
"Cannot determine rate_transition pipeline name: invalid"
f" rate_type '{self.rate_type} supplied."
)
return pipeline_name
#####################
# Lifecycle methods #
#####################
def __init__(
self,
input_state: "BaseDiseaseState",
output_state: "BaseDiseaseState",
transition_rate: DataInput,
triggered: Trigger = Trigger.NOT_TRIGGERED,
rate_type: str = "transition_rate",
):
"""
Parameters
----------
input_state
The starting state of this transition.
output_state
The ending state of this transition.
transition_rate
The transition rate source. Can be the data itself, a function
to retrieve the data, or the artifact key containing the data.
triggered
The trigger for the transition.
rate_type
The type of rate. Can be "incidence_rate", "transition_rate",
or "remission_rate".
"""
super().__init__(
input_state, output_state, probability_func=self._probability, triggered=triggered
)
self.transition_rate = transition_rate
self.rate_type = rate_type
[docs]
def setup(self, builder: Builder) -> None:
"""Perform this component's setup.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
self.transition_rate_table = self.build_lookup_table(builder, "transition_rate")
register_risk_affected_rate_producer(
builder=builder,
name=self.transition_rate_pipeline,
source=self.compute_transition_rate,
required_resources=["is_alive", self.transition_rate_table],
)
self.rate_conversion_type = self.configuration["rate_conversion_type"]
##################################
# Pipeline sources and modifiers #
##################################
[docs]
def compute_transition_rate(self, index: pd.Index[int]) -> pd.Series[float]:
"""Compute the transition rate for the given simulants.
Parameters
----------
index
An iterable of integer labels for the simulants.
Returns
-------
The transition rates indexed by the provided ``index``.
"""
transition_rate = pd.Series(0.0, index=index)
living = self.population_view.get_filtered_index(index, query="is_alive == True")
base_rates = self.transition_rate_table(living)
transition_rate.loc[living] = base_rates
return transition_rate
##################
# Helper methods #
##################
def _probability(self, index: pd.Index[int]) -> pd.Series[float]:
return pd.Series(
rate_to_probability(
self.population_view.get(index, self.transition_rate_pipeline),
rate_conversion_type=self.rate_conversion_type,
)
)
[docs]
class ProportionTransition(Transition):
"""A transition between disease states governed by a fixed proportion.
At each time step, a fixed proportion of eligible simulants
transition to the output state.
"""
##############
# Properties #
##############
@property
def configuration_defaults(self) -> dict[str, Any]:
"""Provides default configuration values for this transition.
Configuration structure::
{transition_name}:
data_sources:
proportion:
Source for the proportion of simulants transitioning
at each time step. The default uses the
``load_proportion`` method which resolves data from
the ``proportion`` constructor argument.
"""
return {
f"{self.name}": {
"data_sources": {
"proportion": self.proportion,
},
},
}
#####################
# Lifecycle methods #
#####################
def __init__(
self,
input_state: "BaseDiseaseState",
output_state: "BaseDiseaseState",
proportion: DataInput,
triggered: Trigger = Trigger.NOT_TRIGGERED,
):
"""
Parameters
----------
input_state
The starting state of this transition.
output_state
The ending state of this transition.
proportion
The proportion source. Can be the data itself, a function
to retrieve the data, or the artifact key containing the data.
triggered
The trigger for the transition.
"""
super().__init__(
input_state, output_state, probability_func=self._probability, triggered=triggered
)
self.proportion = proportion
[docs]
def setup(self, builder: Builder) -> None:
"""Perform this component's setup.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
self.proportion_table = self.build_lookup_table(builder, "proportion")
def _probability(self, index: pd.Index[int]) -> pd.Series[float]:
return self.proportion_table(index)