"""
=========================
The Core Population Model
=========================
Provide tools for sampling and assigning core demographic characteristics to simulants.
"""
from collections.abc import Callable, Iterable
import numpy as np
import pandas as pd
from layered_config_tree.exceptions import ConfigurationKeyError
from loguru import logger
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.randomness import RandomnessStream
from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
from vivarium_public_health import utilities
from vivarium_public_health.population.data_transformations import (
assign_demographic_proportions,
load_population_structure,
rescale_binned_proportions,
smooth_ages,
)
from vivarium_public_health.population.mortality import Mortality
[docs]
class BasePopulation(Component):
"""Produce and age simulants based on demographic data.
This component handles the initialization and lifecycle management of the
core demographic attributes of the simulated population. At setup it loads
a population structure from the artifact, computes demographic sampling
proportions, and registers a population initializer that assigns each
simulant an age, sex, location, entrance time, and exit time. On each time
step simulants are aged forward by the step size.
"""
CONFIGURATION_DEFAULTS = {
"population": {
"initialization_age_min": 0,
"initialization_age_max": 125,
"untracking_age": None,
"include_sex": "Both", # Either Female, Male, or Both
}
}
##############
# Properties #
##############
@property
def time_step_priority(self) -> int:
"""The event priority for base population time-step updates."""
return 8
@property
def time_step_cleanup_priority(self) -> int:
"""The event priority for base population time-step cleanup."""
return 9
#####################
# Lifecycle methods #
#####################
def __init__(self):
super().__init__()
self._sub_components += [AgeOutSimulants(), Mortality(), Disability()]
[docs]
def setup(self, builder: Builder) -> None:
"""Load data, assign demographic proportions, and register the initializer.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
self.config = builder.configuration.population
self.key_columns = builder.configuration.randomness.key_columns
if self.config.include_sex not in ["Male", "Female", "Both"]:
raise ValueError(
"Configuration key 'population.include_sex' must be one "
"of ['Male', 'Female', 'Both']. "
f"Provided value: {self.config.include_sex}."
)
# TODO: Remove this when we remove deprecated keys.
# Validate configuration for deprecated keys
self._validate_config_for_deprecated_keys()
source_population_structure = self._load_population_structure(builder)
self.demographic_proportions = assign_demographic_proportions(
source_population_structure,
include_sex=self.config.include_sex,
)
self.randomness = self.get_randomness_streams(builder)
self.register_simulants = builder.randomness.register_simulants
# HACK / FIXME [MIC-6746]: Simplify initial population creation
# The current implementation of initialize_population is complicated
# and should be simplified/streamlined. Of note, the required_resources for the
# "sex" and "location" columns are different depending on whether or not
# the simulation's initialized age_start and age_end values are the same.
# To get around this, we simply register the initializer without specifying
# any required_resources here. This could potentially lead to a difficult-to-diagnose
# bug, but we've been doing this for a long time without known issues.
builder.population.register_initializer(
initializer=self.initialize_population,
columns=["age", "sex", "location", "entrance_time", "exit_time"],
)
#################
# Setup methods #
#################
[docs]
def get_randomness_streams(self, builder: Builder) -> dict[str, RandomnessStream]:
"""Build and return the randomness streams used during population generation.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
Returns
-------
A dictionary mapping stream name to a
:class:`~vivarium.framework.randomness.stream.RandomnessStream`. Keys are:
``'general_purpose'``
Used for overall population generation.
``'bin_selection'``
Used for selecting demographic age bins when initializing with
age bounds.
``'age_smoothing'``
Used for smoothing ages within a bin at a fixed initial age.
``'age_smoothing_age_bounds'``
Used for smoothing ages when initializing with a range of ages.
"""
return {
"general_purpose": builder.randomness.get_stream("population_generation"),
"bin_selection": builder.randomness.get_stream(
"bin_selection", initializes_crn_attributes=True
),
"age_smoothing": builder.randomness.get_stream(
"age_smoothing", initializes_crn_attributes=True
),
"age_smoothing_age_bounds": builder.randomness.get_stream(
"age_smoothing_age_bounds", initializes_crn_attributes=True
),
}
########################
# Event-driven methods #
########################
# TODO: Move most of this docstring to an rst file.
[docs]
def initialize_population(self, pop_data: SimulantData) -> None:
"""Create a population with fundamental demographic and simulation properties.
Parameters
----------
pop_data
Metadata about the simulants being initialized.
Notes
-----
When the simulation framework creates new simulants (essentially producing a new
set of simulant ids) and this component is being used, the newly created simulants
arrive here first and are assigned the demographic qualities 'age', 'sex',
and 'location' in a way that is consistent with the demographic distributions
represented by the population-level data. Additionally, the simulants are assigned
the simulation properties 'entrance_time' and 'exit_time'.
The 'exit_time' attribute simply marks when the simulant exits the simulation.
Here we are agnostic to the methods of exit (e.g., aging out, dying, etc.) as
this characteristic can be inferred from this column and other information about
the simulant and the simulation parameters.
The 'exit_time' attribute is unique in that it is created by this BasePopulation
component but we expect other components to be able to modify it as needed
(e.g., a Mortality component might change the 'exit_time' when a simulant dies).
We do this by having the components register attribute modifiers as necessary and then
have the BasePopulation component update the underlying private column data accordingly.
"""
age_params = {
"age_start": pop_data.user_data.get(
"age_start", self.config.initialization_age_min
),
"age_end": pop_data.user_data.get("age_end", self.config.initialization_age_max),
}
demographic_proportions = self.get_demographic_proportions_for_creation_time(
self.demographic_proportions, pop_data.creation_time.year
)
self.population_view.initialize(
generate_population(
simulant_ids=pop_data.index,
creation_time=pop_data.creation_time,
step_size=pop_data.creation_window,
age_params=age_params,
demographic_proportions=demographic_proportions,
randomness_streams=self.randomness,
register_simulants=self.register_simulants,
key_columns=self.key_columns,
)
)
[docs]
def on_time_step(self, event: Event) -> None:
"""Age simulants each time step by the step size.
Parameters
----------
event
The event that triggered this method call.
"""
living_idx = self.population_view.get_filtered_index(
event.index, query="is_alive == True"
)
delta = utilities.to_years(event.step_size)
self.population_view.update("age", lambda age: age.loc[living_idx] + delta)
[docs]
def on_time_step_cleanup(self, event: Event) -> None:
"""Update the 'exit_time' private column with modifications from other components.
Parameters
----------
event
The event that triggered this method call.
"""
exit_times = self.population_view.get(
event.index, "exit_time", include_untracked=True
)
self.population_view.update("exit_time", lambda _: exit_times.rename("exit_time"))
##################
# Helper methods #
##################
[docs]
@staticmethod
def get_demographic_proportions_for_creation_time(
demographic_proportions, year: int
) -> pd.DataFrame:
"""Subset the demographic proportions table to the closest reference year.
Parameters
----------
demographic_proportions
Full table of demographic proportions across all reference years,
with a ``year_start`` column.
year
The simulation year for which to retrieve proportions.
Returns
-------
Rows from ``demographic_proportions`` whose ``year_start`` matches
the closest reference year that is less than or equal to ``year``.
"""
reference_years = sorted(set(demographic_proportions.year_start))
ref_year_index = _find_bin_start_index(year, reference_years)
return demographic_proportions[
demographic_proportions.year_start == reference_years[ref_year_index]
]
# TODO: Remove this method when we remove the deprecated keys
def _validate_config_for_deprecated_keys(self) -> None:
"""Warn about deprecated configuration keys and validate consistency.
Checks whether any of the deprecated configuration keys (``age_start``,
``age_end``, ``exit_age``) are present in the population configuration.
For each deprecated key found, a warning is logged indicating the
preferred replacement key.
Raises
------
ValueError
If a deprecated key and its replacement are both explicitly set
with different values.
"""
mapper = {
"age_start": "initialization_age_min",
"age_end": "initialization_age_max",
"exit_age": "untracking_age",
}
deprecated_keys = set(mapper.keys()).intersection(self.config.keys())
for key in deprecated_keys:
provided_new_key = False
for layer in ["override", "model_override"]:
try:
new_key_value = self.config.get(mapper[key], layer=layer)
provided_new_key = True
break
except ConfigurationKeyError:
pass
if provided_new_key and self.config[key] != new_key_value:
raise ValueError(
f"Configuration contains both '{key}' and '{mapper[key]}' with different values. "
f"These keys cannot both be provided. '{key}' will soon be deprecated so please "
f"use '{mapper[key]}'. "
)
logger.warning(
"FutureWarning: "
f"Configuration key '{key}' will be deprecated in future versions of Vivarium "
f"Public Health. Use the new key '{mapper[key]}' instead."
)
def _load_population_structure(self, builder: Builder) -> pd.DataFrame:
"""Load population structure data from the artifact.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
Returns
-------
A :class:`pandas.DataFrame` containing the raw population structure.
"""
return load_population_structure(builder)
[docs]
class ScaledPopulation(BasePopulation):
"""Produce and age simulants from a rescaled population structure.
Use this component in place of :class:`BasePopulation` when simulants
represent a subset of the true population. The base population structure is
multiplied element-wise by a ``scaling_factor`` before simulants are drawn.
Attributes
----------
scaling_factor
A multiplicative scaling factor applied to the population structure.
May be a :class:`pandas.DataFrame` or a string artifact key whose
data resolves to a :class:`pandas.DataFrame`.
Example
-------
When specifying via a model configuration file:
.. code-block:: yaml
components:
vivarium_public_health:
population:
- ScaledPopulation("some.artifact.key")
"""
def __init__(self, scaling_factor: str | pd.DataFrame):
super().__init__()
self.scaling_factor = scaling_factor
def _load_population_structure(self, builder: Builder) -> pd.DataFrame:
"""Load the population structure and apply the scaling factor.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
Returns
-------
A :class:`pandas.DataFrame` with the same structure as the raw
population data, with values rescaled by :attr:`scaling_factor`.
Raises
------
ValueError
If the resolved scaling factor is not a :class:`pandas.DataFrame`.
"""
scaling_factor = self.get_data(builder, self.scaling_factor)
population_structure = load_population_structure(builder)
if not isinstance(scaling_factor, pd.DataFrame):
raise ValueError(
f"Scaling factor must be a pandas DataFrame. Provided value: {scaling_factor}"
)
start_year = builder.configuration.time.start.year
population_structure, scaling_factor = self._format_data_inputs(
population_structure, scaling_factor, start_year
)
return (population_structure * scaling_factor).reset_index()
def _format_data_inputs(
self, pop_structure: pd.DataFrame, scalar_data: pd.DataFrame, year: int
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Align population structure and scaling data to a common reference year.
Subsets both inputs to the closest reference year that is less than or
equal to ``year``. If the scaling data does not contain a ``year_start``
column it is returned unchanged.
Parameters
----------
pop_structure
Raw population structure data with a ``year_start`` column.
scalar_data
Scaling factor data, optionally with a ``year_start`` column.
year
The simulation start year used to select the reference year.
Returns
-------
A tuple of ``(population_structure, scaling_factor)`` dataframes where
both have been indexed by their non-value columns and subset to the
relevant reference year.
"""
scaling_factor = scalar_data.set_index(
[col for col in scalar_data.columns if col != "value"]
)
population_structure = pop_structure.set_index(
[col for col in pop_structure.columns if col != "value"]
)
if "year_start" not in scaling_factor.index.names:
return population_structure, scaling_factor
# Subset the population structure and scaling factors to the simulation
# start year. If the data does not contain the exact simulation start
# year, subset to the closest year less than the simulation start year.
pop_reference_years = sorted(
set(population_structure.index.get_level_values("year_start"))
)
pop_year_index = _find_bin_start_index(year, pop_reference_years)
population_structure = population_structure.loc[
population_structure.index.get_level_values("year_start")
== pop_reference_years[pop_year_index]
]
scale_reference_years = sorted(
set(scaling_factor.index.get_level_values("year_start"))
)
scale_year_index = _find_bin_start_index(year, scale_reference_years)
scaling_factor = scaling_factor.loc[
scaling_factor.index.get_level_values("year_start")
== scale_reference_years[scale_year_index]
]
return population_structure, scaling_factor
[docs]
class AgeOutSimulants(Component):
"""Remove simulants that age beyond the tracking threshold.
When ``population.untracking_age`` is configured, simulants that reach or
exceed that age are marked as ``is_aged_out = True`` during the cleanup
phase and subsequently untracked. The exit time for aged-out simulants is
set via an ``exit_time`` attribute modifier.
"""
#####################
# Lifecycle methods #
#####################
[docs]
def setup(self, builder: Builder) -> None:
"""Set up the component by registering the age-out modifier and initializer.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
self.config = builder.configuration.population
builder.value.register_attribute_modifier("exit_time", self.update_exit_times)
self.clock = builder.time.clock()
self.step_size = builder.time.step_size()
builder.population.register_tracked_query("is_aged_out == False")
builder.population.register_initializer(
initializer=self.initialize_is_aged_out, columns="is_aged_out"
)
[docs]
def update_exit_times(self, index: pd.Index, target: pd.Series) -> pd.Series:
"""Update exit times for simulants who have aged out of the simulation."""
aged_out_idx = self.population_view.get_filtered_index(
index,
query="is_aged_out == True",
include_untracked=True,
)
newly_aged_out_idx = aged_out_idx.intersection(target[target.isna()].index)
target.loc[newly_aged_out_idx] = self.clock() + self.step_size()
return target
[docs]
def initialize_is_aged_out(self, pop_data: SimulantData) -> None:
"""Initialize the ``is_aged_out`` column to ``False`` for all new simulants.
Parameters
----------
pop_data
Metadata about the simulants being initialized.
"""
self.population_view.initialize(
pd.Series(False, index=pop_data.index, name="is_aged_out")
)
[docs]
def on_time_step_cleanup(self, event: Event) -> None:
"""Mark simulants that have exceeded the untracking age as aged out.
Parameters
----------
event
The event that triggered this method call.
"""
if self.config.untracking_age is None:
return
max_age = float(self.config.untracking_age)
newly_aged_out = self.population_view.get_filtered_index(
event.index,
query=f"age >= {max_age} and is_aged_out == False",
)
if len(newly_aged_out) > 0:
self.population_view.update(
"is_aged_out",
lambda _: pd.Series(True, index=newly_aged_out, name="is_aged_out"),
)
[docs]
def generate_population(
simulant_ids: pd.Index,
creation_time: pd.Timestamp,
step_size: pd.Timedelta,
age_params: dict[str, float],
demographic_proportions: pd.DataFrame,
randomness_streams: dict[str, RandomnessStream],
register_simulants: Callable[[pd.DataFrame], None],
key_columns: Iterable[str] = ("entrance_time", "age"),
) -> pd.DataFrame:
"""Produce a random set of simulants sampled from the provided `population_data`.
Parameters
----------
simulant_ids
Values to serve as the index in the newly generated simulant DataFrame.
creation_time
The simulation time when the simulants are created.
step_size
The size of the initial time step.
age_params
Dictionary with keys:
- ``age_start``: Start of an age range.
- ``age_end``: End of an age range.
These keys specify the age interval to use for generating simulants.
demographic_proportions
Table with columns 'age', 'age_start', 'age_end', 'sex', 'year',
'location', 'population', 'P(sex, location, age| year)',
'P(sex, location | age, year)'.
randomness_streams
Source of random number generation within the vivarium common random number framework.
register_simulants
A function to register the new simulants with the CRN framework.
key_columns
A list of key columns for random number generation.
Returns
-------
Table with columns
'entrance_time'
The `pandas.Timestamp` describing when the simulant entered
the simulation. Set to `creation_time` for all simulants.
'exit_time'
The `pandas.Timestamp` describing when the simulant exited
the simulation. Set initially to `pandas.NaT`.
'age'
The age of the simulant at the current time step.
'location'
The location indicating where the simulant resides.
'sex'
The sex of the simulant ('Male' or 'Female').
Notes
-----
This function branches on whether ``age_start == age_end`` to handle two
distinct initialization strategies with different common random number (CRN)
registration requirements:
- **Fixed initial age** (``age_start == age_end``): Calls `_assign_demography_with_initial_age`.
This applies age fuzz smoothing and registers ``entrance_time`` and
``age`` to the CRN framework.
- **Age range** (``age_start != age_end``): Calls `_assign_demography_with_age_bounds`.
This selects from age bins, smooths ages further, and registers
the customizable ``key_columns`` to the CRN framework.
This branching pattern is necessary because the required CRN attributes
differ between the two cases. Rather than specify ``required_resources`` in
the framework upfront, this initializer defers that distinction to runtime
based on the provided age parameters.
"""
population = pd.DataFrame(
{
"entrance_time": creation_time,
"exit_time": pd.NaT,
},
index=simulant_ids,
)
age_start = float(age_params["age_start"])
age_end = float(age_params["age_end"])
if age_start == age_end:
return _assign_demography_with_initial_age(
population,
demographic_proportions,
age_start,
step_size,
randomness_streams,
register_simulants,
)
else: # age_params['age_start'] is not None and age_params['age_end'] is not None
return _assign_demography_with_age_bounds(
population,
demographic_proportions,
age_start,
age_end,
randomness_streams,
register_simulants,
key_columns,
)
def _assign_demography_with_initial_age(
population: pd.DataFrame,
demographic_proportions: pd.DataFrame,
initial_age: float,
step_size: pd.Timedelta,
randomness_streams: dict[str, RandomnessStream],
register_simulants: Callable[[pd.DataFrame], None],
) -> pd.DataFrame:
"""Assign age, sex, and location information to the provided simulants given a fixed age.
Applies age fuzz smoothing to the fixed initial age using the
``'age_smoothing'`` randomness stream, then registers the ``entrance_time``
and ``age`` columns to the common random number (CRN) framework.
Parameters
----------
population
Table that represents the new cohort of agents being added to the simulation.
demographic_proportions
Table with columns 'age', 'age_start', 'age_end', 'sex', 'year',
'location', 'population', 'P(sex, location, age| year)',
'P(sex, location | age, year)'
initial_age
The age to assign the new simulants.
step_size
The size of the initial time step.
randomness_streams
Source of random number generation within the vivarium common random number framework.
register_simulants
A function to register the new simulants with the CRN framework.
Returns
-------
Table with same columns as `simulants` and with the additional
columns 'age', 'sex', and 'location'.
"""
demographic_proportions = demographic_proportions[
(demographic_proportions.age_start <= initial_age)
& (demographic_proportions.age_end >= initial_age)
]
if demographic_proportions.empty:
raise ValueError(
"The age {} is not represented by the population data structure".format(
initial_age
)
)
age_fuzz = randomness_streams["age_smoothing"].get_draw(
population.index
) * utilities.to_years(step_size)
population["age"] = initial_age + age_fuzz
register_simulants(population[["entrance_time", "age"]])
# Assign a demographically accurate location and sex distribution.
choices = demographic_proportions.set_index(["sex", "location"])[
"P(sex, location | age, year)"
].reset_index()
decisions = randomness_streams["general_purpose"].choice(
population.index, choices=choices.index, p=choices["P(sex, location | age, year)"]
)
population["sex"] = choices.loc[decisions, "sex"].values
population["location"] = choices.loc[decisions, "location"].values
return population
def _assign_demography_with_age_bounds(
population: pd.DataFrame,
demographic_proportions: pd.DataFrame,
age_start: float,
age_end: float,
randomness_streams: dict[str, RandomnessStream],
register_simulants: Callable[[pd.DataFrame], None],
key_columns: Iterable[str] = ("entrance_time", "age"),
) -> pd.DataFrame:
"""Assign an age, sex, and location to the provided simulants given a range of ages.
Selects demographic age bins probabilistically using the ``'bin_selection'``
randomness stream, then smooths ages within selected bins using the
``'age_smoothing_age_bounds'`` stream. Registers the specified ``key_columns``
to the common random number (CRN) framework; this allows customization of which
columns participate in CRN generation.
Parameters
----------
population
Table that represents the new cohort of agents being added to the simulation.
demographic_proportions
Table with columns 'age', 'age_start', 'age_end', 'sex', 'year',
'location', 'population', 'P(sex, location, age| year)',
'P(sex, location | age, year)'
age_start, age_end
The start and end of the age range of interest, respectively.
randomness_streams
Source of random number generation within the vivarium common random number framework.
register_simulants
A function to register the new simulants with the CRN framework.
key_columns
A list of key columns for random number generation.
Returns
-------
Table with same columns as `simulants` and with the additional columns
'age', 'sex', and 'location'.
"""
demographic_proportions = rescale_binned_proportions(
demographic_proportions, age_start, age_end
)
if demographic_proportions.empty:
raise ValueError(
f"The age range ({age_start}, {age_end}) is not represented by the "
f"population data structure."
)
# Assign a demographically accurate age, location, and sex distribution.
sub_demographic_proportions = demographic_proportions[
(demographic_proportions.age_start >= age_start)
& (demographic_proportions.age_end <= age_end)
]
choices = sub_demographic_proportions.set_index(["age", "sex", "location"])[
"P(sex, location, age| year)"
].reset_index()
decisions = randomness_streams["bin_selection"].choice(
population.index, choices=choices.index, p=choices["P(sex, location, age| year)"]
)
population["age"] = choices.loc[decisions, "age"].values
population["sex"] = choices.loc[decisions, "sex"].values
population["location"] = choices.loc[decisions, "location"].values
population = smooth_ages(
population, demographic_proportions, randomness_streams["age_smoothing_age_bounds"]
)
register_simulants(population[list(key_columns)])
return population
def _find_bin_start_index(value: int, sorted_reference_values: list[int]) -> int:
"""Find the index of the closest reference value less than or equal to the provided value.
Parameters
----------
value
The value for which to find the closest reference value.
sorted_reference_values
A sorted list of reference values.
Returns
-------
The index of the closest reference value less than or equal to the provided value.
Raises
------
ValueError
If the provided value is less than the minimum reference value.
"""
ref_value_index = np.digitize(value, sorted_reference_values).item() - 1
if ref_value_index < 0:
raise ValueError(
f"The provided value {value} is less than the minimum reference value "
f"{min(sorted_reference_values)}."
)
return ref_value_index
[docs]
class Disability(Component):
"""Handle disability-related attributes and values.
Currently this component only sets up the all-cause disability weight pipeline.
Attributes
----------
disability_weight_pipeline
Name of the pipeline used to produce disability weights.
"""
#####################
# Lifecycle methods #
#####################
def __init__(self) -> None:
super().__init__()
self.disability_weight_pipeline = "all_causes.disability_weight"
[docs]
def setup(self, builder: Builder) -> None:
"""Register the all-cause disability weight pipeline.
Parameters
----------
builder
Access point for utilizing framework interfaces during setup.
"""
builder.value.register_attribute_producer(
self.disability_weight_pipeline,
source=lambda index: [pd.Series(0.0, index=index)],
preferred_combiner=list_combiner,
preferred_post_processor=union_post_processor,
)