Source code for vivarium.framework.lookup.table

"""
=============
Lookup Tables
=============

Simulations tend to require a large quantity of data to run.  :mod:`vivarium`
provides the :class:`LookupTable` abstraction to ensure that accurate data can
be retrieved when it's needed. It's a callable object that takes in a
population index and returns data specific to the individuals represented by
that index. See the :ref:`lookup concept note <lookup_concept>` for more.

"""
import dataclasses
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from numbers import Number
from typing import Callable, List, Tuple, Union

import numpy as np
import pandas as pd

from vivarium.framework.lookup.interpolation import Interpolation

ScalarValue = Union[Number, timedelta, datetime]
LookupTableData = Union[ScalarValue, pd.DataFrame, List[ScalarValue], Tuple[ScalarValue]]


[docs] @dataclasses.dataclass class LookupTable(ABC): """A callable to produces values for a population index. In :mod:`vivarium` simulations, the index is synonymous with the simulated population. The lookup system allows the user to provide different kinds of data and strategies for using that data. When the simulation is running, then, components can lookup parameter values based solely on the population index. Notes ----- These should not be created directly. Use the `lookup` method on the builder during setup. """ table_number: int """Unique identifier of the table.""" data: Union[ScalarValue, pd.DataFrame, List[ScalarValue], Tuple[ScalarValue]] """The data from which to build the interpolation.""" population_view_builder: Callable = None """Callable to get a population view to be used by the lookup table.""" key_columns: Union[List[str], Tuple[str]] = () """Column names to be used as categorical parameters in Interpolation to select between interpolation functions.""" parameter_columns: Union[List[str], Tuple] = () """Column names to be used as continuous parameters in Interpolation.""" value_columns: Union[List[str], Tuple[str]] = () """Names of value columns to be interpolated over.""" interpolation_order: int = 0 """Order of interpolation. Used to decide interpolation strategy.""" clock: Callable = None """Callable for current time in simulation.""" extrapolate: bool = True """Whether to extrapolate beyond edges of given bins.""" validate: bool = True """Whether to validate the data before building the LookupTable.""" @property def name(self) -> str: """Tables are generically named after the order they were created.""" return f"lookup_table_{self.table_number}" def __call__(self, index: pd.Index) -> Union[pd.Series, pd.DataFrame]: """Get the mapped values for the given index. Parameters ---------- index Index for population view. Returns ------- pandas.Series if only one value_column, pandas.DataFrame if multiple columns """ return self.call(index).squeeze(axis=1)
[docs] @abstractmethod def call(self, index: pd.Index) -> pd.DataFrame: """Private method to allow LookupManager to add constraints.""" pass
def __repr__(self) -> str: return "LookupTable()"
[docs] class InterpolatedTable(LookupTable): """A callable that interpolates data according to a given strategy. Notes ----- These should not be created directly. Use the `lookup` interface on the :class:`builder <vivarium.framework.engine.Builder>` during setup. """ def __init__( self, table_number: int, data: Union[ScalarValue, pd.DataFrame, List[ScalarValue], Tuple[ScalarValue]], population_view_builder: Callable, key_columns: Union[List[str], Tuple[str]], parameter_columns: Union[List[str], Tuple], value_columns: Union[List[str], Tuple[str]], interpolation_order: int, clock: Callable, extrapolate: bool, validate: bool, ): super().__init__( table_number=table_number, data=data, population_view_builder=population_view_builder, key_columns=key_columns, parameter_columns=parameter_columns, value_columns=value_columns, interpolation_order=interpolation_order, clock=clock, extrapolate=extrapolate, validate=validate, ) param_cols_with_edges = [] for p in parameter_columns: param_cols_with_edges += [(p, f"{p}_start", f"{p}_end")] view_columns = sorted((set(key_columns) | set(parameter_columns)) - {"year"}) + [ "tracked" ] self.parameter_columns_with_edges = param_cols_with_edges extra_columns = self.data.columns.difference( set(self.key_columns) | set([col for p in self.parameter_columns_with_edges for col in p]) | set(self.value_columns) ) if not self.value_columns: self.value_columns = list(extra_columns) else: self.data = self.data.drop(columns=extra_columns) self.population_view = population_view_builder(view_columns) self.interpolation = Interpolation( data, self.key_columns, self.parameter_columns_with_edges, self.value_columns, order=self.interpolation_order, extrapolate=self.extrapolate, validate=self.validate, )
[docs] def call(self, index: pd.Index) -> pd.DataFrame: """Get the interpolated values for the rows in ``index``. Parameters ---------- index Index of the population to interpolate for. Returns ------- pandas.DataFrame A table with the interpolated values for the population requested. """ pop = self.population_view.get(index) del pop["tracked"] if "year" in [col for col in self.parameter_columns]: current_time = self.clock() fractional_year = current_time.year fractional_year += current_time.timetuple().tm_yday / 365.25 pop["year"] = fractional_year return self.interpolation(pop)
def __repr__(self) -> str: return "InterpolatedTable()"
[docs] class CategoricalTable(LookupTable): """ A callable that selects values from a table based on categorical parameters across an index. Notes ----- These should not be created directly. Use the `lookup` interface on the :class:`builder <vivarium.framework.engine.Builder>` during setup. """ def __init__( self, table_number: int, data: Union[ScalarValue, pd.DataFrame, List[ScalarValue], Tuple[ScalarValue]], population_view_builder: Callable, key_columns: Union[List[str], Tuple[str]], value_columns: Union[List[str], Tuple[str]], **kwargs, ): super().__init__( table_number=table_number, data=data, population_view_builder=population_view_builder, key_columns=key_columns, value_columns=value_columns, **kwargs, ) self.population_view = population_view_builder(self.key_columns + ["tracked"]) extra_columns = self.data.columns.difference( set(self.key_columns) | set(self.value_columns) ) if not self.value_columns: self.value_columns = list(extra_columns) else: self.data = self.data.drop(columns=extra_columns)
[docs] def call(self, index: pd.Index) -> pd.DataFrame: """Get the mapped values for the rows in ``index``. Parameters ---------- index Index of the population to interpolate for. Returns ------- pandas.DataFrame A table with the mapped values for the population requested. """ pop = self.population_view.get(index) del pop["tracked"] # specify some numeric type for columns, so they won't be objects but # will be updated with whatever column type it actually is result = pd.DataFrame(index=pop.index, columns=self.value_columns, dtype=np.float64) sub_tables = pop.groupby(list(self.key_columns)) for key, sub_table in list(sub_tables): if sub_table.empty: continue category_masks = [ self.data[self.key_columns[i]] == category for i, category in enumerate(key) ] joint_mask = True for category_mask in category_masks: joint_mask = joint_mask & category_mask values = self.data.loc[joint_mask, self.value_columns].values result.loc[sub_table.index, self.value_columns] = values return result
def __repr__(self) -> str: return "CategoricalTable()"
[docs] class ScalarTable(LookupTable): """A callable that broadcasts a scalar or list of scalars over an index. Notes ----- These should not be created directly. Use the `lookup` interface on the builder during setup. """
[docs] def call(self, index: pd.Index) -> pd.DataFrame: """Broadcast this tables values over the provided index. Parameters ---------- index Index of the population to construct table for. Returns ------- pandas.DataFrame A table with a column for each of the scalar values for the population requested. """ if not isinstance(self.data, (list, tuple)): values = pd.Series( self.data, index=index, name=self.value_columns[0] if self.value_columns else None, ) else: values = dict( zip(self.value_columns, [pd.Series(v, index=index) for v in self.data]) ) return pd.DataFrame(values)
def __repr__(self) -> str: return "ScalarTable(value(s)={})".format(self.data)