Source code for vivarium.framework.results.observation

"""
============
Observations
============

An observation is a class object that records simulation results; they are responsible
for initializing, gathering, updating, and formatting results.

The provided :class:`Observation` class is an abstract base class that should
be subclassed by concrete observations. While there are no required abstract methods
to define when subclassing, the class does provide common attributes as well
as an `observe` method that determines whether to observe results for a given event.

At the highest level, an observation can be categorized as either an
:class:`UnstratifiedObservation` or a :class:`StratifiedObservation`. More specialized
implementations of these classes involve defining the various methods
provided as attributes to the parent class.

"""

from __future__ import annotations

import itertools
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING

import pandas as pd
from pandas.api.types import CategoricalDtype
from pandas.core.groupby.generic import DataFrameGroupBy

from vivarium.exceptions import VivariumError
from vivarium.framework.event import Event
from vivarium.framework.results.stratification import Stratification, get_original_col_name

if TYPE_CHECKING:
    from vivarium.framework.results.interface import PopulationFilter

VALUE_COLUMN = "value"


[docs] @dataclass class Observation(ABC): """An abstract base dataclass to be inherited by concrete observations. This class includes an :meth:`observe <observe>` method that determines whether to observe results for a given event. """ name: str """Name of the observation. It will also be the name of the output results file for this particular observation.""" population_filter: PopulationFilter """A named tuple of population filtering details. The first item is a Pandas query string to filter the population down to the simulants who should be considered for the observation. The second item is a boolean indicating whether to include untracked simulants from the observation.""" when: str """Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".""" requires_attributes: list[str] """The population attributes required for this observation.""" results_initializer: Callable[[], pd.DataFrame] """Method or function that initializes the raw observation results prior to starting the simulation. This could return, for example, an empty DataFrame or one with a complete set of stratifications as the index and all values set to 0.0.""" results_gatherer: Callable[ [ pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool], tuple[str, ...] | None, ], pd.DataFrame, ] """Method or function that gathers the new observation results.""" results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] """Method or function that updates existing raw observation results with newly gathered results.""" results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame] """Method or function that formats the raw observation results.""" to_observe: Callable[[Event], bool] """Method or function that determines whether to perform an observation on this Event.""" stratifications: tuple[Stratification, ...] | None = None """Optional tuple of the Stratifications this observation should use."""
[docs] def observe( self, df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool], stratifications: tuple[str, ...] | None, ) -> pd.DataFrame: """Gathers the results of the observation. Parameters ---------- df The population or population grouped by the stratifications. stratifications The stratifications to use for the observation. Returns ------- The results of the observation. """ return self.results_gatherer(df, stratifications)
[docs] @classmethod @abstractmethod def is_stratified(cls) -> bool: ...
[docs] class UnstratifiedObservation(Observation): """Concrete class for observing results that are not stratified. The parent class `stratifications` are set to None and the `results_initializer` method is explicitly defined. Attributes ---------- name Name of the observation. It will also be the name of the output results file for this particular observation. population_filter A named tuple of population filtering details. The first item is a Pandas query string to filter the population down to the simulants who should be considered for the observation. The second item is a boolean indicating whether to include untracked simulants from the observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_attributes The population attributes required for this observation. results_gatherer Method or function that gathers the new observation results. results_updater Method or function that updates existing raw observation results with newly gathered results. results_formatter Method or function that formats the raw observation results. to_observe Method or function that determines whether to perform an observation on this Event. """ def __init__( self, name: str, population_filter: PopulationFilter, when: str, requires_attributes: list[str], results_gatherer: Callable[[pd.DataFrame], pd.DataFrame], results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], to_observe: Callable[[Event], bool] = lambda event: True, ): def _wrap_results_gatherer( df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool], _: tuple[str, ...] | None, ) -> pd.DataFrame: if isinstance(df, DataFrameGroupBy): raise TypeError( "Must provide a dataframe to an UnstratifiedObservation. " f"Provided DataFrameGroupBy instead." ) return results_gatherer(df) super().__init__( name=name, population_filter=population_filter, when=when, requires_attributes=requires_attributes, results_initializer=self.create_empty_df, results_gatherer=_wrap_results_gatherer, results_updater=results_updater, results_formatter=results_formatter, to_observe=to_observe, )
[docs] @classmethod def is_stratified(cls) -> bool: return False
[docs] @staticmethod def create_empty_df() -> pd.DataFrame: """Initializes an empty dataframe. Returns ------- An empty DataFrame. """ return pd.DataFrame()
[docs] class StratifiedObservation(Observation): """Concrete class for observing stratified results. The parent class `results_initializer` and `results_gatherer` methods are explicitly defined and stratification-specific attributes `aggregator_sources` and `aggregator` are added. Attributes ---------- name Name of the observation. It will also be the name of the output results file for this particular observation. population_filter A named tuple of population filtering details. The first item is a Pandas query string to filter the population down to the simulants who should be considered for the observation. The second item is a boolean indicating whether to include untracked simulants from the observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_attributes The population attributes required for this observation. results_updater Method or function that updates existing raw observation results with newly gathered results. results_formatter Method or function that formats the raw observation results. aggregator_sources List of population view columns to be used in the `aggregator`. aggregator Method or function that computes the quantity for this observation. to_observe Method or function that determines whether to perform an observation on this Event. """ def __init__( self, name: str, population_filter: PopulationFilter, when: str, requires_attributes: list[str], results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], aggregator_sources: list[str] | None, aggregator: Callable[[pd.DataFrame], float | pd.Series[float]], to_observe: Callable[[Event], bool] = lambda event: True, ): super().__init__( name=name, population_filter=population_filter, when=when, requires_attributes=requires_attributes, results_initializer=self.create_expanded_df, results_gatherer=self.get_complete_stratified_results, # type: ignore [arg-type] results_updater=results_updater, results_formatter=results_formatter, to_observe=to_observe, ) self.aggregator_sources = aggregator_sources self.aggregator = aggregator
[docs] @classmethod def is_stratified(cls) -> bool: return True
[docs] def observe( self, df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool], stratifications: tuple[str, ...] | None, ) -> pd.DataFrame: """Gathers the results of the observation. Parameters ---------- df The population or population grouped by the stratifications. stratifications The stratifications to use for the observation. Returns ------- The results of the observation. """ results = super().observe(df, stratifications) self._rename_stratification_columns(results) return results
def _rename_stratification_columns(self, results: pd.DataFrame) -> None: """Convert the temporary stratified mapped index names back to their original names.""" if isinstance(results.index, pd.MultiIndex): idx_names = [get_original_col_name(name) for name in results.index.names] results.rename_axis(index=idx_names, inplace=True) else: idx_name = results.index.name if idx_name is not None: results.index.rename(get_original_col_name(idx_name), inplace=True)
[docs] def create_expanded_df(self) -> pd.DataFrame: """Initializes a dataframe of 0s with complete set of stratifications as the index. Returns ------- An empty DataFrame with the complete set of stratifications as the index. Notes ----- If no stratifications are requested, then we are aggregating over the entire population and a single-row index named 'stratification' is created. """ # Set up the complete index of all used stratifications if self.stratifications is None: raise VivariumError( f"StratifiedObserver {self.name} has None set as its stratifications." ) stratification_values = { stratification.name: stratification.categories for stratification in self.stratifications } if stratification_values: stratification_names = list(stratification_values.keys()) df = pd.DataFrame( list(itertools.product(*stratification_values.values())), columns=stratification_names, ).astype(CategoricalDtype()) else: # We are aggregating the entire population so create a single-row index stratification_names = ["stratification"] df = pd.DataFrame(["all"], columns=stratification_names).astype( CategoricalDtype() ) # Initialize a zeros dataframe df[VALUE_COLUMN] = 0.0 df = df.set_index(stratification_names) return df
[docs] def get_complete_stratified_results( self, pop_groups: DataFrameGroupBy[str, bool], stratifications: tuple[str, ...], ) -> pd.DataFrame: """Gathers results for this observation. Parameters ---------- pop_groups The population grouped by the stratifications. stratifications The stratifications to use for the observation. Returns ------- The results of the observation. """ df = self._aggregate(pop_groups, self.aggregator_sources, self.aggregator) df = self._format(df) df = self._expand_index(df) if not list(stratifications): df.index.name = "stratification" return df
@staticmethod def _aggregate( pop_groups: DataFrameGroupBy[str, bool], aggregator_sources: list[str] | None, aggregator: Callable[[pd.DataFrame], float | pd.Series[float]], ) -> pd.Series[float] | pd.DataFrame: """Applies the provided aggregator to the popoulation groups.""" aggregates = ( pop_groups[aggregator_sources].apply(aggregator).fillna(0.0) # type: ignore [arg-type] if aggregator_sources else pop_groups[pop_groups.obj.columns].apply(aggregator) # type: ignore [arg-type] ).astype(float) return aggregates @staticmethod def _format(aggregates: pd.Series[float] | pd.DataFrame) -> pd.DataFrame: """Converts the results to a dataframe and ensures the results column name is 'value'.""" df = pd.DataFrame(aggregates) if isinstance(aggregates, pd.Series) else aggregates if df.shape[1] == 1: df.rename(columns={df.columns[0]: "value"}, inplace=True) return df @staticmethod def _expand_index(aggregates: pd.DataFrame) -> pd.DataFrame: """Includes all stratifications in the results by filling missing values with 0.""" full_idx = ( pd.MultiIndex.from_product(aggregates.index.levels) if isinstance(aggregates.index, pd.MultiIndex) else aggregates.index ) aggregates = aggregates.reindex(full_idx).fillna(0.0) return aggregates
[docs] class AddingObservation(StratifiedObservation): """Concrete class for observing additive and stratified results. The parent class `results_updater` method is explicitly defined and stratification-specific attributes `aggregator_sources` and `aggregator` are added. Attributes ---------- name Name of the observation. It will also be the name of the output results file for this particular observation. population_filter A named tuple of population filtering details. The first item is a Pandas query string to filter the population down to the simulants who should be considered for the observation. The second item is a boolean indicating whether to include untracked simulants from the observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_attributes The population attributes required for this observation. results_formatter Method or function that formats the raw observation results. stratifications Tuple of Stratifications to be used by the observation. If empty, the observation is aggregated over the entire population. aggregator_sources List of population view columns to be used in the `aggregator`. aggregator Method or function that computes the quantity for this observation. to_observe Method or function that determines whether to perform an observation on this Event. """ def __init__( self, name: str, population_filter: PopulationFilter, when: str, requires_attributes: list[str], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], aggregator_sources: list[str] | None, aggregator: Callable[[pd.DataFrame], float | pd.Series[float]], to_observe: Callable[[Event], bool] = lambda event: True, ): super().__init__( name=name, population_filter=population_filter, when=when, requires_attributes=requires_attributes, results_updater=self.add_results, results_formatter=results_formatter, aggregator_sources=aggregator_sources, aggregator=aggregator, to_observe=to_observe, )
[docs] @staticmethod def add_results( existing_results: pd.DataFrame, new_observations: pd.DataFrame ) -> pd.DataFrame: """Adds newly-observed results to the existing results. Parameters ---------- existing_results The existing results DataFrame. new_observations The new observations DataFrame. Returns ------- The new results added to the existing results. Notes ----- If the new observations contain columns not present in the existing results, the columns are added to the DataFrame and initialized with 0.0s. """ updated_results = existing_results.copy() # Look for extra columns in the new_observations and initialize with 0. extra_cols = [ c for c in new_observations.columns if c not in existing_results.columns ] if extra_cols: updated_results[extra_cols] = 0.0 for col in new_observations.columns: updated_results[col] += new_observations[col] return updated_results
[docs] class ConcatenatingObservation(UnstratifiedObservation): """Concrete class for observing concatenating (and by extension, unstratified) results. The parent class `results_gatherer` and `results_updater` methods are explicitly defined. Attributes ---------- name Name of the observation. It will also be the name of the output results file for this particular observation. population_filter A named tuple of population filtering details. The first item is a Pandas query string to filter the population down to the simulants who should be considered for the observation. The second item is a boolean indicating whether to include untracked simulants from the observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_attributes The population attributes required for this observation. results_formatter Method or function that formats the raw observation results. to_observe Method or function that determines whether to perform an observation on this Event. """ def __init__( self, name: str, population_filter: PopulationFilter, when: str, requires_attributes: list[str], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], to_observe: Callable[[Event], bool] = lambda event: True, ): requires_attributes = ["event_time"] + requires_attributes super().__init__( name=name, population_filter=population_filter, when=when, requires_attributes=requires_attributes, results_gatherer=self.get_results_of_interest, results_updater=self.concatenate_results, results_formatter=results_formatter, to_observe=to_observe, )
[docs] def get_results_of_interest(self, pop: pd.DataFrame) -> pd.DataFrame: """Return the population with only the `included_columns`.""" return pop[self.requires_attributes]
[docs] @staticmethod def concatenate_results( existing_results: pd.DataFrame, new_observations: pd.DataFrame ) -> pd.DataFrame: """Concatenates the existing results with the new observations. Parameters ---------- existing_results The existing results. new_observations The new observations. Returns ------- The new results concatenated to the existing results. """ if existing_results.empty: return new_observations return pd.concat([existing_results, new_observations], axis=0).reset_index(drop=True)