"""
===============
Basic Observers
===============
This module contains convenience classes for building concrete observers in
public health models.
"""
from collections.abc import Callable
import pandas as pd
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.results import Observer
from vivarium_public_health.results.columns import COLUMNS
[docs]
class PublicHealthObserver(Observer):
"""A convenience class for typical public health observers.
It exposes a method for registering the most common observation type
(adding observation) as well methods for formatting public health results
in a standardized way (to be overwritten as necessary).
"""
[docs]
def register_adding_observation(
self,
builder: Builder,
name: str,
pop_filter: str = "",
include_untracked: bool = False,
when: str = "collect_metrics",
requires_attributes: list[str] = [],
additional_stratifications: list[str] = [],
excluded_stratifications: list[str] = [],
aggregator: Callable[[pd.DataFrame], float | pd.Series] = len,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Register an adding observation to the results system.
An "adding" observation is one that adds/sums new results to existing
result values. It is the most common type of observation used in public
health models.
Parameters
----------
builder
The builder object.
name
Name of the observation. It will also be the name of the output results
file for this particular observation.
pop_filter
A Pandas query filter string to filter the population down to the
simulants who should be considered for the observation.
include_untracked
Whether to include untracked simulants from the observation.
when
Name of the lifecycle phase the observation should happen. Valid values are:
"time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
requires_attributes
The population attributes that are required by the `aggregator`.
additional_stratifications
List of additional stratification names by which to stratify this
observation by.
excluded_stratifications
List of default stratification names to remove from this observation.
aggregator
Function that computes the quantity for this observation.
to_observe
Function that takes an event and returns a boolean indicating whether
the observation should be performed for that event.
"""
builder.results.register_adding_observation(
name=name,
pop_filter=pop_filter,
include_untracked=include_untracked,
when=when,
requires_attributes=requires_attributes,
results_formatter=self.format_results,
additional_stratifications=additional_stratifications,
excluded_stratifications=excluded_stratifications,
# TODO: Remove aggregator_sources from vivarium
aggregator_sources=requires_attributes,
aggregator=aggregator,
to_observe=to_observe,
)
[docs]
def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'measure' column.
This method can be overwritten in subclasses to provide the 'measure' column.
Parameters
----------
measure
The measure name.
results
The raw results.
Returns
-------
The 'measure' column values.
"""
return pd.Series(measure, index=results.index)
[docs]
def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'entity_type' column.
This method can be overwritten in subclasses to provide the 'entity_type' column.
Parameters
----------
measure
The measure name.
results
The raw results.
Returns
-------
The 'entity_type' column values.
"""
return pd.Series("", index=results.index)
[docs]
def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'entity' column.
This method can be overwritten in subclasses to provide the 'entity' column.
Parameters
----------
measure
The measure name.
results
The raw results.
Returns
-------
The 'entity' column values.
"""
return pd.Series("", index=results.index)
[docs]
def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'sub_entity' column.
This method can be overwritten in subclasses to provide the 'sub_entity' column.
Parameters
----------
measure
The measure name.
results
The raw results.
Returns
-------
The 'sub_entity' column values.
"""
return pd.Series("", index=results.index)