"""
===================
The Population View
===================
The :class:`PopulationView` is a user-facing abstraction that manages read and write
access to the underlying :term:`population state table <Population State Table>`.
It has two primary responsibilities:
1. To provide user access to subsets of the state table when it is safe to do so.
2. To allow the user to update private data in a controlled way.
"""
from __future__ import annotations
import warnings
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, overload
import pandas as pd
import vivarium.framework.population.utilities as pop_utils
from vivarium.framework.lifecycle import lifecycle_states
from vivarium.framework.population.exceptions import PopulationError
if TYPE_CHECKING:
from vivarium.component import Component
from vivarium.framework.population.manager import PopulationManager
[docs]
class PopulationView:
"""A read/write manager for the population state table.
It can be used to both read and update the state of the population. While a
PopulationView can read any column, it can only write those columns that the
component it is attached to created (i.e. that component's private columns).
Attempts to update non-existent columns are ignored except during
simulant creation when new columns are allowed to be created.
"""
def __init__(
self,
manager: PopulationManager,
component: Component | None,
view_id: int,
):
"""
Parameters
----------
manager
The population manager for the simulation.
component
The component requesting this view. If None, the view will provide
read-only access.
view_id
The unique identifier for this view.
"""
self._manager = manager
self._component = component
self._id = view_id
##############
# Properties #
##############
@property
def name(self) -> str:
return f"population_view_{self._id}"
@property
def private_columns(self) -> list[str]:
"""The names of private columns managed by this PopulationView.
These private columns are those that were created by the component
that created this view.
"""
if self._component is None:
raise PopulationError(
"This PopulationView is read-only, so it doesn't have access to private_columns."
)
return self._manager.get_private_column_names(self._component.name)
###########
# Methods #
###########
@overload
def get(
self,
index: pd.Index[int],
attributes: str,
query: str = "",
include_untracked: bool | None = None,
skip_post_processor: Literal[False] = False,
mode: Literal["default"] = "default",
) -> pd.Series[Any]:
...
@overload
def get(
self,
index: pd.Index[int],
attributes: list[str] | tuple[str, ...],
query: str = "",
include_untracked: bool | None = None,
skip_post_processor: Literal[False] = False,
mode: Literal["default"] = "default",
) -> pd.DataFrame:
...
@overload
def get(
self,
index: pd.Index[int],
attributes: str | list[str] | tuple[str, ...],
query: str = "",
include_untracked: bool | None = None,
skip_post_processor: Literal[True] = ...,
mode: Literal["default", "source", "no-post-processors"] = "default",
) -> Any:
...
@overload
def get(
self,
index: pd.Index[int],
attributes: str | list[str] | tuple[str, ...],
query: str = "",
include_untracked: bool | None = None,
skip_post_processor: Literal[False] = False,
mode: Literal["source", "no-post-processors"] = ...,
) -> Any:
...
[docs]
def get(
self,
index: pd.Index[int],
attributes: str | list[str] | tuple[str, ...],
query: str = "",
include_untracked: bool | None = None,
skip_post_processor: Literal[True, False] = False,
mode: Literal["default", "source", "no-post-processors"] = "default",
) -> Any:
"""Gets a specific subset of the population state table.
For the rows in ``index``, return the ``attributes`` (i.e. columns) from the
state table. The resulting rows may be further filtered by the call's ``query``
and whether or not to include untracked simulants.
Parameters
----------
index
Index of the population to get. This may be further filtered by various
query conditions.
attributes
The attributes to retrieve. If a single attribute is passed in via a
string, the result will be squeezed to a Series if possible.
query
Additional conditions used to filter the index.
include_untracked
Whether to include untracked simulants. If None (default), untracked
simulants are excluded unless this pipeline was called during population
creation or inside another pipeline call. Untracked simulants are always
included if True and always excluded if False.
skip_post_processor
Whether we should invoke the post-processor on the combined
source and mutator output or return without post-processing.
This is useful when the post-processor acts as some sort of final
unit conversion (e.g. the rescale post processor).
mode
The mode for pipeline evaluation. One of "default", "source",
or "no-post-processors".
Notes
-----
If ``skip_post_processor`` is True, the returned data will not be squeezed.
Returns
-------
The attribute(s) requested subset to the ``index`` and filtered using
the various optional queries. If ``skip_post_processor`` is False, will
return a Series if a single attribute is requested or a Dataframe otherwise.
Raises
------
ValueError
If the result is expected to be a Series but is not.
If an invalid mode is provided.
"""
valid_modes = ("default", "source", "no-post-processors")
if mode not in valid_modes:
raise ValueError(f"Invalid mode '{mode}'. Must be one of {valid_modes}.")
if skip_post_processor:
warnings.warn(
"The 'skip_post_processor' parameter is deprecated. "
"Use mode='no-post-processors' instead.",
DeprecationWarning,
stacklevel=2,
)
if mode == "source":
raise ValueError("Cannot use skip_post_processor=True with mode='source'.")
mode = "no-post-processors"
squeeze: Literal[True, False] = isinstance(attributes, str)
attributes = [attributes] if isinstance(attributes, str) else list(attributes)
population = self._manager.get_population(
attributes=attributes,
index=index,
query=self._build_query(query, include_untracked),
squeeze=squeeze,
mode=mode,
)
if mode == "default" and squeeze and not isinstance(population, pd.Series):
raise ValueError(
"Expected a pandas Series to be returned when requesting a single "
"attribute, but got a DataFrame instead. If you expect this attribute "
"to be a DataFrame, you should call `get_frame()` instead."
)
return population
[docs]
def get_frame(
self,
index: pd.Index[int],
attribute: str,
query: str = "",
include_untracked: bool | None = None,
) -> pd.DataFrame:
"""Gets a single attribute as a DataFrame.
For the rows in ``index``, return the ``attributes`` (i.e. columns) from the
state table. The resulting rows may be further filtered by the call's ``query``
and whether or not to include untracked simulants.
Parameters
----------
index
Index of the population to get.
attribute
The attribute to retrieve. This attribute may contain one or more columns.
query
Additional conditions used to filter the index.
include_untracked
Whether to include untracked simulants. If None (default), untracked
simulants are excluded unless this pipeline was called during population
creation or inside another pipeline call. Untracked simulants are always
included if True and always excluded if False.
Notes
-----
The difference between this method and ``get`` is subtle. This
method always returns a dataframe even if the requested attribute contains
a single column. Further, in the event the attribute has multi-level columns,
it will be squeezed to only return the inner columns.
Calling ``get`` to request a list of a single attribute seems
identical to this, but in that case the underlying data would not be squeezed
at all, i.e. a dataframe with multi-level columns would also return the
outer columns.
Returns
-------
The attribute requested subset to the ``index`` and filtered using
the various optional queries. Will always return a DataFrame.
"""
return pd.DataFrame(
self._manager.get_population(
index=index,
attributes=[attribute],
query=self._build_query(query, include_untracked),
)
)
[docs]
def get_filtered_index(
self,
index: pd.Index[int],
query: str = "",
include_untracked: bool | None = None,
) -> pd.Index[int]:
"""Gets a specific index of the population.
The requested index may be further filtered by the call's ``query`` and
whether or not to include untracked simulants.
Parameters
----------
index
Index of the population to get.
query
Additional conditions used to filter the index.
include_untracked
Whether to include untracked simulants. If None (default), untracked
simulants are excluded unless this pipeline was called during population
creation or inside another pipeline call. Untracked simulants are always
included if True and always excluded if False.
Returns
-------
The requested and filtered population index.
"""
return self.get(
index,
attributes=[],
query=query,
include_untracked=include_untracked,
).index
[docs]
def initialize(self, data: pd.Series[Any] | pd.DataFrame) -> None:
"""Initialize private columns with the provided data.
Use this method during simulant initialization (both initial and when adding
new simulants) to set the initial values of private columns. Column names
are inferred from the data (Series name or DataFrame columns).
Parameters
----------
data
The initial values for private columns. If a :class:`pandas.Series`,
its ``name`` identifies the column. If a :class:`pandas.DataFrame`,
its column names identify the columns.
Raises
------
PopulationError
- If this view is read-only.
- If called outside of simulant initialization.
- If the data contains columns not managed by this view.
- If the data has simulants not in the population.
- If the data is missing simulants during initial population creation.
TypeError
If the data is not a Series or DataFrame.
"""
if self._component is None:
raise PopulationError(
"This PopulationView is read-only, so it doesn't have access to initialize()."
)
if not self._manager.adding_simulants:
raise PopulationError(
"initialize() can only be called during simulant initialization. "
"Use update() to modify existing data."
)
data_df = self._coerce_init_data(data, self.private_columns)
existing = pd.DataFrame(self._manager.get_private_columns(self._component))
unknown_simulants = len(data_df.index.difference(existing.index))
if unknown_simulants:
raise PopulationError(
"Population updates must have an index that is a subset of the current "
f"private data. {unknown_simulants} simulants were provided "
"in an update with no matching index in the existing table."
)
if self._manager.creating_initial_population:
missing_pops = len(existing.index.difference(data_df.index))
if missing_pops:
raise PopulationError(
"Components must initialize all simulants during population "
f"initialization. Component '{self._component.name}' is missing "
f"updates for {missing_pops} simulants."
)
new_columns = list(set(data_df.columns).difference(existing.columns))
self._manager.update(data_df[new_columns])
elif not data_df.empty:
update_columns = list(set(data_df.columns).intersection(existing.columns))
updated_cols_list = []
for column in update_columns:
column_update = self._update_column_and_ensure_dtype(
data_df[column],
existing[column],
adding_simulants=True,
)
updated_cols_list.append(column_update)
self._manager.update(pd.concat(updated_cols_list, axis=1))
@overload
def update(
self,
columns: str,
modifier: Callable[[pd.Series[Any]], pd.Series[Any]],
) -> None:
...
@overload
def update(
self,
columns: list[str],
modifier: Callable[[pd.DataFrame], pd.DataFrame],
) -> None:
...
[docs]
def update(
self,
columns: str | list[str],
modifier: Callable[..., pd.Series[Any] | pd.DataFrame],
) -> None:
"""Update private columns by applying a modifier to the current data.
Read the current values of the specified private columns, pass them to
``modifier``, and write the result back. The modifier receives a
:class:`pandas.Series` when ``columns`` is a string or a
:class:`pandas.DataFrame` when ``columns`` is a list. It should return
data in the same form, optionally with a subset of the original index
(in which case only those rows are updated).
Parameters
----------
columns
The private column(s) to update. A string for a single column
or a list of strings for multiple columns.
modifier
A callable that takes the current column data and returns the
updated values. May return a subset of the original index to
update only some rows.
Raises
------
PopulationError
- If this view is read-only.
- If the modifier returns data with unexpected columns or simulants.
TypeError
If the modifier does not return a Series, DataFrame, or scalar.
"""
if self._component is None:
raise PopulationError(
"This PopulationView is read-only, so it doesn't have access to update()."
)
if isinstance(columns, str):
squeeze = True
column_list = [columns]
else:
squeeze = False
column_list = list(columns)
current_data = self._manager.get_private_columns(self._component, columns=columns)
result = modifier(current_data.copy())
result_df = self._coerce_update_result(result, column_list, current_data.index)
if not result_df.empty:
existing_full = pd.DataFrame(current_data) if squeeze else current_data
updated_cols_list = []
for column in result_df.columns:
column_update = self._update_column_and_ensure_dtype(
result_df[column],
existing_full[column],
adding_simulants=self._manager.adding_simulants,
)
updated_cols_list.append(column_update)
self._manager.update(pd.concat(updated_cols_list, axis=1))
def __repr__(self) -> str:
name = self._component.name if self._component else "None"
private_columns = self.private_columns if self._component else "N/A"
return f"PopulationView(_id={self._id}, _component={name}, private_columns={private_columns})"
##################
# Helper methods #
##################
@staticmethod
def _coerce_update_result(
result: Any,
columns: list[str],
existing_index: pd.Index[int],
) -> pd.DataFrame:
"""Coerce the return value of a modifier callable to a DataFrame.
Parameters
----------
result
The return value of the modifier callable.
columns
The column names that were passed to the modifier.
existing_index
The index of all simulants in the private data.
Returns
-------
The result coerced to a DataFrame.
Raises
------
PopulationError
If the result contains unexpected columns or simulants.
TypeError
If the result is not a Series, DataFrame, or scalar.
"""
if result is None:
raise TypeError("The modifier returned None. Did you forget a return statement?")
if isinstance(result, pd.DataFrame):
coerced = result
elif isinstance(result, pd.Series):
if result.name is None:
if len(columns) == 1:
result = result.rename(columns[0])
else:
raise PopulationError(
"The modifier returned an unnamed Series, but multiple columns "
"were requested. The Series must be named to identify which "
"column it corresponds to, or return a DataFrame instead."
)
coerced = pd.DataFrame(result)
else:
try:
coerced = pd.DataFrame({col: result for col in columns}, index=existing_index)
except (ValueError, TypeError):
raise TypeError(
"The modifier must return a pandas Series, DataFrame, or scalar. "
f"Got {type(result)}."
)
extra_cols = set(coerced.columns).difference(columns)
if extra_cols:
raise PopulationError(
f"The modifier returned data with unexpected columns: {extra_cols}."
)
missing_cols = set(columns).difference(coerced.columns)
if missing_cols:
raise PopulationError(
f"The modifier did not return data for all requested columns. "
f"Missing: {missing_cols}."
)
unknown = coerced.index.difference(existing_index)
if len(unknown):
raise PopulationError(
f"The modifier returned {len(unknown)} simulants not in the population."
)
return coerced
@staticmethod
def _coerce_init_data(
update: pd.Series[Any] | pd.DataFrame,
private_columns: list[str],
) -> pd.DataFrame:
"""Coerces all population updates to a :class:`pandas.DataFrame` format.
Parameters
----------
update
The update to the private data owned by the component that created this view.
private_columns
The private column names owned by the component that created this view.
Returns
-------
The input data formatted as a DataFrame.
"""
if not isinstance(update, (pd.Series, pd.DataFrame)):
raise TypeError(
"The population update must be a pandas Series or DataFrame. "
f"A {type(update)} was provided."
)
if isinstance(update, pd.Series):
if update.name is None:
if len(private_columns) == 1:
update.name = private_columns[0]
else:
raise PopulationError(
"Cannot update with an unnamed pandas series unless there "
"is only a single column in the view."
)
update = pd.DataFrame(update)
if not set(update.columns).issubset(private_columns):
raise PopulationError(
f"Cannot update with a DataFrame or Series that contains columns "
f"the view does not. Dataframe contains the following extra columns: "
f"{set(update.columns).difference(private_columns)}."
)
update_columns = list(update)
if not update_columns:
raise PopulationError(
"The update method of population view is being called on a DataFrame "
"with no columns."
)
return update
@staticmethod
def _update_column_and_ensure_dtype(
update: pd.Series[Any],
existing: pd.Series[Any],
adding_simulants: bool,
) -> pd.Series[Any]:
"""Builds the updated private column with an appropriate dtype.
This method updates any existing private column values with their corresponding
new values from the update; existing values not in the update are preserved.
It also ensures that the resulting column has a dtype consistent with the
original column (unless new simulants are being added).
Parameters
----------
update
The new column values for a subset of the existing index.
existing
The existing column values for all simulants.
adding_simulants
Whether new simulants are currently being initialized.
Returns
-------
The column with the provided update applied
"""
# FIXME: This code does not work as described. I'm leaving it here because writing
# real dtype checking code is a pain and we never seem to hit the actual edge cases.
# I've also seen this error, though I don't have a reproducible and useful example.
# I'm reasonably sure what's really being accounted for here is non-nullable columns
# that temporarily have null values introduced in the space between rows being
# added to the private data and initializers filling them with their first values.
# That means the space of dtype casting issues is actually quite small. What should
# actually happen in the long term is to separate the population creation entirely
# from the mutation of existing state. I.e. there's not an actual reason we need
# to do all these sequential operations on a single underlying dataframe during
# the creation of new simulants besides the fact that it's the existing
# implementation.
update_values = update.array.copy()
new_values = existing.array.copy()
update_index_positional = existing.index.get_indexer(update.index) # type: ignore [no-untyped-call]
# Assumes the update index labels can be interpreted as an array position.
new_values[update_index_positional] = update_values
unmatched_dtypes = new_values.dtype != update_values.dtype
if unmatched_dtypes and not adding_simulants:
# This happens when the population is being grown because extending
# the index forces columns that don't have a natural null type
# to become 'object'
raise PopulationError(
"A component is corrupting the population table by modifying the dtype of "
f"the {update.name} column from {existing.dtype} to {update.dtype}."
)
new_values = new_values.astype(update_values.dtype)
new_data: pd.Series[Any] = pd.Series(
new_values, index=existing.index, name=existing.name
)
return new_data
def _build_query(self, query: str, include_untracked: bool | None) -> str:
"""Builds the full query for this PopulationView.
This combines the provided query with the population manager's tracked query
as appropriate.
Parameters
----------
query
An explicit query string to filter the index.
include_untracked
Controls whether the tracked query is applied:
- None (default): The tracked query is applied at top level, but automatically
suppressed during nested pipeline evaluation (``pipeline_evaluation_depth > 0``)
or during initialization population creation lifecycle phases.
- True: The tracked query is always suppressed (untracked simulants are included).
- False: The tracked query is always applied (untracked simulants are excluded).
Notes
-----
Only the tracked query is affected. Any explicit ``query`` argument is
always preserved so that pipeline sources can further subdivide the index.
"""
skip_tracked_query = include_untracked is True or (
include_untracked is None
and (
self._manager.get_current_state() == lifecycle_states.POPULATION_CREATION
or self._manager.pipeline_evaluation_depth > 0
)
)
return pop_utils.combine_queries(
query,
self._manager.get_tracked_query() if not skip_tracked_query else "",
)