Source code for vivarium.framework.randomness.index_map

"""
===================
Randomness IndexMap
===================

The :class:`IndexMap` is an internal abstraction used by the randomness system to help align
random numbers for the same simulants across multiple simulations. It's key idea is to take
a set of static identifying characteristics about a simulant and hash them to a consistent
positional index within a stream of seeded random numbers.

"""
import datetime
from typing import List, Tuple, Union

import numpy as np
import pandas as pd

from vivarium.framework.randomness.exceptions import RandomnessError


[docs] class IndexMap: """A key-index mapping with a vectorized hash and vectorized lookups.""" SIM_INDEX_COLUMN = "simulant_index" TEN_DIGIT_MODULUS = 10_000_000_000 def __init__(self, key_columns: List[str] = None, size: int = 1_000_000): self._use_crn = bool(key_columns) self._key_columns = key_columns self._map = None self._size = size
[docs] def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None: """Adds the new keys to the mapping. Parameters ---------- new_keys A pandas DataFrame indexed by the simulant index and columns corresponding to the randomness system key columns. clock_time The simulation clock time. Used as the salt during hashing to minimize inter-simulation collisions. """ if new_keys.empty or not self._use_crn: return # Nothing to do new_mapping_index, final_mapping_index = self._parse_new_keys(new_keys) final_keys = final_mapping_index.droplevel(self.SIM_INDEX_COLUMN) if len(final_keys) != len(final_keys.unique()): raise RandomnessError("Non-unique keys in index") final_mapping = self._build_final_mapping(new_mapping_index, clock_time) # Tack on the simulant index to the front of the map. final_mapping.index = final_mapping.index.join(final_mapping_index).reorder_levels( [self.SIM_INDEX_COLUMN] + self._key_columns ) final_mapping = final_mapping.sort_index(level=self.SIM_INDEX_COLUMN) self._map = final_mapping
def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.MultiIndex]: """Parses raw new keys into the mapping index. Parameters ---------- new_keys A pandas DataFrame indexed by the simulant index and columns corresponding to the randomness system key columns. Returns ------- Tuple[pd.MultiIndex, pd.MultiIndex] A tuple of the new mapping index and the final mapping index. Both are pandas indices with a level for the index assigned by the population system and additional levels for the key columns associated with the simulant index. The new mapping index contains only the values for the new keys and the final mapping combines the existing mapping and the new mapping index. """ keys = new_keys.copy() keys.index.name = self.SIM_INDEX_COLUMN new_mapping_index = keys.set_index(self._key_columns, append=True).index if self._map is None: final_mapping_index = new_mapping_index else: final_mapping_index = self._map.index.append(new_mapping_index) return new_mapping_index, final_mapping_index def _build_final_mapping( self, new_mapping_index: pd.Index, clock_time: pd.Timestamp ) -> pd.Series: """Builds a new mapping between key columns and the randomness index from the new mapping index and the existing map. Parameters ---------- new_mapping_index An index with a level for the index assigned by the population system and additional levels for the key columns associated with the simulant index. clock_time The simulation clock time. Used as the salt during hashing to minimize inter-simulation collisions. Returns ------- pd.Series The new mapping incorporating the updates from the new mapping index and resolving collisions. """ new_key_index = new_mapping_index.droplevel(self.SIM_INDEX_COLUMN) mapping_update = self._hash(new_key_index, salt=clock_time) if self._map is None: current_map = mapping_update else: old_map = self._map.droplevel(self.SIM_INDEX_COLUMN) current_map = pd.concat([old_map, mapping_update]) return self._resolve_collisions(new_key_index, current_map) def _resolve_collisions( self, new_key_index: pd.MultiIndex, current_mapping: pd.Series, ) -> pd.Series: """Resolves collisions in the new mapping by perturbing the hash. Parameters ---------- new_key_index The index of new key attributes to hash. current_mapping The new mapping incorporating the updates from the new mapping index with collisions unresolved. Returns ------- pd.Series The new mapping incorporating the updates from the new mapping index and resolving collisions. """ current_mapping = current_mapping.drop_duplicates() collisions = new_key_index.difference(current_mapping.index) salt = 1 while not collisions.empty: mapping_update = self._hash(collisions, salt) current_mapping = pd.concat([current_mapping, mapping_update]).drop_duplicates() collisions = mapping_update.index.difference(current_mapping.index) salt += 1 return current_mapping def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: """Hashes the index into an integer index in the range [0, self.stride] Parameters ---------- keys The new index to hash. salt An integer used to perturb the hash in a deterministic way. Useful in dealing with collisions. Returns ------- pandas.Series A pandas series indexed by the given keys and whose values take on integers in the range [0, len(self)]. Duplicates may appear and should be dealt with by the calling code. """ key_frame = keys.to_frame() new_map = pd.Series(0, index=keys) salt = self._convert_to_ten_digit_int(pd.Series(salt, index=keys)) for i, column_name in enumerate(key_frame.columns): column = self._convert_to_ten_digit_int(key_frame[column_name]) primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 27] out = pd.Series(1, index=column.index) for idx, p in enumerate(primes): # numpy will almost always overflow here, but it is equivalent # to modding out by 2**64. Since it's much much larger than # our map size the amount of additional periodicity this # introduces is pretty trivial. out *= np.power(p, self._digit(column, idx)) new_map += out + salt return new_map % len(self) def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series: """Converts a column of datetimes, integers, or floats into a column of 10 digit integers. Parameters ---------- column A series of datetimes, integers, or floats. Returns ------- pandas.Series A series of ten digit integers based on the input data. Raises ------ RandomnessError If the column contains data that is neither a datetime-like nor numeric. """ if isinstance(column.iloc[0], datetime.datetime): column = self._clip_to_seconds(column.view(np.int64)) elif np.issubdtype(column.iloc[0], np.integer): if not len(column >= 0) == len(column): raise RandomnessError( "Values in integer columns must be greater than or equal to zero." ) column = self._spread(column) elif np.issubdtype(column.iloc[0], np.floating): column = self._shift(column) else: raise RandomnessError( f"Unhashable column type {type(column.iloc[0])}. " "IndexMap accepts datetime like columns and numeric columns." ) return column @staticmethod def _digit(m: Union[int, pd.Series], n: int) -> Union[int, pd.Series]: """Returns the nth digit of each number in m.""" return (m // (10**n)) % 10 @staticmethod def _clip_to_seconds(m: Union[int, pd.Series]) -> Union[int, pd.Series]: """Clips UTC datetime in nanoseconds to seconds.""" return m // pd.Timedelta(1, unit="s").value def _spread(self, m: Union[int, pd.Series]) -> Union[int, pd.Series]: """Spreads out integer values to give smaller values more weight.""" return (m * 111_111) % self.TEN_DIGIT_MODULUS def _shift(self, m: Union[float, pd.Series]) -> Union[int, pd.Series]: """Shifts floats so that the first 10 decimal digits are significant.""" out = m % 1 * self.TEN_DIGIT_MODULUS // 1 if isinstance(out, pd.Series): return out.astype("int64") return int(out) def __getitem__(self, index: pd.Index) -> np.ndarray: if self._use_crn: return self._map.loc[index].values else: return index.values def __len__(self) -> int: return self._size def __repr__(self) -> str: return "IndexMap({})".format("\n ".join(repr(self._map).split("\n")))