Source code for vivarium_public_health.population.base_population

"""
=========================
The Core Population Model
=========================

Provide tools for sampling and assigning core demographic characteristics to simulants.

"""

from collections.abc import Callable, Iterable

import numpy as np
import pandas as pd
from layered_config_tree.exceptions import ConfigurationKeyError
from loguru import logger
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.randomness import RandomnessStream
from vivarium.framework.values import Pipeline, list_combiner, union_post_processor

from vivarium_public_health import utilities
from vivarium_public_health.population.data_transformations import (
    assign_demographic_proportions,
    load_population_structure,
    rescale_binned_proportions,
    smooth_ages,
)
from vivarium_public_health.population.mortality import Mortality


[docs] class BasePopulation(Component): """Produce and age simulants based on demographic data. This component handles the initialization and lifecycle management of the core demographic attributes of the simulated population. At setup it loads a population structure from the artifact, computes demographic sampling proportions, and registers a population initializer that assigns each simulant an age, sex, location, entrance time, and exit time. On each time step simulants are aged forward by the step size. """ CONFIGURATION_DEFAULTS = { "population": { "initialization_age_min": 0, "initialization_age_max": 125, "untracking_age": None, "include_sex": "Both", # Either Female, Male, or Both } } ############## # Properties # ############## @property def time_step_priority(self) -> int: """The event priority for base population time-step updates.""" return 8 @property def time_step_cleanup_priority(self) -> int: """The event priority for base population time-step cleanup.""" return 9 ##################### # Lifecycle methods # ##################### def __init__(self): super().__init__() self._sub_components += [AgeOutSimulants(), Mortality(), Disability()]
[docs] def setup(self, builder: Builder) -> None: """Load data, assign demographic proportions, and register the initializer. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ self.config = builder.configuration.population self.key_columns = builder.configuration.randomness.key_columns if self.config.include_sex not in ["Male", "Female", "Both"]: raise ValueError( "Configuration key 'population.include_sex' must be one " "of ['Male', 'Female', 'Both']. " f"Provided value: {self.config.include_sex}." ) # TODO: Remove this when we remove deprecated keys. # Validate configuration for deprecated keys self._validate_config_for_deprecated_keys() source_population_structure = self._load_population_structure(builder) self.demographic_proportions = assign_demographic_proportions( source_population_structure, include_sex=self.config.include_sex, ) self.randomness = self.get_randomness_streams(builder) self.register_simulants = builder.randomness.register_simulants # HACK / FIXME [MIC-6746]: Simplify initial population creation # The current implementation of initialize_population is complicated # and should be simplified/streamlined. Of note, the required_resources for the # "sex" and "location" columns are different depending on whether or not # the simulation's initialized age_start and age_end values are the same. # To get around this, we simply register the initializer without specifying # any required_resources here. This could potentially lead to a difficult-to-diagnose # bug, but we've been doing this for a long time without known issues. builder.population.register_initializer( initializer=self.initialize_population, columns=["age", "sex", "location", "entrance_time", "exit_time"], )
################# # Setup methods # #################
[docs] def get_randomness_streams(self, builder: Builder) -> dict[str, RandomnessStream]: """Build and return the randomness streams used during population generation. Parameters ---------- builder Access point for utilizing framework interfaces during setup. Returns ------- A dictionary mapping stream name to a :class:`~vivarium.framework.randomness.stream.RandomnessStream`. Keys are: ``'general_purpose'`` Used for overall population generation. ``'bin_selection'`` Used for selecting demographic age bins when initializing with age bounds. ``'age_smoothing'`` Used for smoothing ages within a bin at a fixed initial age. ``'age_smoothing_age_bounds'`` Used for smoothing ages when initializing with a range of ages. """ return { "general_purpose": builder.randomness.get_stream("population_generation"), "bin_selection": builder.randomness.get_stream( "bin_selection", initializes_crn_attributes=True ), "age_smoothing": builder.randomness.get_stream( "age_smoothing", initializes_crn_attributes=True ), "age_smoothing_age_bounds": builder.randomness.get_stream( "age_smoothing_age_bounds", initializes_crn_attributes=True ), }
######################## # Event-driven methods # ######################## # TODO: Move most of this docstring to an rst file.
[docs] def initialize_population(self, pop_data: SimulantData) -> None: """Create a population with fundamental demographic and simulation properties. Parameters ---------- pop_data Metadata about the simulants being initialized. Notes ----- When the simulation framework creates new simulants (essentially producing a new set of simulant ids) and this component is being used, the newly created simulants arrive here first and are assigned the demographic qualities 'age', 'sex', and 'location' in a way that is consistent with the demographic distributions represented by the population-level data. Additionally, the simulants are assigned the simulation properties 'entrance_time' and 'exit_time'. The 'exit_time' attribute simply marks when the simulant exits the simulation. Here we are agnostic to the methods of exit (e.g., aging out, dying, etc.) as this characteristic can be inferred from this column and other information about the simulant and the simulation parameters. The 'exit_time' attribute is unique in that it is created by this BasePopulation component but we expect other components to be able to modify it as needed (e.g., a Mortality component might change the 'exit_time' when a simulant dies). We do this by having the components register attribute modifiers as necessary and then have the BasePopulation component update the underlying private column data accordingly. """ age_params = { "age_start": pop_data.user_data.get( "age_start", self.config.initialization_age_min ), "age_end": pop_data.user_data.get("age_end", self.config.initialization_age_max), } demographic_proportions = self.get_demographic_proportions_for_creation_time( self.demographic_proportions, pop_data.creation_time.year ) self.population_view.initialize( generate_population( simulant_ids=pop_data.index, creation_time=pop_data.creation_time, step_size=pop_data.creation_window, age_params=age_params, demographic_proportions=demographic_proportions, randomness_streams=self.randomness, register_simulants=self.register_simulants, key_columns=self.key_columns, ) )
[docs] def on_time_step(self, event: Event) -> None: """Age simulants each time step by the step size. Parameters ---------- event The event that triggered this method call. """ living_idx = self.population_view.get_filtered_index( event.index, query="is_alive == True" ) delta = utilities.to_years(event.step_size) self.population_view.update("age", lambda age: age.loc[living_idx] + delta)
[docs] def on_time_step_cleanup(self, event: Event) -> None: """Update the 'exit_time' private column with modifications from other components. Parameters ---------- event The event that triggered this method call. """ exit_times = self.population_view.get( event.index, "exit_time", include_untracked=True ) self.population_view.update("exit_time", lambda _: exit_times.rename("exit_time"))
################## # Helper methods # ##################
[docs] @staticmethod def get_demographic_proportions_for_creation_time( demographic_proportions, year: int ) -> pd.DataFrame: """Subset the demographic proportions table to the closest reference year. Parameters ---------- demographic_proportions Full table of demographic proportions across all reference years, with a ``year_start`` column. year The simulation year for which to retrieve proportions. Returns ------- Rows from ``demographic_proportions`` whose ``year_start`` matches the closest reference year that is less than or equal to ``year``. """ reference_years = sorted(set(demographic_proportions.year_start)) ref_year_index = _find_bin_start_index(year, reference_years) return demographic_proportions[ demographic_proportions.year_start == reference_years[ref_year_index] ]
# TODO: Remove this method when we remove the deprecated keys def _validate_config_for_deprecated_keys(self) -> None: """Warn about deprecated configuration keys and validate consistency. Checks whether any of the deprecated configuration keys (``age_start``, ``age_end``, ``exit_age``) are present in the population configuration. For each deprecated key found, a warning is logged indicating the preferred replacement key. Raises ------ ValueError If a deprecated key and its replacement are both explicitly set with different values. """ mapper = { "age_start": "initialization_age_min", "age_end": "initialization_age_max", "exit_age": "untracking_age", } deprecated_keys = set(mapper.keys()).intersection(self.config.keys()) for key in deprecated_keys: provided_new_key = False for layer in ["override", "model_override"]: try: new_key_value = self.config.get(mapper[key], layer=layer) provided_new_key = True break except ConfigurationKeyError: pass if provided_new_key and self.config[key] != new_key_value: raise ValueError( f"Configuration contains both '{key}' and '{mapper[key]}' with different values. " f"These keys cannot both be provided. '{key}' will soon be deprecated so please " f"use '{mapper[key]}'. " ) logger.warning( "FutureWarning: " f"Configuration key '{key}' will be deprecated in future versions of Vivarium " f"Public Health. Use the new key '{mapper[key]}' instead." ) def _load_population_structure(self, builder: Builder) -> pd.DataFrame: """Load population structure data from the artifact. Parameters ---------- builder Access point for utilizing framework interfaces during setup. Returns ------- A :class:`pandas.DataFrame` containing the raw population structure. """ return load_population_structure(builder)
[docs] class ScaledPopulation(BasePopulation): """Produce and age simulants from a rescaled population structure. Use this component in place of :class:`BasePopulation` when simulants represent a subset of the true population. The base population structure is multiplied element-wise by a ``scaling_factor`` before simulants are drawn. Attributes ---------- scaling_factor A multiplicative scaling factor applied to the population structure. May be a :class:`pandas.DataFrame` or a string artifact key whose data resolves to a :class:`pandas.DataFrame`. Example ------- When specifying via a model configuration file: .. code-block:: yaml components: vivarium_public_health: population: - ScaledPopulation("some.artifact.key") """ def __init__(self, scaling_factor: str | pd.DataFrame): super().__init__() self.scaling_factor = scaling_factor def _load_population_structure(self, builder: Builder) -> pd.DataFrame: """Load the population structure and apply the scaling factor. Parameters ---------- builder Access point for utilizing framework interfaces during setup. Returns ------- A :class:`pandas.DataFrame` with the same structure as the raw population data, with values rescaled by :attr:`scaling_factor`. Raises ------ ValueError If the resolved scaling factor is not a :class:`pandas.DataFrame`. """ scaling_factor = self.get_data(builder, self.scaling_factor) population_structure = load_population_structure(builder) if not isinstance(scaling_factor, pd.DataFrame): raise ValueError( f"Scaling factor must be a pandas DataFrame. Provided value: {scaling_factor}" ) start_year = builder.configuration.time.start.year population_structure, scaling_factor = self._format_data_inputs( population_structure, scaling_factor, start_year ) return (population_structure * scaling_factor).reset_index() def _format_data_inputs( self, pop_structure: pd.DataFrame, scalar_data: pd.DataFrame, year: int ) -> tuple[pd.DataFrame, pd.DataFrame]: """Align population structure and scaling data to a common reference year. Subsets both inputs to the closest reference year that is less than or equal to ``year``. If the scaling data does not contain a ``year_start`` column it is returned unchanged. Parameters ---------- pop_structure Raw population structure data with a ``year_start`` column. scalar_data Scaling factor data, optionally with a ``year_start`` column. year The simulation start year used to select the reference year. Returns ------- A tuple of ``(population_structure, scaling_factor)`` dataframes where both have been indexed by their non-value columns and subset to the relevant reference year. """ scaling_factor = scalar_data.set_index( [col for col in scalar_data.columns if col != "value"] ) population_structure = pop_structure.set_index( [col for col in pop_structure.columns if col != "value"] ) if "year_start" not in scaling_factor.index.names: return population_structure, scaling_factor # Subset the population structure and scaling factors to the simulation # start year. If the data does not contain the exact simulation start # year, subset to the closest year less than the simulation start year. pop_reference_years = sorted( set(population_structure.index.get_level_values("year_start")) ) pop_year_index = _find_bin_start_index(year, pop_reference_years) population_structure = population_structure.loc[ population_structure.index.get_level_values("year_start") == pop_reference_years[pop_year_index] ] scale_reference_years = sorted( set(scaling_factor.index.get_level_values("year_start")) ) scale_year_index = _find_bin_start_index(year, scale_reference_years) scaling_factor = scaling_factor.loc[ scaling_factor.index.get_level_values("year_start") == scale_reference_years[scale_year_index] ] return population_structure, scaling_factor
[docs] class AgeOutSimulants(Component): """Remove simulants that age beyond the tracking threshold. When ``population.untracking_age`` is configured, simulants that reach or exceed that age are marked as ``is_aged_out = True`` during the cleanup phase and subsequently untracked. The exit time for aged-out simulants is set via an ``exit_time`` attribute modifier. """ ##################### # Lifecycle methods # #####################
[docs] def setup(self, builder: Builder) -> None: """Set up the component by registering the age-out modifier and initializer. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ self.config = builder.configuration.population builder.value.register_attribute_modifier("exit_time", self.update_exit_times) self.clock = builder.time.clock() self.step_size = builder.time.step_size() builder.population.register_tracked_query("is_aged_out == False") builder.population.register_initializer( initializer=self.initialize_is_aged_out, columns="is_aged_out" )
[docs] def update_exit_times(self, index: pd.Index, target: pd.Series) -> pd.Series: """Update exit times for simulants who have aged out of the simulation.""" aged_out_idx = self.population_view.get_filtered_index( index, query="is_aged_out == True", include_untracked=True, ) newly_aged_out_idx = aged_out_idx.intersection(target[target.isna()].index) target.loc[newly_aged_out_idx] = self.clock() + self.step_size() return target
[docs] def initialize_is_aged_out(self, pop_data: SimulantData) -> None: """Initialize the ``is_aged_out`` column to ``False`` for all new simulants. Parameters ---------- pop_data Metadata about the simulants being initialized. """ self.population_view.initialize( pd.Series(False, index=pop_data.index, name="is_aged_out") )
[docs] def on_time_step_cleanup(self, event: Event) -> None: """Mark simulants that have exceeded the untracking age as aged out. Parameters ---------- event The event that triggered this method call. """ if self.config.untracking_age is None: return max_age = float(self.config.untracking_age) newly_aged_out = self.population_view.get_filtered_index( event.index, query=f"age >= {max_age} and is_aged_out == False", ) if len(newly_aged_out) > 0: self.population_view.update( "is_aged_out", lambda _: pd.Series(True, index=newly_aged_out, name="is_aged_out"), )
[docs] def generate_population( simulant_ids: pd.Index, creation_time: pd.Timestamp, step_size: pd.Timedelta, age_params: dict[str, float], demographic_proportions: pd.DataFrame, randomness_streams: dict[str, RandomnessStream], register_simulants: Callable[[pd.DataFrame], None], key_columns: Iterable[str] = ("entrance_time", "age"), ) -> pd.DataFrame: """Produce a random set of simulants sampled from the provided `population_data`. Parameters ---------- simulant_ids Values to serve as the index in the newly generated simulant DataFrame. creation_time The simulation time when the simulants are created. step_size The size of the initial time step. age_params Dictionary with keys: - ``age_start``: Start of an age range. - ``age_end``: End of an age range. These keys specify the age interval to use for generating simulants. demographic_proportions Table with columns 'age', 'age_start', 'age_end', 'sex', 'year', 'location', 'population', 'P(sex, location, age| year)', 'P(sex, location | age, year)'. randomness_streams Source of random number generation within the vivarium common random number framework. register_simulants A function to register the new simulants with the CRN framework. key_columns A list of key columns for random number generation. Returns ------- Table with columns 'entrance_time' The `pandas.Timestamp` describing when the simulant entered the simulation. Set to `creation_time` for all simulants. 'exit_time' The `pandas.Timestamp` describing when the simulant exited the simulation. Set initially to `pandas.NaT`. 'age' The age of the simulant at the current time step. 'location' The location indicating where the simulant resides. 'sex' The sex of the simulant ('Male' or 'Female'). Notes ----- This function branches on whether ``age_start == age_end`` to handle two distinct initialization strategies with different common random number (CRN) registration requirements: - **Fixed initial age** (``age_start == age_end``): Calls `_assign_demography_with_initial_age`. This applies age fuzz smoothing and registers ``entrance_time`` and ``age`` to the CRN framework. - **Age range** (``age_start != age_end``): Calls `_assign_demography_with_age_bounds`. This selects from age bins, smooths ages further, and registers the customizable ``key_columns`` to the CRN framework. This branching pattern is necessary because the required CRN attributes differ between the two cases. Rather than specify ``required_resources`` in the framework upfront, this initializer defers that distinction to runtime based on the provided age parameters. """ population = pd.DataFrame( { "entrance_time": creation_time, "exit_time": pd.NaT, }, index=simulant_ids, ) age_start = float(age_params["age_start"]) age_end = float(age_params["age_end"]) if age_start == age_end: return _assign_demography_with_initial_age( population, demographic_proportions, age_start, step_size, randomness_streams, register_simulants, ) else: # age_params['age_start'] is not None and age_params['age_end'] is not None return _assign_demography_with_age_bounds( population, demographic_proportions, age_start, age_end, randomness_streams, register_simulants, key_columns, )
def _assign_demography_with_initial_age( population: pd.DataFrame, demographic_proportions: pd.DataFrame, initial_age: float, step_size: pd.Timedelta, randomness_streams: dict[str, RandomnessStream], register_simulants: Callable[[pd.DataFrame], None], ) -> pd.DataFrame: """Assign age, sex, and location information to the provided simulants given a fixed age. Applies age fuzz smoothing to the fixed initial age using the ``'age_smoothing'`` randomness stream, then registers the ``entrance_time`` and ``age`` columns to the common random number (CRN) framework. Parameters ---------- population Table that represents the new cohort of agents being added to the simulation. demographic_proportions Table with columns 'age', 'age_start', 'age_end', 'sex', 'year', 'location', 'population', 'P(sex, location, age| year)', 'P(sex, location | age, year)' initial_age The age to assign the new simulants. step_size The size of the initial time step. randomness_streams Source of random number generation within the vivarium common random number framework. register_simulants A function to register the new simulants with the CRN framework. Returns ------- Table with same columns as `simulants` and with the additional columns 'age', 'sex', and 'location'. """ demographic_proportions = demographic_proportions[ (demographic_proportions.age_start <= initial_age) & (demographic_proportions.age_end >= initial_age) ] if demographic_proportions.empty: raise ValueError( "The age {} is not represented by the population data structure".format( initial_age ) ) age_fuzz = randomness_streams["age_smoothing"].get_draw( population.index ) * utilities.to_years(step_size) population["age"] = initial_age + age_fuzz register_simulants(population[["entrance_time", "age"]]) # Assign a demographically accurate location and sex distribution. choices = demographic_proportions.set_index(["sex", "location"])[ "P(sex, location | age, year)" ].reset_index() decisions = randomness_streams["general_purpose"].choice( population.index, choices=choices.index, p=choices["P(sex, location | age, year)"] ) population["sex"] = choices.loc[decisions, "sex"].values population["location"] = choices.loc[decisions, "location"].values return population def _assign_demography_with_age_bounds( population: pd.DataFrame, demographic_proportions: pd.DataFrame, age_start: float, age_end: float, randomness_streams: dict[str, RandomnessStream], register_simulants: Callable[[pd.DataFrame], None], key_columns: Iterable[str] = ("entrance_time", "age"), ) -> pd.DataFrame: """Assign an age, sex, and location to the provided simulants given a range of ages. Selects demographic age bins probabilistically using the ``'bin_selection'`` randomness stream, then smooths ages within selected bins using the ``'age_smoothing_age_bounds'`` stream. Registers the specified ``key_columns`` to the common random number (CRN) framework; this allows customization of which columns participate in CRN generation. Parameters ---------- population Table that represents the new cohort of agents being added to the simulation. demographic_proportions Table with columns 'age', 'age_start', 'age_end', 'sex', 'year', 'location', 'population', 'P(sex, location, age| year)', 'P(sex, location | age, year)' age_start, age_end The start and end of the age range of interest, respectively. randomness_streams Source of random number generation within the vivarium common random number framework. register_simulants A function to register the new simulants with the CRN framework. key_columns A list of key columns for random number generation. Returns ------- Table with same columns as `simulants` and with the additional columns 'age', 'sex', and 'location'. """ demographic_proportions = rescale_binned_proportions( demographic_proportions, age_start, age_end ) if demographic_proportions.empty: raise ValueError( f"The age range ({age_start}, {age_end}) is not represented by the " f"population data structure." ) # Assign a demographically accurate age, location, and sex distribution. sub_demographic_proportions = demographic_proportions[ (demographic_proportions.age_start >= age_start) & (demographic_proportions.age_end <= age_end) ] choices = sub_demographic_proportions.set_index(["age", "sex", "location"])[ "P(sex, location, age| year)" ].reset_index() decisions = randomness_streams["bin_selection"].choice( population.index, choices=choices.index, p=choices["P(sex, location, age| year)"] ) population["age"] = choices.loc[decisions, "age"].values population["sex"] = choices.loc[decisions, "sex"].values population["location"] = choices.loc[decisions, "location"].values population = smooth_ages( population, demographic_proportions, randomness_streams["age_smoothing_age_bounds"] ) register_simulants(population[list(key_columns)]) return population def _find_bin_start_index(value: int, sorted_reference_values: list[int]) -> int: """Find the index of the closest reference value less than or equal to the provided value. Parameters ---------- value The value for which to find the closest reference value. sorted_reference_values A sorted list of reference values. Returns ------- The index of the closest reference value less than or equal to the provided value. Raises ------ ValueError If the provided value is less than the minimum reference value. """ ref_value_index = np.digitize(value, sorted_reference_values).item() - 1 if ref_value_index < 0: raise ValueError( f"The provided value {value} is less than the minimum reference value " f"{min(sorted_reference_values)}." ) return ref_value_index
[docs] class Disability(Component): """Handle disability-related attributes and values. Currently this component only sets up the all-cause disability weight pipeline. Attributes ---------- disability_weight_pipeline Name of the pipeline used to produce disability weights. """ ##################### # Lifecycle methods # ##################### def __init__(self) -> None: super().__init__() self.disability_weight_pipeline = "all_causes.disability_weight"
[docs] def setup(self, builder: Builder) -> None: """Register the all-cause disability weight pipeline. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ builder.value.register_attribute_producer( self.disability_weight_pipeline, source=lambda index: [pd.Series(0.0, index=index)], preferred_combiner=list_combiner, preferred_post_processor=union_post_processor, )