Source code for vivarium_public_health.causal_factor.effect

"""
==========================
CausalFactor Effect Models
==========================

This module contains tools for modeling the relationship between causal factor
exposure models and the models they affect.

"""

from abc import ABC
from collections.abc import Callable
from importlib import import_module
from typing import Any

import numpy as np
import pandas as pd
from layered_config_tree import ConfigurationError
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.lookup import LookupTable
from vivarium.types import LookupTableData

from vivarium_public_health.causal_factor.calibration_constant import (
    get_calibration_constant_pipeline_name,
)
from vivarium_public_health.causal_factor.distributions import DichotomousDistribution
from vivarium_public_health.causal_factor.exposure import CausalFactor
from vivarium_public_health.causal_factor.utilities import (
    load_exposure_data,
    pivot_categorical,
)
from vivarium_public_health.utilities import EntityString, TargetString


[docs] class CausalFactorEffect(Component, ABC): """A component to model the effect of a causal factor on an affected entity's target measure. This component can source data either from builder.data or from parameters supplied in the configuration. For a causal factor named 'causal_factor' that affects 'affected_target', the configuration would look like: .. code-block:: yaml configuration: causal_factor_effect.causal_factor_name_on_affected_target: exposure_parameters: 2 incidence_rate: 10 """ EXPOSURE_CLASS = CausalFactor ############## # Properties # ############## @property def name(self) -> str: """The name of this causal factor effect component.""" return self.get_name(self.causal_factor, self.target)
[docs] @staticmethod def get_name(causal_factor: EntityString, target: TargetString) -> str: """Return the component name for a causal factor and target pair.""" return f"causal_factor_effect.{causal_factor.name}_on_{target}"
@property def configuration_defaults(self) -> dict[str, Any]: """Default configuration values for this component. Configuration structure:: {causal_factor_effect_name}: data_sources: relative_risk: Source for relative risk data. Default is the artifact key ``{causal_factor}.relative_risk``. Can also be: - A scalar value (e.g., ``1.5``) - A scipy.stats distribution name (e.g., ``"uniform"``) with parameters in ``data_source_parameters`` population_attributable_fraction: Source for PAF data. Default is the artifact key ``{causal_factor}.population_attributable_fraction``. Used to adjust the target measure to account for the portion attributable to this causal factor. data_source_parameters: relative_risk: dict Parameters for scipy.stats distributions when using a distribution name as the ``relative_risk`` source. For example, ``{"loc": 1.0, "scale": 0.5}`` for a uniform distribution. """ return { self.name: { "data_sources": { "relative_risk": f"{self.causal_factor}.relative_risk", "population_attributable_fraction": f"{self.causal_factor}.population_attributable_fraction", }, "data_source_parameters": { "relative_risk": {}, }, } } @property def is_exposure_categorical(self) -> bool: """Whether the exposure distribution is categorical.""" return self._exposure_distribution_type in [ "dichotomous", "ordered_polytomous", "unordered_polytomous", ] ##################### # Lifecycle methods # ##################### def __init__(self, causal_factor: str, target: str): """ Parameters ---------- causal_factor Type and name of causal factor, supplied in the form "causal_factor_type.causal_factor_name" where causal_factor_type should be singular (e.g., risk_factor instead of risk_factors). target Type, name, and target measure of entity to be affected by causal factor, supplied in the form "entity_type.entity_name.measure" where entity_type should be singular (e.g., cause instead of causes). """ super().__init__() self.causal_factor = EntityString(causal_factor) self.target = TargetString(target) self._exposure_distribution_type = None self.exposure_name = f"{self.causal_factor.name}.exposure" self.target_name = f"{self.target.name}.{self.target.measure}" self.relative_risk_name = ( f"{self.causal_factor.name}_on_{self.target_name}.relative_risk" )
[docs] def setup(self, builder: Builder) -> None: """Set up the causal factor effect component. Load distribution type and PAF data, define relative risk source, build relative risk lookup tables, register relative risk pipeline, and register target and calibration constant modifiers. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ self.causal_factor_exposure_component = self._get_causal_factor_exposure_component( builder ) self._exposure_distribution_type = self.get_distribution_type(builder) self.relative_risk_table = self.build_rr_lookup_table(builder) self.paf_data = self.get_calibration_constant_data(builder) self._relative_risk_source = self.get_relative_risk_source(builder) self.register_relative_risk_pipeline(builder) self.register_target_modifier(builder) self.register_calibration_constant_modifier(builder)
################# # Setup methods # #################
[docs] def build_rr_lookup_table(self, builder: Builder) -> LookupTable: """Build a lookup table for relative risk data. Parameters ---------- builder Access point for utilizing framework interfaces during setup. Returns ------- A lookup table of relative risk values. """ rr_data = self.load_relative_risk(builder) rr_value_cols = None if self.is_exposure_categorical: rr_data, rr_value_cols = self.process_categorical_data(builder, rr_data) return self.build_lookup_table( builder, "relative_risk", data_source=rr_data, value_columns=rr_value_cols )
[docs] def get_calibration_constant_data(self, builder: Builder) -> LookupTableData: """Load calibration constant (PAF) data for this effect. Parameters ---------- builder Access point for utilizing framework interfaces during setup. Returns ------- The calibration constant data. """ return self.get_filtered_data( builder, self.configuration.data_sources.population_attributable_fraction )
[docs] def get_distribution_type(self, builder: Builder) -> str: """Get the distribution type for the causal factor from the configuration.""" return ( self.causal_factor_exposure_component.distribution_type or self.causal_factor_exposure_component.get_distribution_type(builder) )
[docs] def load_relative_risk( self, builder: Builder, configuration=None, ) -> str | float | pd.DataFrame: """Load relative risk data from the configuration. Attempt to interpret the configured source as a scipy.stats distribution name; if that fails, load it as artifact data. Parameters ---------- builder Access point for utilizing framework interfaces during setup. configuration Optional configuration override. If ``None``, use ``self.configuration``. Returns ------- The relative risk data. Raises ------ ConfigurationError If the distribution parameters are invalid. """ if configuration is None: configuration = self.configuration rr_source = configuration.data_sources.relative_risk rr_dist_parameters = configuration.data_source_parameters.relative_risk.to_dict() if isinstance(rr_source, str): try: distribution = getattr(import_module("scipy.stats"), rr_source) rng = np.random.default_rng(builder.randomness.get_seed(self.name)) rr_data = distribution(**rr_dist_parameters).ppf(rng.random()) except AttributeError: rr_data = self.get_filtered_data(builder, rr_source) except TypeError: raise ConfigurationError( f"Parameters {rr_dist_parameters} are not valid for distribution {rr_source}." ) else: rr_data = self.get_filtered_data(builder, rr_source) return rr_data
[docs] def get_filtered_data( self, builder: Builder, data_source: str | float | pd.DataFrame ) -> float | pd.DataFrame: """Load data and filter to the target entity and measure. Parameters ---------- builder Access point for utilizing framework interfaces during setup. data_source The data source identifier, scalar, or DataFrame. Returns ------- The filtered data. """ data = self.get_data(builder, data_source) if isinstance(data, pd.DataFrame): # filter data to only include the target entity and measure correct_target_mask = True columns_to_drop = [] if "affected_entity" in data.columns: correct_target_mask &= data["affected_entity"] == self.target.name columns_to_drop.append("affected_entity") if "affected_measure" in data.columns: correct_target_mask &= data["affected_measure"] == self.target.measure columns_to_drop.append("affected_measure") data = data[correct_target_mask].drop(columns=columns_to_drop) return data
[docs] def process_categorical_data( self, builder: Builder, rr_data: str | float | pd.DataFrame ) -> tuple[str | float | pd.DataFrame, list[str]]: """Process relative risk data for categorical exposures. For scalar RR data with a dichotomous distribution, construct a DataFrame with exposed/unexposed categories. Pivot the data to wide format for use in a lookup table. Parameters ---------- builder Access point for utilizing framework interfaces during setup. rr_data The relative risk data. Returns ------- A tuple of the pivoted RR data and the list of value column names. Raises ------ ValueError If scalar RR data is provided with a non-dichotomous distribution. """ if not isinstance(rr_data, pd.DataFrame): exposure_distribution = ( self.causal_factor_exposure_component.exposure_distribution ) if not isinstance(exposure_distribution, DichotomousDistribution): raise ValueError( f"Relative risk data for categorical exposure must be a DataFrame unless the " f"exposure distribution is dichotomous. Found type {type(rr_data)} with " f"exposure distribution type {exposure_distribution.distribution_type}." ) cat1 = builder.data.load("population.demographic_dimensions") cat1["parameter"] = exposure_distribution.exposed cat1["value"] = rr_data cat2 = cat1.copy() cat2["parameter"] = exposure_distribution.unexposed cat2["value"] = 1 rr_data = pd.concat([cat1, cat2], ignore_index=True) if "parameter" in rr_data.index.names: rr_data = rr_data.reset_index("parameter") exposure_distribution = self.causal_factor_exposure_component.exposure_distribution if isinstance(exposure_distribution, DichotomousDistribution): rr_data = exposure_distribution.rename_deprecated_categories(rr_data) rr_value_cols = list(rr_data["parameter"].unique()) rr_data = pivot_categorical(rr_data, "parameter") return rr_data, rr_value_cols
# todo currently this isn't being called. we need to properly set rrs if # the exposure has been rebinned
[docs] def rebin_relative_risk_data( self, builder, relative_risk_data: pd.DataFrame ) -> pd.DataFrame: """Rebin relative risk data. When the polytomous risk is rebinned, matching relative risk needs to be rebinned. After rebinning, rr for both exposed and unexposed categories should be the weighted sum of relative risk of the component categories where weights are relative proportions of exposure of those categories. For example, if cat1, cat2, cat3 are exposed categories and cat4 is unexposed with exposure [0.1,0.2,0.3,0.4], for the matching rr = [rr1, rr2, rr3, 1], rebinned rr for the rebinned cat1 should be: (0.1 *rr1 + 0.2 * rr2 + 0.3* rr3) / (0.1+0.2+0.3) """ if not self.causal_factor in builder.configuration.to_dict(): return relative_risk_data rebin_exposed_categories = set( builder.configuration[self.causal_factor]["rebinned_exposed"] ) if rebin_exposed_categories: # todo make sure this works exposure_data = load_exposure_data(builder, self.causal_factor) relative_risk_data = self._rebin_relative_risk_data( relative_risk_data, exposure_data, rebin_exposed_categories ) return relative_risk_data
def _rebin_relative_risk_data( self, relative_risk_data: pd.DataFrame, exposure_data: pd.DataFrame, rebin_exposed_categories: set, ) -> pd.DataFrame: """Compute exposure-weighted relative risks for rebinned categories.""" cols = list(exposure_data.columns.difference(["value"])) relative_risk_data = relative_risk_data.merge(exposure_data, on=cols) relative_risk_data["value_x"] = relative_risk_data.value_x.multiply( relative_risk_data.value_y ) relative_risk_data.parameter = relative_risk_data["parameter"].map( lambda p: "cat1" if p in rebin_exposed_categories else "cat2" ) relative_risk_data = relative_risk_data.groupby(cols).sum().reset_index() relative_risk_data["value"] = relative_risk_data.value_x.divide( relative_risk_data.value_y ).fillna(0) return relative_risk_data.drop(columns=["value_x", "value_y"])
[docs] def get_relative_risk_source(self, builder: Builder) -> Callable[[pd.Index], pd.Series]: """Build a callable that computes relative risk from exposure. For continuous exposures, use TMRED-based log-linear scaling. For categorical exposures, look up the RR for each simulant's exposure category. Parameters ---------- builder Access point for utilizing framework interfaces during setup. Returns ------- A callable that accepts a simulant index and returns relative risk values. """ if not self.is_exposure_categorical: tmred = builder.data.load(f"{self.causal_factor}.tmred") tmrel = 0.5 * (tmred["min"] + tmred["max"]) scale = builder.data.load(f"{self.causal_factor}.relative_risk_scalar") def generate_relative_risk(index: pd.Index) -> pd.Series: rr = self.relative_risk_table(index) exposure = self.population_view.get(index, self.exposure_name) relative_risk = np.maximum(rr.values ** ((exposure - tmrel) / scale), 1) return relative_risk else: index_columns = ["index", self.causal_factor.name] def generate_relative_risk(index: pd.Index) -> pd.Series: rr = self.relative_risk_table(index) exposure = self.population_view.get(index, self.exposure_name).reset_index() exposure.columns = index_columns exposure = exposure.set_index(index_columns) relative_risk = rr.stack().reset_index() relative_risk.columns = index_columns + ["value"] relative_risk = relative_risk.set_index(index_columns) effect = relative_risk.loc[exposure.index, "value"].droplevel( self.causal_factor.name ) return effect return generate_relative_risk
[docs] def register_relative_risk_pipeline(self, builder: Builder) -> None: """Register the relative risk pipeline with the simulation. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ builder.value.register_attribute_producer( self.relative_risk_name, self._relative_risk_source, required_resources=[self.exposure_name], )
[docs] def register_target_modifier(self, builder: Builder) -> None: """Register the relative risk as a modifier on the target pipeline. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ builder.value.register_attribute_modifier( self.target_name, modifier=self.relative_risk_name )
[docs] def register_calibration_constant_modifier(self, builder: Builder) -> None: """Register the PAF data as a modifier on the calibration constant pipeline. Parameters ---------- builder Access point for utilizing framework interfaces during setup. """ builder.value.register_value_modifier( get_calibration_constant_pipeline_name(self.target_name), modifier=lambda: self.paf_data, )
################## # Helper methods # ################## def _get_causal_factor_exposure_component(self, builder: Builder) -> CausalFactor: """Retrieve effect component and validate that it is compatible with the causal factor exposure. """ causal_factor_exposure_component = builder.components.get_component( self.causal_factor ) if not isinstance(causal_factor_exposure_component, self.EXPOSURE_CLASS): raise ValueError( f"{self.__class__.__name__} model {self.name} requires a {self.EXPOSURE_CLASS.__name__} component named {self.causal_factor}" ) return causal_factor_exposure_component