"""
====================================
Low Birth Weight and Short Gestation
====================================
Low birth weight and short gestation (LBWSG) is a non-standard risk
implementation that has been used in several public health models.
"""
import pickle
from typing import Callable, Dict, List, Optional
import numpy as np
import pandas as pd
from vivarium.framework.engine import Builder
from vivarium.framework.lifecycle import LifeCycleError
from vivarium.framework.lookup import LookupTable
from vivarium.framework.population import SimulantData
from vivarium.framework.values import Pipeline
from vivarium_public_health.risks import Risk, RiskEffect
from vivarium_public_health.risks.data_transformations import (
get_exposure_data,
get_exposure_post_processor,
)
from vivarium_public_health.risks.distributions import PolytomousDistribution
from vivarium_public_health.utilities import EntityString, to_snake_case
CATEGORICAL = "categorical"
BIRTH_WEIGHT = "birth_weight"
GESTATIONAL_AGE = "gestational_age"
[docs]class LBWSGDistribution(PolytomousDistribution):
CONFIGURATION_DEFAULTS = {
"lbwsg_distribution": {
"age_column": "age",
"sex_column": "sex",
"year_column": "year",
}
}
#####################
# Lifecycle methods #
#####################
def __init__(self, exposure_data: pd.DataFrame = None):
super().__init__(
EntityString("risk_factor.low_birth_weight_and_short_gestation"), exposure_data
)
# noinspection PyAttributeOutsideInit
[docs] def setup(self, builder: Builder) -> None:
self.config = builder.configuration.lbwsg_distribution
self._exposure_data = self.get_exposure_data(builder)
super().setup(builder)
self.category_intervals = self.get_category_intervals(builder)
#################
# Setup methods #
#################
[docs] def get_exposure_data(self, builder: Builder) -> pd.DataFrame:
if self._exposure_data is None:
self._exposure_data = get_exposure_data(builder, self.risk)
return self._exposure_data.rename(
columns={
"sex": self.config.sex_column,
"age_start": f"{self.config.age_column}_start",
"age_end": f"{self.config.age_column}_end",
"year_start": f"{self.config.year_column}_start",
"year_end": f"{self.config.year_column}_end",
}
)
[docs] def get_exposure_parameters(self, builder: Builder) -> Pipeline:
return builder.value.register_value_producer(
self.exposure_parameters_pipeline_name,
source=builder.lookup.build_table(
self._exposure_data,
key_columns=[self.config.sex_column],
parameter_columns=[self.config.age_column, self.config.year_column],
),
requires_columns=[
self.config.sex_column,
self.config.age_column,
],
)
[docs] def get_category_intervals(self, builder: Builder) -> Dict[str, Dict[str, pd.Interval]]:
"""
Gets the intervals for each category. It is a dictionary from the string
"birth_weight" or "gestational_age" to a dictionary from the category
name to the interval
:param builder:
:return:
"""
categories = builder.data.load(f"{self.risk}.categories")
category_intervals = {
axis: {
category: self._parse_description(axis, description)
for category, description in categories.items()
}
for axis in [BIRTH_WEIGHT, GESTATIONAL_AGE]
}
return category_intervals
##################
# Public methods #
##################
[docs] def ppf(self, propensities: pd.DataFrame) -> pd.DataFrame:
"""
Takes a DataFrame with three columns: 'categorical.propensity',
'birth_weight.propensity', and 'gestational_age.propensity' which
contain each of those propensities for each simulant.
Returns a DataFrame with two columns for birth-weight and gestational
age exposures.
:param propensities:
:return:
"""
categorical_exposure = super().ppf(propensities[f"{CATEGORICAL}_propensity"])
continuous_exposures = [
self.single_axis_ppf(
axis,
propensities[f"{axis}.propensity"],
categorical_exposure=categorical_exposure,
)
for axis in self.category_intervals
]
return pd.concat(continuous_exposures, axis=1)
[docs] def single_axis_ppf(
self,
axis: str,
propensity: pd.Series,
categorical_propensity: pd.Series = None,
categorical_exposure: pd.Series = None,
) -> pd.Series:
"""
Takes an axis (either 'birth_weight' or 'gestational_age'), a propensity
and either a categorical propensity or a categorical exposure and
returns continuous exposures for that axis.
If categorical propensity is provided rather than exposure, this
function requires access to the low birth weight and short gestation
categorical exposure parameters pipeline
("risk_factor.low_birth_weight_and_short_gestation.exposure_parameters").
:param axis:
:param propensity:
:param categorical_propensity:
:param categorical_exposure:
:return:
"""
if (categorical_propensity is None) == (categorical_exposure is None):
raise ValueError(
"Either categorical propensity of categorical exposure may be provided, but not"
" both or neither."
)
if categorical_exposure is None:
categorical_exposure = super().ppf(categorical_propensity)
exposure_intervals = categorical_exposure.apply(
lambda category: self.category_intervals[axis][category]
)
exposure_left = exposure_intervals.apply(lambda interval: interval.left)
exposure_right = exposure_intervals.apply(lambda interval: interval.right)
continuous_exposure = propensity * (exposure_right - exposure_left) + exposure_left
continuous_exposure = continuous_exposure.rename(f"{axis}.exposure")
return continuous_exposure
##################
# Helper methods #
##################
@staticmethod
def _parse_description(axis: str, description: str) -> pd.Interval:
"""
Parses a string corresponding to a low birth weight and short gestation
category to an Interval
:param axis:
:param description:
:return:
"""
endpoints = {
BIRTH_WEIGHT: [
float(val) for val in description.split(", [")[1].split(")")[0].split(", ")
],
GESTATIONAL_AGE: [
float(val) for val in description.split("- [")[1].split(")")[0].split(", ")
],
}[axis]
return pd.Interval(*endpoints, closed="left") # noqa
[docs]class LBWSGRisk(Risk):
AXES = [BIRTH_WEIGHT, GESTATIONAL_AGE]
[docs] @staticmethod
def birth_exposure_pipeline_name(axis: str) -> str:
return f"{axis}.birth_exposure"
[docs] @staticmethod
def exposure_column_name(axis: str) -> str:
return f"{axis}_exposure"
##############
# Properties #
##############
@property
def columns_created(self) -> List[str]:
return [self.exposure_column_name(axis) for axis in self.AXES]
#####################
# Lifecycle methods #
#####################
def __init__(self):
super().__init__("risk_factor.low_birth_weight_and_short_gestation")
# noinspection PyAttributeOutsideInit
[docs] def setup(self, builder: Builder) -> None:
super().setup(builder)
self.birth_exposures = self.get_birth_exposure_pipelines(builder)
##########################
# Initialization methods #
##########################
[docs] def get_exposure_distribution(self) -> LBWSGDistribution:
return LBWSGDistribution()
#################
# Setup methods #
#################
[docs] def get_propensity_pipeline(self, builder: Builder) -> Optional[Pipeline]:
# Propensity only used on initialization; not being saved to avoid a cycle
return None
[docs] def get_exposure_pipeline(self, builder: Builder) -> Optional[Pipeline]:
# Exposure only used on initialization; not being saved to avoid a cycle
return None
[docs] def get_birth_exposure_pipelines(self, builder: Builder) -> Dict[str, Pipeline]:
def get_pipeline(axis_: str):
return builder.value.register_value_producer(
self.birth_exposure_pipeline_name(axis_),
source=lambda index: self.get_birth_exposure(axis_, index),
requires_columns=["age", "sex"],
requires_streams=[self.randomness_stream_name],
preferred_post_processor=get_exposure_post_processor(builder, self.risk),
)
return {
self.birth_exposure_pipeline_name(axis): get_pipeline(axis) for axis in self.AXES
}
########################
# Event-driven methods #
########################
[docs] def on_initialize_simulants(self, pop_data: SimulantData) -> None:
birth_exposures = {
self.exposure_column_name(axis): self.birth_exposures[
self.birth_exposure_pipeline_name(axis)
](pop_data.index)
for axis in self.AXES
}
self.population_view.update(pd.DataFrame(birth_exposures))
##################################
# Pipeline sources and modifiers #
##################################
[docs] def get_birth_exposure(self, axis: str, index: pd.Index) -> pd.DataFrame:
categorical_propensity = self.randomness.get_draw(index, additional_key=CATEGORICAL)
continuous_propensity = self.randomness.get_draw(index, additional_key=axis)
return self.exposure_distribution.single_axis_ppf(
axis, continuous_propensity, categorical_propensity
)
[docs] def get_current_exposure(self, index: pd.Index) -> pd.DataFrame:
raise LifeCycleError(
f"The {self.risk.name} exposure pipeline should not be called. You probably want to"
f" refer directly one of the exposure columns. During simulant initialization the birth"
f" exposure pipelines should be used instead."
)
[docs]class LBWSGRiskEffect(RiskEffect):
TMREL_BIRTH_WEIGHT_INTERVAL: pd.Interval = pd.Interval(3500.0, 4500.0)
TMREL_GESTATIONAL_AGE_INTERVAL: pd.Interval = pd.Interval(38.0, 42.0)
##############
# Properties #
##############
@property
def columns_created(self) -> List[str]:
return self._rr_column_names
@property
def columns_required(self) -> Optional[List[str]]:
return ["age", "sex"] + self.lbwsg_exposure_column_names
@property
def initialization_requirements(self) -> Dict[str, List[str]]:
return {
"requires_columns": ["sex"] + self.lbwsg_exposure_column_names,
"requires_values": [],
"requires_streams": [],
}
#####################
# Lifecycle methods #
#####################
def __init__(self, target: str):
super().__init__("risk_factor.low_birth_weight_and_short_gestation", target)
self.lbwsg_exposure_column_names = [
LBWSGRisk.exposure_column_name(axis) for axis in LBWSGRisk.AXES
]
self.relative_risk_pipeline_name = (
f"effect_of_{self.risk.name}_on_{self.target.name}.relative_risk"
)
[docs] def relative_risk_column_name(self, age_group_id) -> str:
return (
f"effect_of_{self.risk.name}_on_{age_group_id}_{self.target.name}_relative_risk"
)
# noinspection PyAttributeOutsideInit
[docs] def setup(self, builder: Builder) -> None:
self.age_intervals = self.get_age_intervals(builder)
self._rr_column_names = self.get_rr_column_names()
super().setup(builder)
self.interpolator = self.get_interpolator(builder)
#################
# Setup methods #
#################
[docs] def get_risk_exposure(self, builder: Builder) -> Callable[[pd.Index], pd.DataFrame]:
def exposure(index: pd.Index) -> pd.DataFrame:
return self.population_view.subview(self.lbwsg_exposure_column_names).get(index)
return exposure
[docs] def get_target_modifier(
self, builder: Builder
) -> Callable[[pd.Index, pd.Series], pd.Series]:
def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
return target * self.relative_risk(index)
return adjust_target
[docs] def register_target_modifier(self, builder: Builder) -> None:
builder.value.register_value_modifier(
self.target_pipeline_name,
modifier=self.target_modifier,
requires_columns=["age", "sex"],
)
[docs] def get_age_intervals(self, builder: Builder) -> Dict[str, pd.Interval]:
age_bins = builder.data.load("population.age_bins").set_index("age_start")
exposure = builder.data.load(f"{self.risk}.exposure")
exposure = exposure[exposure["age_end"] > 0]
exposed_age_group_starts = (
exposure.groupby("age_start")["value"].any().reset_index()["age_start"]
)
return {
to_snake_case(age_bins.loc[age_start, "age_group_name"]): pd.Interval(
age_start, age_bins.loc[age_start, "age_end"]
)
for age_start in exposed_age_group_starts
}
[docs] def get_rr_column_names(self) -> List[str]:
return [self.relative_risk_column_name(age_group) for age_group in self.age_intervals]
[docs] def get_relative_risk_source(self, builder: Builder) -> Pipeline:
return builder.value.register_value_producer(
self.relative_risk_pipeline_name,
source=self.get_relative_risk,
requires_columns=["age"] + self._rr_column_names,
)
[docs] def get_population_attributable_fraction_source(self, builder: Builder) -> LookupTable:
return builder.lookup.build_table(
builder.data.load(f"{self.risk}.population_attributable_fraction"),
key_columns=["sex"],
parameter_columns=["age", "year"],
)
[docs] def get_interpolator(self, builder: Builder) -> pd.Series:
age_start_to_age_group_name_map = {
interval.left: to_snake_case(age_group_name)
for age_group_name, interval in self.age_intervals.items()
}
# get relative risk data for target
interpolators = builder.data.load(f"{self.risk}.relative_risk_interpolator")
interpolators = (
# isolate RRs for target and drop non-neonatal age groups since they have RR == 1.0
interpolators[
interpolators["age_start"].isin(
[interval.left for interval in self.age_intervals.values()]
)
]
.drop(columns=["age_end", "year_start", "year_end"])
.set_index(["sex", "value"])
.apply(lambda row: (age_start_to_age_group_name_map[row["age_start"]]), axis=1)
.rename("age_group_name")
.reset_index()
.set_index(["sex", "age_group_name"])
)["value"]
interpolators = interpolators.apply(lambda x: pickle.loads(bytes.fromhex(x)))
return interpolators
########################
# Event-driven methods #
########################
[docs] def on_initialize_simulants(self, pop_data: SimulantData) -> None:
pop = self.population_view.subview(["sex"] + self.lbwsg_exposure_column_names).get(
pop_data.index
)
birth_weight = pop[LBWSGRisk.exposure_column_name(BIRTH_WEIGHT)]
gestational_age = pop[LBWSGRisk.exposure_column_name(GESTATIONAL_AGE)]
is_male = pop["sex"] == "Male"
is_tmrel = (self.TMREL_GESTATIONAL_AGE_INTERVAL.left <= gestational_age) & (
self.TMREL_BIRTH_WEIGHT_INTERVAL.left <= birth_weight
)
def get_relative_risk_for_age_group(age_group: int) -> pd.Series:
column_name = self.relative_risk_column_name(age_group)
log_relative_risk = pd.Series(0.0, index=pop_data.index, name=column_name)
male_interpolator = self.interpolator["Male", age_group]
log_relative_risk[is_male & ~is_tmrel] = male_interpolator(
gestational_age[is_male & ~is_tmrel],
birth_weight[is_male & ~is_tmrel],
grid=False,
)
female_interpolator = self.interpolator["Female", age_group]
log_relative_risk[~is_male & ~is_tmrel] = female_interpolator(
gestational_age[~is_male & ~is_tmrel],
birth_weight[~is_male & ~is_tmrel],
grid=False,
)
return np.exp(log_relative_risk)
relative_risk_columns = [
get_relative_risk_for_age_group(age_group) for age_group in self.age_intervals
]
self.population_view.update(pd.concat(relative_risk_columns, axis=1))
##################################
# Pipeline sources and modifiers #
##################################
[docs] def get_relative_risk(self, index: pd.Index) -> pd.Series:
pop = self.population_view.get(index)
relative_risk = pd.Series(1.0, index=index, name=self.relative_risk_pipeline_name)
for age_group, interval in self.age_intervals.items():
age_group_mask = (interval.left <= pop["age"]) & (pop["age"] < interval.right)
relative_risk[age_group_mask] = pop.loc[
age_group_mask, self.relative_risk_column_name(age_group)
]
return relative_risk