"""
==========================
Vivarium Testing Utilities
==========================
Utility functions and classes to make testing ``vivarium`` components easier.
"""
from pathlib import Path
from typing import List
import numpy as np
import pandas as pd
from vivarium import Component
from vivarium.framework import randomness
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.randomness.index_map import IndexMap
[docs]
class NonCRNTestPopulation(Component):
CONFIGURATION_DEFAULTS = {
"population": {
"initialization_age_min": 0,
"initialization_age_max": 100,
"untracking_age": None,
},
}
@property
def columns_created(self) -> List[str]:
return ["age", "sex", "location", "alive", "entrance_time", "exit_time"]
[docs]
def setup(self, builder: Builder) -> None:
self.config = builder.configuration
self.randomness = builder.randomness.get_stream(
"population_age_fuzz", initializes_crn_attributes=True
)
[docs]
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
age_start = pop_data.user_data.get(
"age_start", self.config.population.initialization_age_min
)
age_end = pop_data.user_data.get(
"age_end", self.config.population.initialization_age_max
)
location = self.config.input_data.location
population = _non_crn_build_population(
pop_data.index,
age_start,
age_end,
location,
pop_data.creation_time,
pop_data.creation_window,
self.randomness,
)
self.population_view.update(population)
[docs]
def on_time_step(self, event: Event) -> None:
population = self.population_view.get(event.index, query="alive == 'alive'")
population["age"] += event.step_size / pd.Timedelta(days=365)
self.population_view.update(population)
[docs]
class TestPopulation(NonCRNTestPopulation):
[docs]
def setup(self, builder: Builder) -> None:
super().setup(builder)
self.age_randomness = builder.randomness.get_stream(
"age_initialization", initializes_crn_attributes=True
)
self.register = builder.randomness.register_simulants
[docs]
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
age_start = pop_data.user_data.get(
"age_start", self.config.population.initialization_age_min
)
age_end = pop_data.user_data.get(
"age_end", self.config.population.initialization_age_max
)
age_draw = self.age_randomness.get_draw(pop_data.index)
if age_start == age_end:
age = age_draw * (pop_data.creation_window / pd.Timedelta(days=365)) + age_start
else:
age = age_draw * (age_end - age_start) + age_start
core_population = pd.DataFrame(
{"entrance_time": pop_data.creation_time, "age": age.values}, index=pop_data.index
)
self.register(core_population)
if "location" in self.config.input_data.keys():
location = self.config.input_data.location
else:
location = self.randomness.choice(
pop_data.index, ["USA", "Canada", "Mexico"], additional_key="location_choice"
)
population = _build_population(core_population, location, self.randomness)
self.population_view.update(population)
def _build_population(core_population, location, randomness_stream):
index = core_population.index
population = pd.DataFrame(
{
"age": core_population["age"],
"entrance_time": core_population["entrance_time"],
"sex": randomness_stream.choice(
index, ["Male", "Female"], additional_key="sex_choice"
),
"alive": pd.Series("alive", index=index),
"location": location,
"exit_time": pd.NaT,
},
index=index,
)
return population
def _non_crn_build_population(
index, age_start, age_end, location, creation_time, creation_window, randomness_stream
):
if age_start == age_end:
age = (
randomness_stream.get_draw(index) * (creation_window / pd.Timedelta(days=365))
+ age_start
)
else:
age = randomness_stream.get_draw(index) * (age_end - age_start) + age_start
population = pd.DataFrame(
{
"age": age,
"sex": randomness_stream.choice(
index, ["Male", "Female"], additional_key="sex_choice"
),
"alive": pd.Series("alive", index=index),
"location": location,
"entrance_time": creation_time,
"exit_time": pd.NaT,
},
index=index,
)
return population
[docs]
def build_table(value, year_start, year_end, columns=("age", "year", "sex", "value")):
value_columns = columns[3:]
if not isinstance(value, list):
value = [value] * len(value_columns)
if len(value) != len(value_columns):
raise ValueError("Number of values must match number of value columns")
rows = []
for age in range(0, 140):
for year in range(year_start, year_end + 1):
for sex in ["Male", "Female"]:
r_values = []
for v in value:
if v is None:
r_values.append(np.random.random())
elif callable(v):
r_values.append(v(age, sex, year))
else:
r_values.append(v)
rows.append([age, age + 1, year, year + 1, sex] + r_values)
return pd.DataFrame(
rows,
columns=["age_start", "age_end", "year_start", "year_end", "sex"]
+ list(value_columns),
)
[docs]
def make_dummy_column(name, initial_value):
class DummyColumnMaker:
@property
def name(self):
return "dummy_column_maker"
def setup(self, builder):
self.population_view = builder.population.get_view([name])
builder.population.initializes_simulants(self.make_column, creates_columns=[name])
def make_column(self, pop_data):
self.population_view.update(
pd.Series(initial_value, index=pop_data.index, name=name)
)
def __repr__(self):
return f"dummy_column(name={name}, initial_value={initial_value})"
return DummyColumnMaker()
[docs]
def get_randomness(
key="test",
clock=lambda: pd.Timestamp(1990, 7, 2),
seed=12345,
initializes_crn_attributes=False,
):
return randomness.RandomnessStream(
key,
clock,
seed=seed,
index_map=IndexMap(),
initializes_crn_attributes=initializes_crn_attributes,
)
[docs]
def reset_mocks(mocks):
for mock in mocks:
mock.reset_mock()