"""
=================
The Disease Model
=================
This module contains a state machine driver for disease models. Its primary
function is to provide coordination across a set of disease states and
transitions at simulation initialization and during transitions.
"""
from __future__ import annotations
from collections.abc import Iterable
from functools import partial
from typing import Any
import pandas as pd
from layered_config_tree import ConfigurationError
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.state_machine import Machine
from vivarium.types import DataInput, LookupTableData
from vivarium_public_health.disease.exceptions import DiseaseModelError
from vivarium_public_health.disease.state import BaseDiseaseState, SusceptibleState
from vivarium_public_health.disease.transition import RateTransition, TransitionString
[docs]
class DiseaseModel(Machine):
"""State machine model for disease progression.
This component manages a set of disease states and transitions between
them. It handles initialization of simulant disease states based on
prevalence data and tracks cause-specific mortality rates.
"""
##############
# Properties #
##############
@property
def configuration_defaults(self) -> dict[str, Any]:
"""Provides default configuration values for this disease model.
Configuration structure::
{disease_name}:
data_sources:
cause_specific_mortality_rate:
Source for cause-specific mortality rate (CSMR) data.
Default uses the ``load_cause_specific_mortality_rate``
method which loads from artifact at
``cause.{cause_name}.cause_specific_mortality_rate``.
"""
return {
f"{self.name}": {
"data_sources": {
"cause_specific_mortality_rate": self.load_cause_specific_mortality_rate,
},
},
}
@property
def state_names(self) -> list[str]:
"""List of names of all states in this disease model."""
return [s.state_id for s in self.states]
@property
def transition_names(self) -> list[TransitionString]:
"""List of names of all transitions in this disease model."""
return [
state_name for state in self.states for state_name in state.get_transition_names()
]
#####################
# Lifecycle methods #
#####################
def __init__(
self,
cause: str,
cause_type: str = "cause",
states: Iterable[BaseDiseaseState] = (),
residual_state: BaseDiseaseState | None = None,
cause_specific_mortality_rate: DataInput | None = None,
) -> None:
"""
Parameters
----------
cause
The name of the cause of disease.
cause_type
The type of cause. Either "cause" or "sequela".
states
The disease states to include in the model.
residual_state
The state to use as the residual (whose prevalence is
calculated as 1 minus the sum of all other states). If not
provided, the model's ``SusceptibleState`` is used.
cause_specific_mortality_rate
The source for cause-specific mortality rate data. Can be the
data itself, a function to retrieve the data, or the artifact
key containing the data.
"""
super().__init__(cause, states=states)
self.cause = cause
self.cause_type = cause_type
self.residual_state = self._get_residual_state(residual_state)
self._csmr_source = cause_specific_mortality_rate
[docs]
def setup(self, builder: Builder) -> None:
"""Perform this component's setup.
- Gathers initialization weights pipelines from states contained in the disease model
and registers them to be run during population initialization.
- Registers a modifier to adjust the cause-specific mortality rate based on the model's states.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
self.initialization_weights_pipelines = [
*[state.prevalence_pipeline for state in self.states],
*[state.birth_prevalence_pipeline for state in self.states],
]
super().setup(builder)
self.configuration_age_start = builder.configuration.population.initialization_age_min
self.configuration_age_end = builder.configuration.population.initialization_age_max
self.csmr_table = self.build_lookup_table(builder, "cause_specific_mortality_rate")
builder.value.register_attribute_modifier(
"cause_specific_mortality_rate",
self.adjust_cause_specific_mortality_rate,
required_resources=["age", "sex"],
)
[docs]
def on_post_setup(self, event: Event) -> None:
"""Validate that all rate transitions use the same conversion type.
Parameters
----------
event
The event that triggered this method call.
"""
conversion_types = set()
for state in self.states:
for transition in state.transition_set.transitions:
if isinstance(transition, RateTransition):
conversion_types.add(transition.rate_conversion_type)
if len(conversion_types) > 1:
raise ConfigurationError(
"All transitions in a disease model must have the same rate conversion type."
f" Found: {conversion_types}."
)
[docs]
def initialize_state(self, pop_data: SimulantData) -> None:
"""Initialize the simulants in the population.
If all simulants are initialized at age 0, birth prevalence is used.
Otherwise, prevalence is used.
Parameters
----------
pop_data
Metadata about the simulants being initialized.
"""
self.initialization_weights_pipelines = [
state.birth_prevalence_pipeline
if pop_data.user_data.get("age_end", self.configuration_age_end) == 0
else state.prevalence_pipeline
for state in self.states
]
super().initialize_state(pop_data)
#################
# Setup methods #
#################
[docs]
def load_cause_specific_mortality_rate(self, builder: Builder) -> float | pd.DataFrame:
"""Load cause-specific mortality rate data.
If no source was provided at construction, loads CSMR from the
artifact. Returns 0.0 for causes that only have morbidity
(YLD-only causes).
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
Returns
-------
The cause-specific mortality rate data.
"""
if self._csmr_source is None:
only_morbid = builder.data.load(f"cause.{self.cause}.restrictions")["yld_only"]
if only_morbid:
self._csmr_source = 0.0
else:
self._csmr_source = (
f"{self.cause_type}.{self.cause}.cause_specific_mortality_rate"
)
return self.get_data(builder, self._csmr_source)
##################################
# Pipeline sources and modifiers #
##################################
[docs]
def adjust_cause_specific_mortality_rate(
self, index: pd.Index[int], rate: pd.Series[float]
) -> pd.Series[float]:
"""Modify the cause-specific mortality rate for the given simulants.
Parameters
----------
index
The index of simulants for which to adjust the cause-specific mortality rate.
rate
The base cause-specific mortality rate.
Returns
-------
The adjusted cause-specific mortality rate.
"""
return rate + self.csmr_table(index)
####################
# Helper functions #
####################
def _get_residual_state(
self, residual_state: BaseDiseaseState | None
) -> BaseDiseaseState:
"""Get the residual state for the DiseaseModel.
This will be the residual state if it is provided, otherwise it will be
the model's SusceptibleState. This method also calculates the residual
state's birth_prevalence and prevalence.
Parameters
----------
residual_state
The state to use as the residual, or None to auto-detect
from the model's ``SusceptibleState``.
Returns
-------
The resolved residual state with prevalence and
birth_prevalence functions set.
"""
if residual_state is None:
susceptible_states = [s for s in self.states if isinstance(s, SusceptibleState)]
if len(susceptible_states) != 1:
raise DiseaseModelError(
"DiseaseModel must have exactly one SusceptibleState or it must specify"
" a residual state."
)
residual_state = susceptible_states[0]
if residual_state not in self.states:
raise DiseaseModelError(
f"Residual state '{residual_state}' must be one of the states: {self.states}."
)
residual_state.birth_prevalence = partial(
self._get_residual_state_probabilities, table_name="birth_prevalence"
)
residual_state.prevalence = partial(
self._get_residual_state_probabilities, table_name="prevalence"
)
return residual_state
def _get_residual_state_probabilities(
self, builder: Builder, table_name: str
) -> LookupTableData:
"""Calculate the probabilities of the residual state based on the other states.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
table_name
The name of the probability table, either "prevalence" or
"birth_prevalence".
Returns
-------
The residual state probabilities, calculated as 1 minus the
sum of probabilities from all non-residual states.
"""
non_residual_states = [s for s in self.states if s != self.residual_state]
non_residual_probabilities = 0
for state in non_residual_states:
weights_source = builder.configuration[state.name].data_sources[table_name]
weights = state.get_data(builder, weights_source)
if isinstance(weights, pd.DataFrame):
weights = weights.set_index(
[c for c in weights.columns if c != "value"]
).squeeze()
non_residual_probabilities += weights
residual_probabilities = 1 - non_residual_probabilities
if pd.Series(residual_probabilities < 0).any():
raise ValueError(
f"The {table_name} for the states in the DiseaseModel must sum"
" to less than 1."
)
if isinstance(residual_probabilities, pd.Series):
residual_probabilities = residual_probabilities.reset_index()
return residual_probabilities