Source code for vivarium_public_health.plugins.parser

"""
===============================
Component Configuration Parsers
===============================

Component Configuration Parsers in this module are specialized implementations of
:class:`ComponentConfigurationParser <vivarium.framework.components.parser.ComponentConfigurationParser>`
that can parse configurations of components specific to the Vivarium Public
Health package.

"""
import warnings
from collections.abc import Callable
from importlib import import_module
from importlib.resources import files
from typing import Any

import pandas as pd
from layered_config_tree import LayeredConfigTree
from vivarium import Component
from vivarium.framework.components import ComponentConfigurationParser
from vivarium.framework.components.parser import ParsingError
from vivarium.framework.engine import Builder
from vivarium.framework.state_machine import Trigger
from vivarium.framework.utilities import import_by_path

from vivarium_public_health.disease import (
    BaseDiseaseState,
    DiseaseModel,
    DiseaseState,
    RecoveredState,
    SusceptibleState,
    TransientDiseaseState,
)
from vivarium_public_health.utilities import TargetString


[docs] class CausesParsingErrors(ParsingError): """Error raised when there are any errors parsing a cause model configuration.""" def __init__(self, messages: list[str]): super().__init__("\n - " + "\n - ".join(messages))
[docs] class CausesConfigurationParser(ComponentConfigurationParser): """Parser for cause model configurations. Component configuration parser that acts the same as the standard vivarium `ComponentConfigurationParser` but adds the additional ability to parse a configuration to create `DiseaseModel` components. These DiseaseModel configurations can either be specified directly in the configuration in a `causes` key or in external configuration files that are specified in the `external_configuration` key. """ DEFAULT_MODEL_CONFIG = { "model_type": f"{DiseaseModel.__module__}.{DiseaseModel.__name__}", "residual_state": None, } """Default cause model configuration if it's not explicitly specified. Initial state and residual state cannot both be provided. If neither initial state nor residual state has been specified, the cause model must have a state named 'susceptible'. """ DEFAULT_STATE_CONFIG = { "cause_type": "cause", "transient": False, "allow_self_transition": True, "side_effect": None, "cleanup_function": None, "state_type": None, } """Default state configuration if it's not explicitly specified.""" DEFAULT_TRANSITION_CONFIG = {"triggered": "NOT_TRIGGERED"} """Default triggered value. This value is used if the transition configuration does not explicity specify it. """
[docs] def parse_component_config(self, component_config: LayeredConfigTree) -> list[Component]: """Parses the component configuration and returns a list of components. In particular, this method looks for an `external_configuration` key and/or a `causes` key. The `external_configuration` key should have names of packages that contain cause model configuration files. Within that key should be a list of paths to cause model configuration files relative to the package. .. code-block:: yaml external_configuration: some_package: - some/path/cause_model_1.yaml - some/path/cause_model_2.yaml The `causes` key should contain configuration information for cause models. .. code-block:: yaml causes: cause_1: model_type: vivarium_public_health.disease.DiseaseModel residual_state: susceptible states: susceptible: cause_type: cause data_sources: {} infected: cause_type: cause transient: false allow_self_transition: true data_sources: {} transitions: transition_1: source: susceptible sink: infected transition_type: rate data_sources: {} # todo add information about the data_sources configuration Note that this method modifies the simulation's component configuration by adding the contents of external configuration files to the `model_override` layer and adding default cause model configuration values for all cause models to the `component_config` layer. Parameters ---------- component_config A LayeredConfigTree defining the components to initialize. Returns ------- A list of initialized components. Raises ------ CausesParsingErrors If the cause model configuration is invalid """ components = [] if "external_configuration" in component_config: self._validate_external_configuration(component_config["external_configuration"]) for package, config_files in component_config["external_configuration"].items(): for config_file in config_files.get_value(): source = f"{package}::{config_file}" config_path = str(files(package).joinpath(config_file)) external_config = LayeredConfigTree(config_path) component_config.update( external_config, layer="model_override", source=source ) if "causes" in component_config: causes_config = component_config["causes"] self._validate_causes_config(causes_config) self._add_default_config_layer(causes_config) components += self._get_cause_model_components(causes_config) # Parse standard components (i.e. not cause models) standard_component_config = component_config.to_dict() standard_component_config.pop("external_configuration", None) standard_component_config.pop("causes", None) standard_components = ( self.process_level(standard_component_config, []) if standard_component_config else [] ) return components + standard_components
######################### # Configuration methods # ######################### def _add_default_config_layer(self, causes_config: LayeredConfigTree) -> None: """Adds a default layer to the provided configuration. This default layer specifies values for the cause model configuration. Parameters ---------- causes_config A LayeredConfigTree defining the cause model configurations """ default_config = {} for cause_name, cause_config in causes_config.items(): default_states_config = {} default_transitions_config = {} default_config[cause_name] = { **self.DEFAULT_MODEL_CONFIG, "states": default_states_config, "transitions": default_transitions_config, } for state_name, state_config in cause_config.states.items(): default_states_config[state_name] = self.DEFAULT_STATE_CONFIG for transition_name, transition_config in cause_config.transitions.items(): default_transitions_config[transition_name] = self.DEFAULT_TRANSITION_CONFIG causes_config.update( default_config, layer="component_configs", source="causes_configuration_parser" ) ################################ # Cause model creation methods # ################################ def _get_cause_model_components( self, causes_config: LayeredConfigTree ) -> list[Component]: """Parses the cause model configuration and returns the `DiseaseModel` components. Parameters ---------- causes_config A LayeredConfigTree defining the cause model components to initialize Returns ------- A list of initialized `DiseaseModel` components """ cause_models = [] for cause_name, cause_config in causes_config.items(): data_sources = self._get_data_sources(cause_config) states: dict[str, BaseDiseaseState] = { state_name: self._get_state(state_name, state_config, cause_name) for state_name, state_config in cause_config.states.items() } for transition_config in cause_config.transitions.values(): self._add_transition( states[transition_config.source], states[transition_config.sink], transition_config, ) model_type = import_by_path(cause_config.model_type) residual_state = states.get(cause_config.residual_state, None) model = model_type( cause_name, states=list(states.values()), residual_state=residual_state, **data_sources, ) cause_models.append(model) return cause_models def _get_state( self, state_name: str, state_config: LayeredConfigTree, cause_name: str ) -> BaseDiseaseState: """Parses a state configuration and returns an initialized `BaseDiseaseState` object. Parameters ---------- state_name The name of the state to initialize state_config A LayeredConfigTree defining the state to initialize cause_name The name of the cause to which the state belongs Returns ------- An initialized `BaseDiseaseState` object """ state_id = cause_name if state_name in ["susceptible", "recovered"] else state_name state_kwargs = { "cause_type": state_config.cause_type, "allow_self_transition": state_config.allow_self_transition, } if state_config.side_effect: # todo handle side effects properly state_kwargs["side_effect"] = lambda *x: x if state_config.cleanup_function: # todo handle cleanup functions properly state_kwargs["cleanup_function"] = lambda *x: x state_kwargs = {**state_kwargs, **self._get_data_sources(state_config)} if state_config.state_type is not None: state_type = import_by_path(state_config.state_type) elif state_config.transient: state_type = TransientDiseaseState elif state_name == "susceptible": state_type = SusceptibleState elif state_name == "recovered": state_type = RecoveredState else: state_type = DiseaseState state = state_type(state_id, **state_kwargs) return state def _add_transition( self, source_state: BaseDiseaseState, sink_state: BaseDiseaseState, transition_config: LayeredConfigTree, ) -> None: """Adds a transition between two states. Parameters ---------- source_state The state the transition starts from sink_state The state the transition ends at transition_config A `LayeredConfigTree` defining the transition to add """ triggered = Trigger[transition_config.triggered] data_sources = self._get_data_sources(transition_config) if transition_config["transition_type"] == "rate": source_state.add_rate_transition(sink_state, **data_sources, triggered=triggered) elif transition_config["transition_type"] == "proportion": source_state.add_proportion_transition( sink_state, **data_sources, triggered=triggered ) elif transition_config["transition_type"] == "dwell_time": source_state.add_dwell_time_transition(sink_state, triggered=triggered) else: raise ValueError( f"Invalid transition data type '{transition_config.type}'" f" provided for transition '{transition_config}'." ) def _get_data_sources( self, config: LayeredConfigTree ) -> dict[str, str | float | pd.Timedelta | Callable[[Builder], Any]]: """Parses a configuration and returns the data sources. Parameters ---------- config A LayeredConfigTree defining the data sources to initialize Returns ------- A dictionary of data source getters """ if "data_sources" not in config: return {} return { name: self._get_data_source(name, config.data_sources[name]) for name in config.data_sources.keys() } @staticmethod def _get_data_source( name: str, source: str | float, ) -> str | float | pd.Timedelta | Callable[[Builder], Any]: """Parses a data source and returns a callable that can be used to retrieve the data. Parameters ---------- name The name of the data source source The data source to parse Returns ------- A callable that can be used to retrieve the data """ if isinstance(source, str): try: return pd.Timedelta(source) except ValueError: pass if "::" in source: module, method = source.split("::") return getattr(import_module(module), method) try: source = TargetString(source) except ValueError: raise ValueError(f"Invalid data source '{source}' for '{name}'.") return source ###################### # Validation methods # ###################### _CAUSE_KEYS = { "model_type", "states", "transitions", "data_sources", "residual_state", } _STATE_KEYS = { "state_type", "cause_type", "transient", "allow_self_transition", "side_effect", "data_sources", "cleanup_function", } _DATA_SOURCE_KEYS = { "cause": {"cause_specific_mortality_rate"}, "state": { "prevalence", "birth_prevalence", "dwell_time", "disability_weight", "excess_mortality_rate", }, "rate_transition": { "incidence_rate", "transition_rate", "remission_rate", }, "proportion_transition": {"proportion"}, } _TRANSITION_KEYS = {"source", "sink", "transition_type", "triggered", "data_sources"} _TRANSITION_TYPE_KEYS = {"rate", "proportion", "dwell_time"} @staticmethod def _validate_external_configuration(external_configuration: LayeredConfigTree) -> None: """Validates the external configuration. Parameters ---------- external_configuration A LayeredConfigTree defining the external configuration Raises ------ CausesParsingErrors If the external configuration is invalid """ external_configuration = external_configuration.to_dict() error_messages = [] for package, config_files in external_configuration.items(): if not isinstance(package, str): error_messages.append( f"Key '{package}' must be a string definition of a package." ) if not isinstance(config_files, list): error_messages.append( f"External configuration for package '{package}' must be a list." ) else: for config_file in config_files: if not isinstance(config_file, str) or not config_file.endswith(".yaml"): error_messages.append( f"External configuration for package '{package}' must " "be a list of paths to yaml files." ) if error_messages: raise CausesParsingErrors(error_messages) def _validate_causes_config(self, causes_config: LayeredConfigTree) -> None: """Validates the cause model configuration. Parameters ---------- causes_config A LayeredConfigTree defining the cause model configurations Raises ------ CausesParsingErrors If the cause model configuration is invalid """ causes_config = causes_config.to_dict() error_messages = [] for cause_name, cause_config in causes_config.items(): error_messages += self._validate_cause(cause_name, cause_config) if error_messages: raise CausesParsingErrors(error_messages) def _validate_cause(self, cause_name: str, cause_config: dict[str, Any]) -> list[str]: """Validates a cause configuration and returns a list of error messages. Parameters ---------- cause_name The name of the cause to validate cause_config A LayeredConfigTree defining the cause to validate Returns ------- A list of error messages """ error_messages = [] if not isinstance(cause_config, dict): error_messages.append( f"Cause configuration for cause '{cause_name}' must be a dictionary." ) return error_messages if not set(cause_config.keys()).issubset(self._CAUSE_KEYS): error_messages.append( f"Cause configuration for cause '{cause_name}' may only" " contain the following keys: " f"{self._CAUSE_KEYS}." ) if "model_type" in cause_config: error_messages += self._validate_imported_type( cause_config["model_type"], cause_name, "model" ) states_config = cause_config.get("states", {}) if not states_config: error_messages.append( f"Cause configuration for cause '{cause_name}' must define " "at least one state." ) if not isinstance(states_config, dict): error_messages.append( f"States configuration for cause '{cause_name}' must be a dictionary." ) else: residual_state = cause_config.get("residual_state", None) if residual_state is not None and residual_state not in states_config: error_messages.append( f"Residual state '{residual_state}' for cause '{cause_name}'" f" must be present in the states for cause '{cause_name}." ) for state_name, state_config in states_config.items(): error_messages += self._validate_state(cause_name, state_name, state_config) transitions_config = cause_config.get("transitions", {}) if not isinstance(transitions_config, dict): error_messages.append( f"Transitions configuration for cause '{cause_name}' must be " "a dictionary if it is present." ) else: for transition_name, transition_config in transitions_config.items(): error_messages += self._validate_transition( cause_name, transition_name, transition_config, states_config ) error_messages += self._validate_data_sources( cause_config, cause_name, "cause", cause_name ) return error_messages def _validate_state( self, cause_name: str, state_name: str, state_config: dict[str, Any] ) -> list[str]: """Validates a state configuration and returns a list of error messages. Parameters ---------- cause_name The name of the cause to which the state belongs state_name The name of the state to validate state_config A LayeredConfigTree defining the state to validate Returns ------- A list of error messages """ error_messages = [] if not isinstance(state_config, dict): error_messages.append( f"State configuration for in cause '{cause_name}' and " f"state '{state_name}' must be a dictionary." ) return error_messages allowable_keys = set(self._STATE_KEYS) if state_name in ["susceptible", "recovered"]: allowable_keys.remove("data_sources") if not set(state_config.keys()).issubset(allowable_keys): error_messages.append( f"State configuration for in cause '{cause_name}' and " f"state '{state_name}' may only contain the following " f"keys: {allowable_keys}." ) state_type = state_config.get("state_type", "") error_messages += self._validate_imported_type( state_type, cause_name, "state", state_name ) if not isinstance(state_config.get("cause_type", ""), str): error_messages.append( f"Cause type for state '{state_name}' in cause '{cause_name}' " f"must be a string. Provided {state_config['cause_type']}." ) is_transient = state_config.get("transient", False) if not isinstance(is_transient, bool): error_messages.append( f"Transient flag for state '{state_name}' in cause '{cause_name}' " f"must be a boolean. Provided {state_config['transient']}." ) if state_name in ["susceptible", "recovered"] and state_type: error_messages.append( f"The name '{state_name}' in cause '{cause_name}' concretely " f"specifies the state type, so state_type is not an allowed " "configuration." ) if state_name in ["susceptible", "recovered"] and is_transient: error_messages.append( f"The name '{state_name}' in cause '{cause_name}' concretely " f"specifies the state type, so transient is not an allowed " "configuration." ) if is_transient and state_type: error_messages.append( f"Specifying transient as True for state '{state_name}' in cause " f"'{cause_name}' concretely specifies the state type, so " "state_type is not an allowed configuration." ) if not isinstance(state_config.get("allow_self_transition", True), bool): error_messages.append( f"Allow self transition flag for state '{state_name}' in cause " f"'{cause_name}' must be a boolean. Provided " f"'{state_config['allow_self_transition']}'." ) error_messages += self._validate_data_sources( state_config, cause_name, "state", state_name ) return error_messages def _validate_transition( self, cause_name: str, transition_name: str, transition_config: dict[str, Any], states_config: dict[str, Any], ) -> list[str]: """Validates a transition configuration and returns a list of error messages. Parameters ---------- cause_name The name of the cause to which the transition belongs transition_name The name of the transition to validate transition_config A LayeredConfigTree defining the transition to validate states_config A LayeredConfigTree defining the states for the cause Returns ------- A list of error messages """ error_messages = [] if not isinstance(transition_config, dict): error_messages.append( f"Transition configuration for in cause '{cause_name}' and " f"transition '{transition_name}' must be a dictionary." ) return error_messages if not set(transition_config.keys()).issubset( CausesConfigurationParser._TRANSITION_KEYS ): error_messages.append( f"Transition configuration for in cause '{cause_name}' and " f"transition '{transition_name}' may only contain the " f"following keys: {self._TRANSITION_KEYS}." ) source = transition_config.get("source", None) sink = transition_config.get("sink", None) if sink is None or source is None: error_messages.append( f"Transition configuration for in cause '{cause_name}' and " f"transition '{transition_name}' must contain both a source " f"and a sink." ) if source is not None and source not in states_config: error_messages.append( f"Transition configuration for in cause '{cause_name}' and " f"transition '{transition_name}' must contain a source that " f"is present in the states." ) if sink is not None and sink not in states_config: error_messages.append( f"Transition configuration for in cause '{cause_name}' and " f"transition '{transition_name}' must contain a sink that " f"is present in the states." ) if ( "triggered" in transition_config and transition_config["triggered"] not in Trigger.__members__ ): error_messages.append( f"Transition configuration for in cause '{cause_name}' and " f"transition '{transition_name}' may only have one of the following " f"values: {Trigger.__members__}." ) if "transition_type" not in transition_config: error_messages.append( f"Transition configuration for in cause '{cause_name}' and " f"transition '{transition_name}' must contain a transition type." ) else: transition_type = transition_config["transition_type"] if transition_type not in self._TRANSITION_TYPE_KEYS: error_messages.append( f"Transition configuration for in cause '{cause_name}' and " f"transition '{transition_name}' may only contain the " f"following values: {self._TRANSITION_TYPE_KEYS}." ) if transition_type == "dwell_time" and "data_sources" in transition_config: error_messages.append( f"Transition configuration for in cause '{cause_name}' and " f"transition '{transition_name}' is a dwell-time transition and " f"may not have data sources as dwell-time is configured on the state." ) elif transition_type in self._TRANSITION_TYPE_KEYS.difference({"dwell_time"}): error_messages += self._validate_data_sources( transition_config, cause_name, f"{transition_type}_transition", transition_name, ) return error_messages @staticmethod def _validate_imported_type( import_path: str, cause_name: str, entity_type: str, entity_name: str | None = None ) -> list[str]: """Validates an imported type and returns a list of error messages. Parameters ---------- import_path The import path to validate cause_name The name of the cause to which the imported type belongs entity_type The type of the entity to which the imported type belongs entity_name The name of the entity to which the imported type belongs, if it is not a cause Returns ------- A list of error messages """ expected_type = {"model": DiseaseModel, "state": BaseDiseaseState}[entity_type] error_messages = [] if not import_path: return error_messages try: imported_type = import_by_path(import_path) if not ( isinstance(imported_type, type) and issubclass(imported_type, expected_type) ): raise TypeError except (ModuleNotFoundError, AttributeError, TypeError, ValueError): error_messages.append( f"If '{entity_type}_type' is provided for cause '{cause_name}' " f"{f'and {entity_type} {entity_name} ' if entity_name else ''}it " f"must be the fully qualified import path to a {expected_type} " f"implementation. Provided'{import_path}'." ) return error_messages def _validate_data_sources( self, config: dict[str, Any], cause_name: str, config_type: str, config_name: str ) -> list[str]: """Validates the data sources in a configuration and returns any error messages. Parameters ---------- config A LayeredConfigTree defining the configuration to validate cause_name The name of the cause to which the configuration belongs config_type The type of the configuration to validate config_name The name of the configuration being validated Returns ------- A list of error messages """ error_messages = [] data_sources_config = config.get("data_sources", {}) if not isinstance(data_sources_config, dict): error_messages.append( f"Data sources configuration for {config_type} '{config}' in " f"cause '{cause_name}' must be a dictionary if it is present." ) return error_messages if not set(data_sources_config.keys()).issubset(self._DATA_SOURCE_KEYS[config_type]): error_messages.append( f"Data sources configuration for {config_type} '{config_name}' " f"in cause '{cause_name}' may only contain the following keys: " f"{self._DATA_SOURCE_KEYS[config_type]}." ) for config_name, source in data_sources_config.items(): try: self._get_data_source(config_name, source) except ValueError: error_messages.append( f"Configuration for {config_type} '{config_name}' in cause " f"'{cause_name}' has an invalid data source at '{source}'." ) return error_messages