Source code for vivarium.framework.artifact.artifact

"""
=================
The Data Artifact
=================

This module provides tools for interacting with data artifacts.

A data artifact is an archive on disk intended to package up all data
relevant to a particular simulation. This module provides a class to wrap that
archive file for convenient access and inspection.

"""
import re
import warnings
from collections import defaultdict
from collections.abc import Iterator
from pathlib import Path
from typing import Any

from vivarium.framework.artifact import hdf


[docs] class ArtifactException(Exception): """Exception raise for inconsistent use of the data artifact.""" pass
[docs] class Artifact: """An interface for interacting with :mod:`vivarium` artifacts.""" def __init__(self, path: str | Path, filter_terms: list[str] | None = None) -> None: """ Parameters ---------- path The path to the artifact file. filter_terms A set of terms suitable for usage with the ``where`` kwarg for :func:`pandas.read_hdf`. """ self._path = Path(path) self._filter_terms = filter_terms self._draw_column_filter = _parse_draw_filters(filter_terms) self._cache: dict[str, Any] = {} self.create_hdf_with_keyspace(self._path) self._keys = Keys(self._path) @property def path(self) -> str: """The path to the artifact file.""" return str(self._path) @property def keys(self) -> list[str]: """A list of all the keys contained within the artifact.""" return self._keys.to_list() @property def filter_terms(self) -> list[str] | None: """Filters that will be applied to the requested data on loads.""" return self._filter_terms
[docs] @staticmethod def create_hdf_with_keyspace(path: Path) -> None: """Creates the artifact HDF file and adds a node to track keys.""" if not path.is_file(): warnings.warn(f"No artifact found at {path}. Building new artifact.") hdf.touch(path) keys = hdf.get_keys(path) if keys and "metadata.keyspace" not in keys: raise ArtifactException( "Attempting to construct an Artifact from a malformed existing file. " "This can occur when constructing an Artifact from an existing file when " "the existing file was generated by some other hdf writing mechanism " "(e.g. pandas.to_hdf) rather than generating the the file using this class " "and a non-existent or empty hdf file." ) if not keys: hdf.write(path, "metadata.keyspace", ["metadata.keyspace"])
[docs] def load(self, entity_key: str) -> Any: """Loads the data associated with provided entity_key. Parameters ---------- entity_key The key associated with the expected data. Returns ------- The expected data. Will either be a standard Python object or a :class:`pandas.DataFrame` or :class:`pandas.Series`. Raises ------ ArtifactException If the provided key is not in the artifact. """ if entity_key not in self: raise ArtifactException(f"{entity_key} should be in {self.path}.") if entity_key not in self._cache: data = hdf.load( self._path, entity_key, self._filter_terms, self._draw_column_filter ) # FIXME: Under what conditions do we get None here. assert ( data is not None ), f"Data for {entity_key} is not available. Check your model specification." self._cache[entity_key] = data return self._cache[entity_key]
[docs] def write(self, entity_key: str, data: Any) -> None: """Writes data into the artifact and binds it to the provided key. Parameters ---------- entity_key The key associated with the provided data. data The data to write. Accepted formats are :class:`pandas.Series`, :class:`pandas.DataFrame` or standard python types and containers. Raises ------ ArtifactException If the provided key already exists in the artifact. """ if entity_key in self: raise ArtifactException(f"{entity_key} already in artifact.") elif data is None: raise ArtifactException(f"Attempting to write to key {entity_key} with no data.") else: hdf.write(self._path, entity_key, data) self._keys.append(entity_key)
[docs] def remove(self, entity_key: str) -> None: """Removes data associated with the provided key from the artifact. Parameters ---------- entity_key The key associated with the data to remove. Raises ------ ArtifactException If the key is not present in the artifact. """ if entity_key not in self: raise ArtifactException( f"Trying to remove non-existent key {entity_key} from artifact." ) self._keys.remove(entity_key) if entity_key in self._cache: self._cache.pop(entity_key) hdf.remove(self._path, entity_key)
[docs] def replace(self, entity_key: str, data: Any) -> None: """Replaces the artifact data at the provided key with the new data. Parameters ---------- entity_key The key for which the data should be overwritten. data The data to write. Accepted formats are :class:`pandas.Series`, :class:`pandas.DataFrame` or standard python types and containers. Raises ------ ArtifactException If the provided key does not already exist in the artifact. """ if entity_key not in self: raise ArtifactException( f"Trying to replace non-existent key {entity_key} in artifact." ) self.remove(entity_key) self.write(entity_key, data)
[docs] def clear_cache(self) -> None: """Clears the artifact's cache. The artifact will cache data in memory to improve performance for repeat access. """ self._cache = {}
def __iter__(self) -> Iterator[str]: return iter(self.keys) def __contains__(self, item: str) -> bool: return item in self.keys def __repr__(self) -> str: return f"Artifact(keys={self.keys})" def __str__(self) -> str: key_tree = _to_tree(self.keys) out = "Artifact containing the following keys:\n" for root, children in key_tree.items(): out += f"{root}\n" for child, grandchildren in children.items(): out += f"\t{child}\n" for grandchild in grandchildren: out += f"\t\t{grandchild}\n" return out
def _to_tree(keys: list[str]) -> dict[str, dict[str, list[str]]]: out: defaultdict[str, dict[str, list[str]]] = defaultdict(lambda: defaultdict(list)) for k in keys: key = k.split(".") if len(key) == 3: out[key[0]][key[1]].append(key[2]) else: out[key[0]][key[1]] = [] return dict(out)
[docs] class Keys: """A convenient wrapper around the keyspace which makes it easier for Artifact to maintain its keyspace when an entity key is added or removed. With the artifact_path, Keys object is initialized when the Artifact is initialized """ keyspace_node = "metadata.keyspace" def __init__(self, artifact_path: Path): self._path = artifact_path self._keys = [str(k) for k in hdf.load(self._path, "metadata.keyspace", None, None)]
[docs] def append(self, new_key: str) -> None: """Whenever the artifact gets a new key and new data, append is called to remove the old keyspace and to write the updated keyspace""" self._keys.append(new_key) hdf.remove(self._path, self.keyspace_node) hdf.write(self._path, self.keyspace_node, self._keys)
[docs] def remove(self, removing_key: str) -> None: """Whenever the artifact removes a key and data, remove is called to remove the key from keyspace and write the updated keyspace.""" self._keys.remove(removing_key) hdf.remove(self._path, self.keyspace_node) hdf.write(self._path, self.keyspace_node, self._keys)
[docs] def to_list(self) -> list[str]: """A list of all the entity keys in the associated artifact.""" return self._keys
def __contains__(self, item: str) -> bool: return item in self._keys
def _parse_draw_filters(filter_terms: list[str] | None) -> list[str] | None: """Given a list of filter terms, parse out any related to draws and convert to the list of column names. Also include 'value' column for compatibility with data that is long on draws. """ columns = None if filter_terms: draw_terms = [] for term in filter_terms: # first strip out all the parentheses strip_t: str = re.sub("[()]", "", term) # then split each condition out t: list[str] = re.split("[&|]", strip_t) # then split condition to see if it relates to draws split_term = [re.split("([<=>in])", i) for i in t] draw_terms.extend([t for t in split_term if t[0].strip() == "draw"]) if len(draw_terms) > 1: raise ValueError( f"You can only supply one filter term related to draws. " f"You supplied {filter_terms}, {len(draw_terms)} of which pertain to draws." ) if draw_terms: # convert term to columns columns_term: list[str] = [s.strip() for s in draw_terms[0] if s.strip()] if ( len(columns_term) == 4 and columns_term[1].lower() == "i" and columns_term[2].lower() == "n" ): draws = [int(d) for d in columns_term[-1][1:-1].split(",")] elif (len(columns_term) == 4 and columns_term[1] == columns_term[2] == "=") or ( len(columns_term) == 3 and columns_term[1] == "=" ): draws = [int(columns_term[-1])] else: raise NotImplementedError( f"The only supported draw filters are =, ==, or in. " f'You supplied {"".join(columns_term)}.' ) columns = [f"draw_{n}" for n in draws] + ["value"] return columns