"""
====================
The Artifact Manager
====================
This module contains the :class:`ArtifactManager`, a ``vivarium`` plugin
for handling complex data bound up in a data artifact.
"""
import re
from pathlib import Path
from typing import Any, Sequence, Union
import pandas as pd
from vivarium.config_tree import ConfigTree
from vivarium.framework.artifact.artifact import Artifact
from vivarium.manager import Manager
_Filter = Union[str, int, Sequence[int], Sequence[str]]
[docs]
class ArtifactManager(Manager):
"""The controller plugin component for managing a data artifact."""
CONFIGURATION_DEFAULTS = {
"input_data": {
"artifact_path": None,
"artifact_filter_term": None,
"input_draw_number": None,
}
}
@property
def name(self):
return "artifact_manager"
[docs]
def setup(self, builder):
"""Performs this component's simulation setup."""
self.logger = builder.logging.get_logger(self.name)
# because not all columns are accessible via artifact filter terms, apply config filters separately
self.config_filter_term = validate_filter_term(
builder.configuration.input_data.artifact_filter_term
)
self.artifact = self._load_artifact(builder.configuration)
builder.lifecycle.add_constraint(self.load, allow_during=["setup"])
def _load_artifact(self, configuration: ConfigTree) -> Union[Artifact, None]:
"""Looks up the path to the artifact hdf file, builds a default filter,
and generates the data artifact. Stores any configuration specified filter
terms separately to be applied on loading, because not all columns are
available via artifact filter terms.
Parameters
----------
configuration :
Configuration block of the model specification containing the input data parameters.
Returns
-------
An interface to the data artifact.
"""
if not configuration.input_data.artifact_path:
return None
artifact_path = parse_artifact_path_config(configuration)
base_filter_terms = get_base_filter_terms(configuration)
self.logger.info(f"Running simulation from artifact located at {artifact_path}.")
self.logger.info(f"Artifact base filter terms are {base_filter_terms}.")
self.logger.info(f"Artifact additional filter terms are {self.config_filter_term}.")
return Artifact(artifact_path, base_filter_terms)
[docs]
def load(self, entity_key: str, **column_filters: _Filter) -> Any:
"""Loads data associated with the given entity key.
Parameters
----------
entity_key
The key associated with the expected data.
column_filters
Filters that subset the data by a categorical column and then remove
the column from the raw data. They are supplied as keyword arguments
to the load method in the form "column=value".
Returns
-------
Any
The data associated with the given key, filtered down to the
requested subset if the data is a dataframe.
"""
data = self.artifact.load(entity_key)
if isinstance(data, pd.DataFrame): # could be metadata dict
data = data.reset_index()
draw_col = [c for c in data if "draw" in c]
if draw_col:
data = data.rename(columns={draw_col[0]: "value"})
return (
filter_data(data, self.config_filter_term, **column_filters)
if isinstance(data, pd.DataFrame)
else data
)
def __repr__(self):
return "ArtifactManager()"
[docs]
class ArtifactInterface:
"""The builder interface for accessing a data artifact."""
def __init__(self, manager: ArtifactManager):
self._manager = manager
[docs]
def load(self, entity_key: str, **column_filters: Union[_Filter]) -> pd.DataFrame:
"""Loads data associated with a formatted entity key.
The provided entity key must be of the form
{entity_type}.{measure} or {entity_type}.{entity_name}.{measure}.
Here entity_type denotes the kind of entity being described. Examples
include cause, risk, population, and covariates.
The entity_name is the name of the specific entity. For example,
if we had entity_type as cause, we might have entity_name as
diarrheal_diseases or ischemic_heart_disease.
Finally, measure is the name of the quantity the data describes.
Examples of measures are incidence, disability_weight, relative_risk,
and cost.
Parameters
----------
entity_key
The key associated with the expected data.
column_filters
Filters that subset the data by a categorical column and then
remove the column from the raw data. They are supplied as keyword
arguments to the load method in the form "column=value".
Returns
-------
pandas.DataFrame
The data associated with the given key filtered down to the requested subset.
"""
return self._manager.load(entity_key, **column_filters)
def __repr__(self):
return "ArtifactManagerInterface()"
[docs]
def filter_data(
data: pd.DataFrame, config_filter_term: str = None, **column_filters: _Filter
) -> pd.DataFrame:
"""Uses the provided column filters and age_group conditions to subset the raw data."""
data = _config_filter(data, config_filter_term)
data = _subset_rows(data, **column_filters)
data = _subset_columns(data, **column_filters)
return data
def _config_filter(data, config_filter_term):
if config_filter_term:
filter_column = re.split("[<=>]", config_filter_term.split()[0])[0]
if filter_column in data.columns:
data = data.query(config_filter_term)
return data
[docs]
def validate_filter_term(config_filter_term):
multiple_filter_indicators = [" and ", " or ", "|", "&"]
if config_filter_term is not None and any(
x in config_filter_term for x in multiple_filter_indicators
):
raise NotImplementedError(
"Only a single filter term via the configuration is currently supported."
)
return config_filter_term
def _subset_rows(data: pd.DataFrame, **column_filters: _Filter) -> pd.DataFrame:
"""Filters out unwanted rows from the data using the provided filters."""
extra_filters = set(column_filters.keys()) - set(data.columns)
if extra_filters:
raise ValueError(
f"Filtering by non-existent columns: {extra_filters}. "
f"Available columns: {data.columns}"
)
for column, condition in column_filters.items():
if column in data.columns:
if not isinstance(condition, (list, tuple)):
condition = [condition]
mask = pd.Series(False, index=data.index)
for c in condition:
mask |= data[f"{column}"] == c
row_indexer = data[mask].index
data = data.loc[row_indexer, :]
return data
def _subset_columns(data: pd.DataFrame, **column_filters) -> pd.DataFrame:
"""Filters out unwanted columns and default columns from the data using provided filters."""
columns_to_remove = set(list(column_filters.keys()) + ["draw"])
columns_to_remove = columns_to_remove.intersection(data.columns)
return data.drop(columns=columns_to_remove)
[docs]
def get_base_filter_terms(configuration: ConfigTree):
"""Parses default filter terms from the artifact configuration."""
base_filter_terms = []
draw = configuration.input_data.input_draw_number
if draw is not None:
base_filter_terms.append(f"draw == {draw}")
return base_filter_terms
[docs]
def parse_artifact_path_config(config: ConfigTree) -> str:
"""Gets the path to the data artifact from the simulation configuration.
The path specified in the configuration may be absolute or it may be relative
to the location of the configuration file.
Parameters
----------
config
The configuration block of the simulation model specification containing the artifact path.
Returns
-------
str
The path to the data artifact.
"""
path = Path(config.input_data.artifact_path)
if not path.is_absolute():
path_config = config.input_data.metadata("artifact_path")[-1]
if path_config["source"] is None:
raise ValueError("Insufficient information provided to find artifact.")
path = Path(path_config["source"]).parent.joinpath(path).resolve()
if not path.exists():
raise FileNotFoundError(f"Cannot find artifact at path {path}")
return str(path)