"""
==============
Disease States
==============
This module contains tools to manage standard disease states.
"""
from __future__ import annotations
from abc import ABC
from collections.abc import Callable
from typing import Any
import numpy as np
import pandas as pd
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.population import PopulationView, SimulantData
from vivarium.framework.randomness import RandomnessStream
from vivarium.framework.state_machine import State, Transient, Transition, Trigger
from vivarium.types import DataInput, LookupTableData
from vivarium_public_health.causal_factor.calibration_constant import (
register_risk_affected_rate_producer,
)
from vivarium_public_health.disease.exceptions import DiseaseModelError
from vivarium_public_health.disease.transition import (
ProportionTransition,
RateTransition,
TransitionString,
)
from vivarium_public_health.utilities import is_non_zero
[docs]
class BaseDiseaseState(State):
"""Base class for disease states in a state machine model.
Provides shared infrastructure for tracking state event times,
prevalence-based initialization weights, and transitions.
"""
##############
# Properties #
##############
@property
def configuration_defaults(self) -> dict[str, Any]:
"""Provides default configuration values for this state.
Extends the parent State's configuration with disease-specific
data sources for prevalence.
Configuration structure::
{component_name}:
data_sources:
prevalence:
Source for prevalence data used to initialize simulants
into this state. Default is the value set on the instance
(typically 0.0).
birth_prevalence:
Source for birth prevalence data used to initialize
newborn simulants. Default is the value set on the
instance (typically 0.0).
"""
configuration_defaults = super().configuration_defaults
additional_defaults = {
"prevalence": self.prevalence,
"birth_prevalence": self.birth_prevalence,
"dwell_time": 0.0,
}
data_sources = {
**configuration_defaults[self.name]["data_sources"],
**additional_defaults,
}
configuration_defaults[self.name]["data_sources"] = data_sources
return configuration_defaults
@property
def has_dwell_time(self) -> bool:
"""Whether this state has a non-zero dwell time."""
dwell_time = self.dwell_time_table.data
return (
isinstance(dwell_time, pd.DataFrame) and np.any(dwell_time.value != 0)
) or dwell_time > 0
#####################
# Lifecycle methods #
#####################
def __init__(
self,
state_id: str,
allow_self_transition: bool = True,
side_effect_function: Callable | None = None,
cause_type: str = "cause",
) -> None:
"""
Parameters
----------
state_id
The name of this state.
allow_self_transition
Whether this state allows simulants to remain in the state
for multiple time steps.
side_effect_function
A function to be called when this state is entered.
cause_type
The type of cause. Either "cause" or "sequela".
"""
super().__init__(state_id, allow_self_transition)
self.cause_type = cause_type
self.side_effect_function = side_effect_function
self.event_time_column = self.state_id + "_event_time"
self.event_count_column = self.state_id + "_event_count"
self.prevalence_pipeline = f"{self.state_id}.prevalence"
self.birth_prevalence_pipeline = f"{self.state_id}.birth_prevalence"
self.dwell_time_pipeline = f"{self.state_id}.dwell_time"
self.prevalence = 0.0
self.birth_prevalence = 0.0
self.required_resources = []
[docs]
def setup(self, builder: Builder) -> None:
"""Perform this component's setup.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
self.dwell_time_table = self.build_lookup_table(
builder, "dwell_time", data_source=self.get_dwell_time(builder)
)
if self.has_dwell_time and not self.transition_set.allow_self_transition:
raise DiseaseModelError(
f"State '{self.state_id}' has a dwell time but does not allow self-transitions."
)
self.prevalence_table = self.build_lookup_table(builder, "prevalence")
builder.value.register_attribute_producer(
self.prevalence_pipeline, source=self.prevalence_table
)
self.birth_prevalence_table = self.build_lookup_table(builder, "birth_prevalence")
builder.value.register_attribute_producer(
self.birth_prevalence_pipeline, source=self.birth_prevalence_table
)
builder.value.register_attribute_producer(
self.dwell_time_pipeline, source=self.dwell_time_table
)
builder.population.register_initializer(
initializer=self.initialize_event_time_and_count,
columns=[self.event_time_column, self.event_count_column],
required_resources=self.required_resources,
)
#################
# Setup methods #
#################
[docs]
def get_dwell_time(self, builder: Builder) -> DataInput:
"""Load the dwell time for this state from configuration.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
Returns
-------
The dwell time data, converted from a Timedelta to days
if applicable.
"""
dwell_time = self.get_data(
builder, self.configuration.get(["data_sources", "dwell_time"])
)
if isinstance(dwell_time, pd.Timedelta):
dwell_time = dwell_time.total_seconds() / (60 * 60 * 24)
return dwell_time
########################
# Event-driven methods #
########################
[docs]
def initialize_event_time_and_count(self, pop_data: SimulantData) -> None:
"""Add this state's columns to the simulation state table.
Parameters
----------
pop_data
Metadata about the simulants being initialized.
"""
for transition in self.transition_set:
if transition.start_active:
transition.set_active(pop_data.index)
pop_update = self.get_initial_event_times(pop_data)
self.population_view.initialize(pop_update)
##################
# Helper methods #
##################
[docs]
def get_initialization_parameters(self) -> dict[str, Any]:
"""Exclude side effect function and cause type from name and __repr__."""
initialization_parameters = super().get_initialization_parameters()
return {"state_id": initialization_parameters["state_id"]}
[docs]
def get_initial_event_times(self, pop_data: SimulantData) -> pd.DataFrame:
"""Get initial event times and counts for new simulants.
Parameters
----------
pop_data
Metadata about the simulants being initialized.
Returns
-------
A DataFrame with event time and count columns for new simulants.
"""
return pd.DataFrame(
{self.event_time_column: pd.NaT, self.event_count_column: 0}, index=pop_data.index
)
[docs]
def transition_side_effect(self, index: pd.Index[int], event_time: pd.Timestamp) -> None:
"""Updates the simulation state and triggers any side effects associated with this state.
Parameters
----------
index
An iterable of integer labels for the simulants.
event_time
The time at which this transition occurs.
"""
def _bump_event(pop: pd.DataFrame) -> pd.DataFrame:
pop = pop.loc[index]
pop[self.event_time_column] = event_time
pop[self.event_count_column] += 1
return pop
self.population_view.update(
[self.event_time_column, self.event_count_column], _bump_event
)
if self.side_effect_function is not None:
self.side_effect_function(index, event_time)
##################
# Public methods #
##################
[docs]
def get_transition_names(self) -> list[str]:
"""Get the names of all transitions from this state.
Returns
-------
The transition names formatted as ``{from_state}_TO_{to_state}``.
"""
transitions = []
for trans in self.transition_set.transitions:
init_state = trans.input_state.name.split(".")[1]
end_state = trans.output_state.name.split(".")[1]
transitions.append(TransitionString(f"{init_state}_TO_{end_state}"))
return transitions
[docs]
def add_rate_transition(
self,
output: "BaseDiseaseState",
triggered: Trigger = Trigger.NOT_TRIGGERED,
transition_rate: DataInput | None = None,
rate_type: str = "transition_rate",
) -> RateTransition:
"""Builds a RateTransition from this state to the given state.
Parameters
----------
output
The end state after the transition.
triggered
The trigger for the transition.
transition_rate
The transition rate source. Can be the data itself, a function to
retrieve the data, or the artifact key containing the data.
rate_type
The type of rate. Can be "incidence_rate", "transition_rate", or
"remission_rate".
Returns
-------
The created transition object.
"""
transition = RateTransition(
input_state=self,
output_state=output,
triggered=triggered,
transition_rate=transition_rate,
rate_type=rate_type,
)
self.add_transition(transition)
return transition
[docs]
def add_proportion_transition(
self,
output: "BaseDiseaseState",
triggered: Trigger = Trigger.NOT_TRIGGERED,
proportion: DataInput | None = None,
) -> ProportionTransition:
"""Builds a ProportionTransition from this state to the given state.
Parameters
----------
output
The end state after the transition.
triggered
The trigger for the transition.
proportion
The proportion source. Can be the data itself, a function to
retrieve the data, or the artifact key containing the data.
Returns
-------
The created transition object.
"""
transition = ProportionTransition(
input_state=self,
output_state=output,
triggered=triggered,
proportion=proportion,
)
self.add_transition(transition)
return transition
[docs]
def add_dwell_time_transition(
self, output: "BaseDiseaseState", triggered: Trigger = Trigger.NOT_TRIGGERED
) -> Transition:
"""Build a dwell time transition from this state to the given state.
Parameters
----------
output
The end state after the transition.
triggered
The trigger for the transition.
Returns
-------
The created transition object.
"""
transition = Transition(self, output, triggered=triggered)
self.add_transition(transition)
return transition
[docs]
class NonDiseasedState(BaseDiseaseState):
"""Base class for states representing the absence of a disease condition.
Provides a name prefix mechanism for creating properly named disease
states (e.g., ``susceptible_to_`` or ``recovered_from_``).
"""
#####################
# Lifecycle methods #
#####################
def __init__(
self,
state_id: str,
allow_self_transition: bool = True,
side_effect_function: Callable | None = None,
cause_type: str = "cause",
name_prefix: str = "",
) -> None:
"""
Parameters
----------
state_id
The name of this state.
allow_self_transition
Whether this state allows simulants to remain in the state
for multiple time steps.
side_effect_function
A function to be called when this state is entered.
cause_type
The type of cause. Either "cause" or "sequela".
name_prefix
The prefix to prepend to the state ID.
"""
if not state_id.startswith(name_prefix):
state_id = f"{name_prefix}{state_id}"
super().__init__(
state_id,
allow_self_transition=allow_self_transition,
side_effect_function=side_effect_function,
cause_type=cause_type,
)
##################
# Public methods #
##################
[docs]
def add_rate_transition(
self,
output: BaseDiseaseState,
triggered: Trigger = Trigger.NOT_TRIGGERED,
transition_rate: DataInput | None = None,
) -> RateTransition:
"""Build a rate transition from this state to the given state.
If no transition rate is provided, uses the incidence rate for
the output state from the artifact.
Parameters
----------
output
The end state after the transition.
triggered
The trigger for the transition.
transition_rate
The transition rate source. Can be the data itself, a function
to retrieve the data, or the artifact key containing the data.
Returns
-------
The created transition object.
"""
if transition_rate is None:
transition_rate = f"{self.cause_type}.{output.state_id}.incidence_rate"
return super().add_rate_transition(
output=output,
triggered=triggered,
transition_rate=transition_rate,
rate_type="incidence_rate",
)
[docs]
class SusceptibleState(NonDiseasedState):
"""State representing susceptibility to a disease.
Automatically prepends ``susceptible_to_`` to the state ID.
"""
#####################
# Lifecycle methods #
#####################
def __init__(
self,
state_id: str,
allow_self_transition: bool = True,
side_effect_function: Callable | None = None,
cause_type: str = "cause",
) -> None:
"""
Parameters
----------
state_id
The name of the disease this state is susceptible to.
allow_self_transition
Whether this state allows simulants to remain in the state
for multiple time steps.
side_effect_function
A function to be called when this state is entered.
cause_type
The type of cause. Either "cause" or "sequela".
"""
super().__init__(
state_id,
allow_self_transition=allow_self_transition,
side_effect_function=side_effect_function,
cause_type=cause_type,
name_prefix="susceptible_to_",
)
##################
# Public methods #
##################
[docs]
def has_initialization_weights(self) -> bool:
"""Whether this state has initialization weights.
Returns
-------
Always True for susceptible states.
"""
return True
[docs]
class RecoveredState(NonDiseasedState):
"""State representing recovery from a disease.
Automatically prepends ``recovered_from_`` to the state ID.
"""
def __init__(
self,
state_id: str,
allow_self_transition: bool = True,
side_effect_function: Callable | None = None,
cause_type: str = "cause",
) -> None:
"""
Parameters
----------
state_id
The name of the disease this state represents recovery from.
allow_self_transition
Whether this state allows simulants to remain in the state
for multiple time steps.
side_effect_function
A function to be called when this state is entered.
cause_type
The type of cause. Either "cause" or "sequela".
"""
super().__init__(
state_id,
allow_self_transition=allow_self_transition,
side_effect_function=side_effect_function,
cause_type=cause_type,
name_prefix="recovered_from_",
)
[docs]
class ExcessMortalityState(Component, ABC):
"""Mixin for disease states that may have excess mortality."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self._has_excess_mortality = None
[docs]
def has_excess_mortality(self, builder: Builder) -> bool:
"""Determine whether this state has non-zero excess mortality.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
Returns
-------
True if the state has non-zero excess mortality data.
"""
if self._has_excess_mortality is None:
emr_source = builder.configuration.get(
[self.name, "data_sources", "excess_mortality_rate"]
)
emr_data = self.get_data(builder, emr_source)
self._has_excess_mortality = is_non_zero(emr_data)
return self._has_excess_mortality
[docs]
class DiseaseState(BaseDiseaseState, ExcessMortalityState):
"""State representing a disease in a state machine model."""
##############
# Properties #
##############
@property
def configuration_defaults(self) -> dict[str, Any]:
"""Provides default configuration values for this disease state.
Extends BaseDiseaseState's configuration with additional data sources
for disease burden metrics.
Configuration structure::
{component_name}:
data_sources:
prevalence:
Source for prevalence data. Defaults to the
``prevalence`` constructor argument, or if not
provided, loads from artifact at
``cause.{state_id}.prevalence``.
birth_prevalence:
Source for birth prevalence data. Defaults to the
``birth_prevalence`` constructor argument, or if not
provided, loads from artifact at
``cause.{state_id}.birth_prevalence``.
dwell_time:
Source for dwell time data (minimum time in state
before transition). Defaults to the ``dwell_time``
constructor argument, or if not provided, defaults
to 0 (no minimum dwell time).
disability_weight:
Source for disability weight data used to calculate
years lived with disability (YLDs). Defaults to the
``disability_weight`` constructor argument, or if not
provided, loads from artifact at
``cause.{state_id}.disability_weight``.
excess_mortality_rate:
Source for excess mortality rate data. Defaults to the
``excess_mortality_rate`` constructor argument, or if
not provided, loads from artifact at
``cause.{state_id}.excess_mortality_rate``.
"""
configuration_defaults = super().configuration_defaults
additional_defaults = {
"prevalence": self._prevalence_source,
"birth_prevalence": self._birth_prevalence_source,
"dwell_time": self._dwell_time_source,
"disability_weight": self._disability_weight_source,
"excess_mortality_rate": self._excess_mortality_rate_source,
}
data_sources = {
**configuration_defaults[self.name]["data_sources"],
**additional_defaults,
}
configuration_defaults[self.name]["data_sources"] = data_sources
return configuration_defaults
#####################
# Lifecycle methods #
#####################
def __init__(
self,
state_id: str,
allow_self_transition: bool = True,
side_effect_function: Callable | None = None,
cause_type: str = "cause",
prevalence: DataInput | None = None,
birth_prevalence: DataInput = 0.0,
dwell_time: DataInput = 0.0,
disability_weight: DataInput | None = None,
excess_mortality_rate: DataInput | None = None,
):
"""
Parameters
----------
state_id
The name of this state.
allow_self_transition
Whether this state allows simulants to remain in the state for
multiple time-steps.
side_effect_function
A function to be called when this state is entered.
cause_type
The type of cause represented by this state. Either "cause" or "sequela".
prevalence
The prevalence source. This is used to initialize simulants. Can be
the data itself, a function to retrieve the data, or the artifact
key containing the data.
birth_prevalence
The birth prevalence source. This is used to initialize newborn
simulants. Can be the data itself, a function to retrieve the data,
or the artifact key containing the data.
dwell_time
The dwell time source. This is used to determine how long a simulant
must remain in the state before transitioning. Can be the data
itself, a function to retrieve the data, or the artifact key
containing the data.
disability_weight
The disability weight source. This is used to calculate the
disability weight for simulants in this state. Can be the data
itself, a function to retrieve the data, or the artifact key
containing the data.
excess_mortality_rate
The excess mortality rate source. This is used to calculate the
excess mortality rate for simulants in this state. Can be the data
itself, a function to retrieve the data, or the artifact key
containing the data.
"""
super().__init__(
state_id,
allow_self_transition=allow_self_transition,
side_effect_function=side_effect_function,
cause_type=cause_type,
)
self.excess_mortality_rate_pipeline = f"{self.state_id}.excess_mortality_rate"
self.dw_pipeline = f"{self.state_id}.disability_weight"
self._prevalence_source = self.get_prevalence_source(prevalence)
self._birth_prevalence_source = birth_prevalence
self._dwell_time_source = dwell_time
self._disability_weight_source = self.get_disability_weight_source(disability_weight)
self._excess_mortality_rate_source = self.get_excess_mortality_rate_source(
excess_mortality_rate
)
[docs]
def setup(self, builder: Builder) -> None:
"""Performs this component's simulation setup.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
self.randomness_prevalence = self.get_randomness_prevalence(builder)
self.required_resources = [
self.model,
self.randomness_prevalence,
self.dwell_time_pipeline,
]
super().setup(builder)
self.clock = builder.time.clock()
self.disability_weight_table = self.build_lookup_table(builder, "disability_weight")
self.excess_mortality_rate_table = self.build_lookup_table(
builder, "excess_mortality_rate"
)
if self._has_excess_mortality is None:
self._has_excess_mortality = is_non_zero(self.excess_mortality_rate_table.data)
self.register_disability_weight_pipeline(builder)
builder.value.register_attribute_modifier(
"all_causes.disability_weight", modifier=self.dw_pipeline
)
self.register_excess_mortality_rate_pipeline(builder)
builder.value.register_attribute_modifier(
"mortality_rate",
modifier=self.adjust_mortality_rate,
required_resources=[self.excess_mortality_rate_pipeline],
)
#################
# Setup methods #
#################
[docs]
def get_prevalence_source(self, prevalence: DataInput | None) -> DataInput:
"""Resolve the prevalence data source.
Parameters
----------
prevalence
The prevalence source provided at construction, or None to
use the default artifact key.
Returns
-------
The resolved prevalence data source.
"""
return (
prevalence
if prevalence is not None
else f"{self.cause_type}.{self.state_id}.prevalence"
)
[docs]
def get_disability_weight_source(self, disability_weight: DataInput | None) -> DataInput:
"""Resolve the disability weight data source.
Parameters
----------
disability_weight
The disability weight source provided at construction, or
None to use the default artifact key.
Returns
-------
The resolved disability weight data source.
"""
if disability_weight is None:
disability_weight = f"{self.cause_type}.{self.state_id}.disability_weight"
def disability_weight_source(builder: Builder) -> LookupTableData:
disability_weight_ = self.get_data(builder, disability_weight)
if isinstance(disability_weight_, pd.DataFrame) and len(disability_weight_) == 1:
# sequela only have single value
disability_weight_ = disability_weight_.value[0]
return disability_weight_
return disability_weight_source
[docs]
def register_disability_weight_pipeline(self, builder: Builder) -> None:
"""Register the disability weight pipeline with the simulation.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
builder.value.register_attribute_producer(
f"{self.state_id}.disability_weight",
source=self.compute_disability_weight,
required_resources=["is_alive", self.model, self.disability_weight_table],
)
[docs]
def get_excess_mortality_rate_source(
self, excess_mortality_rate: DataInput | None
) -> DataInput:
"""Resolve the excess mortality rate data source.
Parameters
----------
excess_mortality_rate
The excess mortality rate source provided at construction,
or None to use the default artifact key.
Returns
-------
The resolved excess mortality rate data source.
"""
if excess_mortality_rate is None:
excess_mortality_rate = f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
def excess_mortality_rate_source(builder: Builder) -> LookupTableData:
if excess_mortality_rate is not None:
return self.get_data(builder, excess_mortality_rate)
elif builder.data.load(f"cause.{self.model}.restrictions")["yld_only"]:
return 0
return builder.data.load(
f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
)
return excess_mortality_rate_source
[docs]
def register_excess_mortality_rate_pipeline(self, builder: Builder) -> None:
"""Register the excess mortality rate pipeline with the simulation.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
register_risk_affected_rate_producer(
builder=builder,
name=self.excess_mortality_rate_pipeline,
source=self.compute_excess_mortality_rate,
required_resources=["is_alive", self.model, self.excess_mortality_rate_table],
)
[docs]
def get_randomness_prevalence(self, builder: Builder) -> RandomnessStream:
"""Get a randomness stream for assigning prevalent cases.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
Returns
-------
A randomness stream for prevalent case assignment.
"""
return builder.randomness.get_stream(f"{self.state_id}_prevalent_cases")
##################
# Public methods #
##################
[docs]
def has_initialization_weights(self) -> bool:
"""Whether this state has initialization weights.
Returns
-------
Always True for disease states.
"""
return True
[docs]
def add_rate_transition(
self,
output: BaseDiseaseState,
triggered: Trigger = Trigger.NOT_TRIGGERED,
transition_rate: DataInput | None = None,
rate_type: str = "transition_rate",
) -> RateTransition:
"""Build a rate transition from this state to the given state.
If no transition rate is provided, uses the remission rate for
this state from the artifact.
Parameters
----------
output
The end state after the transition.
triggered
The trigger for the transition.
transition_rate
The transition rate source. Can be the data itself, a function
to retrieve the data, or the artifact key containing the data.
rate_type
The type of rate. Can be "incidence_rate", "transition_rate",
or "remission_rate".
Returns
-------
The created transition object.
"""
if transition_rate is None:
transition_rate = f"{self.cause_type}.{self.state_id}.remission_rate"
rate_type = "remission_rate"
return super().add_rate_transition(
output=output,
triggered=triggered,
transition_rate=transition_rate,
rate_type=rate_type,
)
[docs]
def next_state(
self, index: pd.Index[int], event_time: pd.Timestamp, population_view: PopulationView
) -> None:
"""Moves a population among different disease states.
Parameters
----------
index
An iterable of integer labels for the simulants.
event_time
The time at which this transition occurs.
population_view
A view of the internal state of the simulation.
"""
eligible_index = self._filter_for_transition_eligibility(index, event_time)
return super().next_state(eligible_index, event_time, population_view)
##################################
# Pipeline sources and modifiers #
##################################
[docs]
def compute_disability_weight(self, index: pd.Index[int]) -> pd.Series[float]:
"""Gets the disability weight associated with this state.
Parameters
----------
index
An iterable of integer labels for the simulants.
Returns
-------
An iterable of disability weights indexed by the provided `index`.
"""
disability_weight = pd.Series(0.0, index=index)
with_condition = self.with_condition(index)
disability_weight.loc[with_condition] = self.disability_weight_table(with_condition)
return disability_weight
[docs]
def compute_excess_mortality_rate(self, index: pd.Index[int]) -> pd.Series[float]:
"""Get the excess mortality rate associated with this state.
Parameters
----------
index
An iterable of integer labels for the simulants.
Returns
-------
An iterable of excess mortality rates indexed by the
provided ``index``.
"""
excess_mortality_rate = pd.Series(0.0, index=index)
with_condition = self.with_condition(index)
base_excess_mort = self.excess_mortality_rate_table(with_condition)
excess_mortality_rate.loc[with_condition] = base_excess_mort
return excess_mortality_rate
[docs]
def adjust_mortality_rate(
self, index: pd.Index[int], rates_df: pd.DataFrame
) -> pd.DataFrame:
"""Modifies the baseline mortality rate for a simulant if they are in this state.
Parameters
----------
index
An iterable of integer labels for the simulants.
rates_df
A DataFrame of mortality rates.
Returns
-------
The modified DataFrame of mortality rates.
"""
rate = self.population_view.get(
index, self.excess_mortality_rate_pipeline, skip_post_processor=True
)
rates_df[self.state_id] = rate
return rates_df
##################
# Helper methods #
##################
[docs]
def get_initial_event_times(self, pop_data: SimulantData) -> pd.DataFrame:
"""Get initial event times for new simulants, including prevalent cases.
Parameters
----------
pop_data
Metadata about the simulants being initialized.
Returns
-------
A DataFrame with event time and count columns for new simulants.
"""
pop_update = super().get_initial_event_times(pop_data)
simulants_with_condition = self.population_view.get(
pop_data.index,
self.model,
query=f'{self.model}=="{self.state_id}"',
)
if not simulants_with_condition.empty:
infected_at = self._assign_event_time_for_prevalent_cases(
simulants_with_condition,
self.clock(),
self.randomness_prevalence.get_draw,
self.population_view.get(
simulants_with_condition.index, self.dwell_time_pipeline
),
)
pop_update.loc[infected_at.index, self.event_time_column] = infected_at
return pop_update
[docs]
def with_condition(self, index: pd.Index[int]) -> pd.Index[int]:
"""Get the subset of simulants who are in this disease state.
Parameters
----------
index
An iterable of integer labels for the simulants.
Returns
-------
The subset of simulants who are alive and in this state.
"""
return self.population_view.get_filtered_index(
index, query=f'{self.model}=="{self.state_id}" and is_alive == True'
)
@staticmethod
def _assign_event_time_for_prevalent_cases(
infected, current_time, randomness_func, dwell_time
):
infected_at = dwell_time * randomness_func(infected.index)
infected_at = current_time - pd.to_timedelta(infected_at, unit="D")
return infected_at
def _filter_for_transition_eligibility(
self, index: pd.Index[int], event_time: pd.Timestamp
) -> pd.Index[int]:
"""Filter out all simulants who haven't been in the state for the prescribed dwell time.
Parameters
----------
index
An iterable of integer labels for the simulants.
event_time
The time at which this transition occurs.
Returns
-------
A filtered index of the simulants.
"""
event_times = self.population_view.get(
index, self.event_time_column, query="is_alive == True"
)
dwell_time = self.population_view.get(index, self.dwell_time_pipeline)
if np.any(dwell_time) > 0:
state_exit_time = event_times + pd.to_timedelta(dwell_time, unit="D")
return event_times.loc[state_exit_time <= event_time].index
else:
return index
def _cleanup_effect(self, index: pd.Index[int], event_time: pd.Timestamp) -> None:
if self._cleanup_function is not None:
self._cleanup_function(index, event_time)
[docs]
class TransientDiseaseState(BaseDiseaseState, Transient):
"""A disease state that simulants pass through instantaneously.
Simulants do not remain in this state; they transition to another
state within the same time step.
"""
pass