"""Errors and utility functions for input processing."""
from numbers import Real
from typing import List, Tuple, Union
import numpy as np
import pandas as pd
from gbd_mapping import Cause, RiskFactor, causes, risk_factors
from vivarium_inputs import utility_data
from vivarium_inputs.globals import (
DEMOGRAPHIC_COLUMNS,
DRAW_COLUMNS,
SEXES,
SPECIAL_AGES,
)
INDEX_COLUMNS = DEMOGRAPHIC_COLUMNS + ["affected_entity", "affected_measure", "parameter"]
##################################################
# Functions to remove GBD conventions from data. #
##################################################
[docs]def scrub_gbd_conventions(data, location):
data = scrub_location(data, location)
data = scrub_sex(data)
data = scrub_age(data)
data = scrub_year(data)
data = scrub_affected_entity(data)
return data
[docs]def scrub_location(data, location):
if "location_id" in data.index.names:
data.index = data.index.rename("location", level="location_id").set_levels(
[location], level="location"
)
else:
data = pd.concat([data], keys=[location], names=["location"])
return data
[docs]def scrub_sex(data):
if "sex_id" in data.index.names:
levels = list(
data.index.levels[data.index.names.index("sex_id")].map(
lambda x: {1: "Male", 2: "Female"}.get(x, x)
)
)
data.index = data.index.rename("sex", level="sex_id").set_levels(levels, level="sex")
return data
[docs]def scrub_age(data):
if "age_group_id" in data.index.names:
age_bins = utility_data.get_age_bins().set_index("age_group_id")
id_levels = data.index.levels[data.index.names.index("age_group_id")]
interval_levels = [
pd.Interval(age_bins.age_start[age_id], age_bins.age_end[age_id], closed="left")
for age_id in id_levels
]
data.index = data.index.rename("age", level="age_group_id").set_levels(
interval_levels, level="age"
)
return data
[docs]def scrub_year(data):
if "year_id" in data.index.names:
id_levels = data.index.levels[data.index.names.index("year_id")]
interval_levels = [
pd.Interval(year_id, year_id + 1, closed="left") for year_id in id_levels
]
data.index = data.index.rename("year", level="year_id").set_levels(
interval_levels, level="year"
)
return data
[docs]def scrub_affected_entity(data):
CAUSE_BY_ID = {c.gbd_id: c for c in causes}
# RISK_BY_ID = {r.gbd_id: r for r in risk_factors}
if "cause_id" in data.columns:
data["affected_entity"] = data.cause_id.apply(
lambda cause_id: CAUSE_BY_ID[cause_id].name
)
data.drop("cause_id", axis=1, inplace=True)
return data
[docs]def set_age_interval(data):
if "age_start" in data.index.names:
bins = zip(
data.index.get_level_values("age_start"), data.index.get_level_values("age_end")
)
data = data.assign(
age=[pd.Interval(x[0], x[1], closed="left") for x in bins]
).set_index("age", append=True)
data.index = data.index.droplevel("age_start").droplevel("age_end")
return data
###############################################################
# Functions to normalize GBD data over a standard demography. #
###############################################################
[docs]def normalize(
data: pd.DataFrame, fill_value: Real = None, cols_to_fill: List[str] = DRAW_COLUMNS
) -> pd.DataFrame:
data = normalize_sex(data, fill_value, cols_to_fill)
data = normalize_year(data)
data = normalize_age(data, fill_value, cols_to_fill)
return data
[docs]def normalize_sex(data: pd.DataFrame, fill_value, cols_to_fill) -> pd.DataFrame:
sexes = set(data.sex_id.unique()) if "sex_id" in data.columns else set()
if not sexes:
# Data does not correspond to individuals, so no age column necessary.
pass
elif sexes == set(SEXES.values()):
# We have variation across sex, don't need the column for both.
data = data[data.sex_id.isin([SEXES["Male"], SEXES["Female"]])]
elif sexes == {SEXES["Combined"]}:
# Data is not sex specific, but does apply to both sexes, so copy.
fill_data = data.copy()
data.loc[:, "sex_id"] = SEXES["Male"]
fill_data.loc[:, "sex_id"] = SEXES["Female"]
data = pd.concat([data, fill_data], ignore_index=True)
elif len(sexes) == 1:
# Data is sex specific, but only applies to one sex, so fill the other with default.
fill_data = data.copy()
missing_sex = (
{SEXES["Male"], SEXES["Female"]}.difference(set(data.sex_id.unique())).pop()
)
fill_data.loc[:, "sex_id"] = missing_sex
fill_data.loc[:, cols_to_fill] = fill_value
data = pd.concat([data, fill_data], ignore_index=True)
else: # sexes == {SEXES['Male'], SEXES['Female']}
pass
return data
[docs]def normalize_year(data: pd.DataFrame) -> pd.DataFrame:
binned_years = utility_data.get_estimation_years()
years = {
"annual": list(range(min(binned_years), max(binned_years) + 1)),
"binned": binned_years,
}
if "year_id" not in data:
# Data doesn't vary by year, so copy for each year.
df = []
for year in years["annual"]:
fill_data = data.copy()
fill_data["year_id"] = year
df.append(fill_data)
data = pd.concat(df, ignore_index=True)
elif set(data.year_id) == set(years["binned"]):
data = interpolate_year(data)
else: # set(data.year_id.unique()) == years['annual']
pass
# Dump extra data.
data = data[data.year_id.isin(years["annual"])]
return data
[docs]def interpolate_year(data):
# Hide the central comp dependency unless required.
from core_maths.interpolate import pchip_interpolate
id_cols = list(set(data.columns).difference(DRAW_COLUMNS))
fillin_data = pchip_interpolate(data, id_cols, DRAW_COLUMNS)
return pd.concat([data, fillin_data], sort=True)
[docs]def normalize_age(
data: pd.DataFrame, fill_value: Real, cols_to_fill: List[str]
) -> pd.DataFrame:
data_ages = set(data.age_group_id.unique()) if "age_group_id" in data.columns else set()
gbd_ages = set(utility_data.get_age_group_ids())
if not data_ages:
# Data does not correspond to individuals, so no age column necessary.
pass
elif data_ages == {SPECIAL_AGES["all_ages"]}:
# Data applies to all ages, so copy.
dfs = []
for age in gbd_ages:
missing = data.copy()
missing.loc[:, "age_group_id"] = age
dfs.append(missing)
data = pd.concat(dfs, ignore_index=True)
elif data_ages < gbd_ages:
# Data applies to subset, so fill other ages with fill value.
key_columns = list(data.columns.difference(cols_to_fill))
key_columns.remove("age_group_id")
expected_index = pd.MultiIndex.from_product(
[data[c].unique() for c in key_columns] + [gbd_ages],
names=key_columns + ["age_group_id"],
)
data = (
data.set_index(key_columns + ["age_group_id"])
.reindex(expected_index, fill_value=fill_value)
.reset_index()
)
else: # data_ages == gbd_ages
pass
return data
[docs]def get_ordered_index_cols(data_columns: Union[pd.Index, set]):
return [i for i in INDEX_COLUMNS if i in data_columns] + list(
data_columns.difference(INDEX_COLUMNS)
)
[docs]def reshape(data: pd.DataFrame, value_cols: List = DRAW_COLUMNS) -> pd.DataFrame:
if isinstance(data, pd.DataFrame) and not isinstance(
data.index, pd.MultiIndex
): # push all non-val cols into index
data = data.set_index(get_ordered_index_cols(data.columns.difference(value_cols)))
elif not data.columns.difference(
value_cols
).empty: # we missed some columns that need to be in index
data = data.set_index(list(data.columns.difference(value_cols)), append=True)
data = data.reorder_levels(get_ordered_index_cols(set(data.index.names)))
else: # we've already set the full index
pass
return data
[docs]def wide_to_long(data: pd.DataFrame, value_cols: List, var_name: str) -> pd.DataFrame:
if set(data.columns).intersection(value_cols):
id_cols = data.columns.difference(value_cols)
data = pd.melt(data, id_vars=id_cols, value_vars=value_cols, var_name=var_name)
return data
[docs]def sort_hierarchical_data(data: pd.DataFrame) -> pd.DataFrame:
"""Reorder index labels of a hierarchical index and sort in level order."""
sort_order = ["location", "sex", "age_start", "age_end", "year_start", "year_end"]
sorted_data_index = [n for n in sort_order if n in data.index.names]
sorted_data_index.extend([n for n in data.index.names if n not in sorted_data_index])
if isinstance(data.index, pd.MultiIndex):
data = data.reorder_levels(sorted_data_index)
data = data.sort_index()
return data
[docs]def convert_affected_entity(data: pd.DataFrame, column: str) -> pd.DataFrame:
ids = data[column].unique()
data = data.rename(columns={column: "affected_entity"})
if column == "cause_id":
name_map = {c.gbd_id: c.name for c in causes if c.gbd_id in ids}
else: # column == 'rei_id'
name_map = {r.gbd_id: r.name for r in risk_factors if r.gbd_id in ids}
data["affected_entity"] = data["affected_entity"].map(name_map)
return data
[docs]def compute_categorical_paf(
rr_data: pd.DataFrame, e: pd.DataFrame, affected_entity: str
) -> pd.DataFrame:
rr = rr_data[rr_data.affected_entity == affected_entity]
affected_measure = rr.affected_measure.unique()[0]
rr.drop(["affected_entity", "affected_measure"], axis=1, inplace=True)
key_cols = ["sex_id", "age_group_id", "year_id", "parameter", "draw"]
e = e.set_index(key_cols).sort_index(level=key_cols)
rr = rr.set_index(key_cols).sort_index(level=key_cols)
weighted_rr = e * rr
groupby_cols = [c for c in key_cols if c != "parameter"]
mean_rr = weighted_rr.reset_index().groupby(groupby_cols)["value"].sum()
paf = ((mean_rr - 1) / mean_rr).reset_index()
paf = paf.replace(-np.inf, 0) # Rows with zero exposure.
paf["affected_entity"] = affected_entity
paf["affected_measure"] = affected_measure
return paf
[docs]def get_age_group_ids_by_restriction(
entity: Union[RiskFactor, Cause], which_age: str
) -> Tuple[float, float]:
if which_age == "yll":
start, end = (
entity.restrictions.yll_age_group_id_start,
entity.restrictions.yll_age_group_id_end,
)
elif which_age == "yld":
start, end = (
entity.restrictions.yld_age_group_id_start,
entity.restrictions.yld_age_group_id_end,
)
elif which_age == "inner":
start = get_restriction_age_boundary(entity, "start", reverse=True)
end = get_restriction_age_boundary(entity, "end", reverse=True)
elif which_age == "outer":
start = get_restriction_age_boundary(entity, "start")
end = get_restriction_age_boundary(entity, "end")
else:
raise NotImplementedError(
"The second argument of this function should be one of [yll, yld, inner, outer]."
)
return start, end
[docs]def filter_data_by_restrictions(
data: pd.DataFrame,
entity: Union[RiskFactor, Cause],
which_age: str,
age_group_ids: List[int],
) -> pd.DataFrame:
"""
For the given data and restrictions, it applies age/sex restrictions and
filter out the data outside of the range. Age restrictions can be applied
in 4 different ways:
- yld
- yll
- narrowest(inner) range of yll and yld
- broadest(outer) range of yll and yld.
Parameters
----------
data
DataFrame containing 'age_group_id' and 'sex_id' columns.
entity
Cause or RiskFactor
which_age
one of 4 choices: 'yll', 'yld', 'inner', 'outer'.
age_group_ids
List of possible age group ids.
Returns
-------
pandas.DataFrame
DataFrame which is filtered out any data outside of age/sex
restriction ranges.
"""
restrictions = entity.restrictions
if restrictions.male_only and not restrictions.female_only:
sexes = [SEXES["Male"]]
elif not restrictions.male_only and restrictions.female_only:
sexes = [SEXES["Female"]]
else: # not male only and not female only
sexes = [SEXES["Male"], SEXES["Female"], SEXES["Combined"]]
data = data[data.sex_id.isin(sexes)]
start, end = get_age_group_ids_by_restriction(entity, which_age)
ages = get_restriction_age_ids(start, end, age_group_ids)
data = data[data.age_group_id.isin(ages)]
return data
[docs]def clear_disability_weight_outside_restrictions(
data: pd.DataFrame, cause: Cause, fill_value: float, age_group_ids: List[int]
) -> pd.DataFrame:
"""Because sequela disability weight is not age/sex specific, we need to
have a custom function to set the values outside the corresponding cause
restrictions to 0 after it has been expanded over age/sex."""
restrictions = cause.restrictions
if restrictions.male_only and not restrictions.female_only:
sexes = [SEXES["Male"]]
elif not restrictions.male_only and restrictions.female_only:
sexes = [SEXES["Female"]]
else: # not male only and not female only
sexes = [SEXES["Male"], SEXES["Female"], SEXES["Combined"]]
start, end = get_age_group_ids_by_restriction(cause, "yld")
ages = get_restriction_age_ids(start, end, age_group_ids)
data.loc[
(~data.sex_id.isin(sexes)) | (~data.age_group_id.isin(ages)), DRAW_COLUMNS
] = fill_value
return data
[docs]def filter_to_most_detailed_causes(data: pd.DataFrame) -> pd.DataFrame:
"""For the DataFrame including the cause_ids, it filters rows with
cause_ids for the most detailed causes"""
cause_ids = set(data.cause_id)
most_detailed_cause_ids = [
c.gbd_id for c in causes if c.gbd_id in cause_ids and c.most_detailed
]
return data[data.cause_id.isin(most_detailed_cause_ids)]
[docs]def get_restriction_age_ids(
start_id: Union[int, None], end_id: Union[int, None], age_group_ids: List[int]
) -> List[int]:
"""Get the start/end age group id and return the list of GBD age_group_ids
in-between.
"""
if start_id is None or end_id is None:
data = []
else:
start_index = age_group_ids.index(start_id)
end_index = age_group_ids.index(end_id)
data = age_group_ids[start_index : end_index + 1]
return data
[docs]def get_restriction_age_boundary(
entity: Union[RiskFactor, Cause], boundary: str, reverse=False
):
"""Find the minimum/maximum age restriction (if both 'yll' and 'yld'
restrictions exist) for a RiskFactor.
Parameters
----------
entity
RiskFactor or Cause for which to find the minimum/maximum age restriction.
boundary
String 'start' or 'end' indicating whether to return the minimum(maximum)
start age restriction or maximum(minimum) end age restriction.
reverse
if reverse is True, return the maximum of start age restriction
and minimum of end age restriction.
Returns
-------
The age group id corresponding to the minimum or maximum start or end
age restriction, depending on `boundary`, if both 'yll' and 'yld'
restrictions exist. Otherwise, returns whichever restriction exists.
"""
yld_age = entity.restrictions[f"yld_age_group_id_{boundary}"]
yll_age = entity.restrictions[f"yld_age_group_id_{boundary}"]
if yld_age is None:
age = yll_age
elif yll_age is None:
age = yld_age
else:
start_op = max if reverse else min
end_op = min if reverse else max
age = end_op(yld_age, yll_age) if boundary == "start" else start_op(yld_age, yll_age)
return age
[docs]def get_exposure_and_restriction_ages(exposure: pd.DataFrame, entity: RiskFactor) -> set:
"""Get the intersection of age groups found in exposure data and entity
restriction age range. Used to filter other risk data where
using just exposure age groups isn't sufficient because exposure at the
point of extraction is pre-filtering by age restrictions.
Parameters
----------
exposure
Exposure data for `entity`.
entity
Entity for which to find the intersecting exposure and restriction ages.
Returns
-------
Set of age groups found in both the entity's exposure data and in the
entity's age restrictions.
"""
exposure_age_groups = set(exposure.age_group_id)
start, end = get_age_group_ids_by_restriction(entity, "outer")
restriction_age_groups = get_restriction_age_ids(
start, end, utility_data.get_age_group_ids()
)
valid_age_groups = exposure_age_groups.intersection(restriction_age_groups)
return valid_age_groups
[docs]def split_interval(data, interval_column, split_column_prefix):
if isinstance(data, pd.DataFrame) and interval_column in data.index.names:
data[f"{split_column_prefix}_end"] = [
x.right for x in data.index.get_level_values(interval_column)
]
if not isinstance(data.index, pd.MultiIndex):
data[f"{split_column_prefix}_start"] = [
x.left for x in data.index.get_level_values(interval_column)
]
data = data.set_index(
[f"{split_column_prefix}_start", f"{split_column_prefix}_end"]
)
else:
interval_starts = [
x.left for x in data.index.levels[data.index.names.index(interval_column)]
]
data.index = data.index.rename(
f"{split_column_prefix}_start", level=interval_column
).set_levels(interval_starts, level=f"{split_column_prefix}_start")
data = data.set_index(f"{split_column_prefix}_end", append=True)
return data