"""
===============
Stratifications
===============
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import pandas as pd
from pandas.api.types import CategoricalDtype
from vivarium.types import ScalarMapper, VectorMapper
STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values"
# TODO: Parameterizing pandas objects fails below python 3.12
[docs]
@dataclass
class Stratification:
"""Class for stratifying observed quantities by specified characteristics.
Each Stratification represents a set of mutually exclusive and collectively
exhaustive categories into which simulants can be assigned.
This class includes a :meth:`stratify <stratify>` method that produces an
output column by calling the mapper on the source columns.
"""
name: str
"""Name of the stratification."""
requires_attributes: list[str]
"""The population attributes needed as input for the `mapper`."""
categories: list[str]
"""Exhaustive list of all possible stratification values."""
excluded_categories: list[str]
"""List of possible stratification values to exclude from results processing.
If None (the default), will use exclusions as defined in the configuration."""
mapper: VectorMapper | ScalarMapper | None
"""A callable that maps the population attributes specified by the
`requires_attributes` argument to the stratification categories. It can either
map the entire population or an individual simulant. A simulation will fail if
the `mapper` ever produces an invalid value."""
is_vectorized: bool = False
"""True if the `mapper` function will map the entire population, and False
if it will only map a single simulant."""
def __str__(self) -> str:
return (
f"Stratification '{self.name}' with required attributes {self.requires_attributes}, "
f"categories {self.categories}, and mapper {getattr(self.mapper, '__name__', repr(self.mapper))}"
)
def __post_init__(self) -> None:
"""Assigns a default `mapper` if none was provided and check for non-empty
`categories` and `requires_attributes` otherwise.
Raises
------
ValueError
If no mapper is provided and the number of sources is not 1.
ValueError
If the categories argument is empty.
ValueError
If the requires_attributes argument is empty.
"""
self.vectorized_mapper = self._get_vectorized_mapper(self.mapper, self.is_vectorized)
if not self.categories:
raise ValueError("The categories argument must be non-empty.")
if not self.requires_attributes:
raise ValueError("The requires_attributes argument must be non-empty.")
[docs]
def stratify(self, population: pd.DataFrame) -> pd.Series[CategoricalDtype]:
"""Applies the `mapper` to the population `sources` columns.
This creates a new Series to be added to the population. Any `excluded_categories`
(which have already been removed from `categories`) will be converted to
NaNs in the new column and dropped later at the observation level.
Parameters
----------
population
A DataFrame containing the data to be stratified.
Returns
-------
A Series containing the mapped values to be used for stratifying.
Raises
------
ValueError
If the mapper returns any values not in `categories` or `excluded_categories`.
"""
mapped_column = self.vectorized_mapper(population[self.requires_attributes])
unknown_categories = set(mapped_column) - set(
self.categories + self.excluded_categories
)
# Reduce all nans to a single one
unknown_categories = {cat for cat in unknown_categories if not pd.isna(cat)}
if mapped_column.isna().any():
unknown_categories.add(mapped_column[mapped_column.isna()].iat[0])
if unknown_categories:
raise ValueError(f"Invalid values mapped to {self.name}: {unknown_categories}")
# Convert the dtype to the allowed categories. Note that this will
# result in Nans for any values in excluded_categories.
return mapped_column.astype(
CategoricalDtype(categories=self.categories, ordered=True)
)
def _get_vectorized_mapper(
self,
user_provided_mapper: VectorMapper | ScalarMapper | None,
is_vectorized: bool,
) -> VectorMapper:
"""Chooses a VectorMapper based on the provided callable mapper."""
if user_provided_mapper is None:
if len(self.requires_attributes) != 1:
raise ValueError(
f"No mapper but {len(self.requires_attributes)} required attributes are "
f"provided for stratification {self.name}. The list of required attributes "
"must be of length 1 if no mapper is provided."
)
return self._default_mapper
elif is_vectorized:
return user_provided_mapper # type: ignore [return-value]
else:
return lambda population: population.apply(user_provided_mapper, axis=1)
@staticmethod
def _default_mapper(pop: pd.DataFrame) -> pd.Series[Any]:
"""Squeezes a DataFrame to a Series.
Parameters
----------
pop
The data to be stratified.
Returns
-------
The squeezed data to be stratified.
Notes
-----
The input DataFrame is guaranteed to have a single column.
"""
squeezed_pop: pd.Series[Any] = pop.squeeze(axis=1)
return squeezed_pop
[docs]
def get_mapped_col_name(col_name: str) -> str:
"""Returns a new column name to be used for mapped values"""
return f"{col_name}_{STRATIFICATION_COLUMN_SUFFIX}"
[docs]
def get_original_col_name(col_name: str) -> str:
"""Returns the original column name given a modified mapped column name."""
return (
col_name[: -(len(STRATIFICATION_COLUMN_SUFFIX)) - 1]
if col_name.endswith(f"_{STRATIFICATION_COLUMN_SUFFIX}")
else col_name
)