"""
=====================
Life Cycle Management
=====================
The life cycle is a representation of the flow of execution states in a
:mod:`vivarium` simulation. The tools in this model allow a simulation to
formally represent its execution state and use the formal representation to
enforce run-time contracts.
There are two flavors of contracts that this system enforces:
- **Constraints**: These are contracts around when certain methods,
particularly those available off the :ref:`Builder <builder_concept>`,
can be used. For example, :term:`simulants <Simulant>` should only be
added to the simulation during initial population creation and during
the main simulation loop, otherwise services necessary for initializing
that population's attributes may not exist. By applying a constraint,
we can provide very clear errors about what went wrong, rather than
a deep and unintelligible stack trace.
- **Ordering Contracts**: The
:class:`~vivarium.framework.engine.SimulationContext` will construct
the formal representation of the life cycle during its initialization.
Once generated, the context declares as it transitions between
different lifecycle states and the tools here ensure that only valid
transitions occur. These kinds of contracts are particularly useful
during interactive usage, as they prevent users from, for example,
running a simulation whose population has not been created.
The tools here also allow for introspection of the simulation life cycle.
"""
import functools
import textwrap
import time
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
from vivarium.exceptions import VivariumError
from vivarium.manager import Manager
[docs]
class LifeCycleError(VivariumError):
"""Generic error class for the life cycle management system."""
pass
[docs]
class InvalidTransitionError(LifeCycleError):
"""Error raised when life cycle ordering contracts are violated."""
pass
[docs]
class ConstraintError(LifeCycleError):
"""Error raised when life cycle constraint contracts are violated."""
pass
[docs]
class LifeCycleState:
"""A representation of a simulation run state."""
def __init__(self, name: str):
self._name = name
self._next = None
self._loop_next = None
self._entrance_count = 0
self._handlers = []
@property
def name(self) -> str:
"""The name of the lifecycle state."""
return self._name
@property
def entrance_count(self) -> int:
"""The number of times this state has been entered."""
return self._entrance_count
[docs]
def add_next(self, next_state: "LifeCycleState", loop: bool = False):
"""Link this state to the next state in the simulation life cycle.
States are linked together and used to ensure that the simulation
life cycle proceeds in the proper order. A life cycle state can be
bound to two ``next`` states to allow for loops in the life cycle and
both are considered valid when checking for valid state transitions.
The first represents the linear progression through the simulation,
while the second represents a loop in the life cycle.
Parameters
----------
next_state
The next state in the simulation life cycle.
loop
Whether the provided state is the linear next state or a loop
back to a previous state in the life cycle.
"""
if loop:
self._loop_next = next_state
else:
self._next = next_state
[docs]
def valid_next_state(self, state: Optional["LifeCycleState"]) -> bool:
"""Check if the provided state is valid for a life cycle transition.
Parameters
----------
state
The state to check.
Returns
-------
bool
Whether the state is valid for a transition.
"""
return (state is None and state is self._next) or (
state is not None and (state is self._next or state is self._loop_next)
)
[docs]
def enter(self):
"""Marks an entrance into this state."""
self._entrance_count += 1
[docs]
def add_handlers(self, handlers: List[Callable]):
"""Registers a set of functions that will be executed during the state.
The primary use case here is for introspection and reporting.
For setting constraints, see :meth:`LifeCycleInterface.add_constraint`.
Parameters
----------
handlers
The set of functions that will be executed during this state.
"""
for h in handlers:
name = h.__name__
if hasattr(h, "__self__"):
obj = h.__self__
self._handlers.append(f"{obj.__class__.__name__}({obj.name}).{name}")
else:
self._handlers.append(f"Unbound function {name}")
def __repr__(self) -> str:
return f"LifeCycleState(name={self.name})"
def __str__(self) -> str:
return "\n\t".join([self.name] + self._handlers)
[docs]
class LifeCyclePhase:
"""A representation of a distinct lifecycle phase in the simulation.
A lifecycle phase is composed of one or more unique lifecycle states.
There is exactly one state within the phase which serves as a valid
exit point from the phase. The states may operate in a loop.
"""
def __init__(self, name: str, states: List[str], loop: bool):
self._name = name
self._states = [LifeCycleState(states[0])]
self._loop = loop
for s in states[1:]:
self._states.append(LifeCycleState(s))
self._states[-2].add_next(self._states[-1])
if self._loop:
self._states[-1].add_next(self._states[0], loop=True)
@property
def name(self) -> str:
"""The name of this life cycle phase."""
return self._name
@property
def states(self) -> Tuple[LifeCycleState]:
"""The states in this life cycle phase in order of execution."""
return tuple(self._states)
[docs]
def add_next(self, phase: "LifeCyclePhase"):
"""Link the provided phase as the next phase in the life cycle."""
self._states[-1].add_next(phase._states[0])
[docs]
def get_state(self, state_name: str) -> LifeCycleState:
"""Retrieve a life cycle state by name from the phase."""
return [s for s in self._states if s.name == state_name].pop()
def __contains__(self, state_name: str) -> bool:
return bool([s for s in self._states if s.name == state_name])
def __repr__(self) -> str:
return f"LifeCyclePhase(name={self.name}, states={[s.name for s in self.states]})"
def __str__(self) -> str:
out = self.name
if self._loop:
out += "*"
out += "\n" + textwrap.indent("\n".join([str(state) for state in self.states]), "\t")
return out
[docs]
class LifeCycle:
"""A concrete representation of the flow of simulation execution states."""
def __init__(self):
self._state_names = set()
self._phase_names = set()
self._phases = []
self.add_phase("initialization", ["initialization"], loop=False)
[docs]
def add_phase(self, phase_name: str, states: List[str], loop):
"""Add a new phase to the lifecycle.
Phases must be added in order.
Parameters
----------
phase_name
The name of the phase to add. Phase names must be unique.
states
The list of names (in order) of the states that make up the
life cycle phase. State names must be unique across the entire
life cycle.
loop
Whether the life cycle phase states loop.
Raises
------
LifeCycleError
If the phase or state names are non-unique.
"""
self._validate(phase_name, states)
new_phase = LifeCyclePhase(phase_name, states, loop)
if self._phases:
self._phases[-1].add_next(new_phase)
self._state_names.update(states)
self._phase_names.add(phase_name)
self._phases.append(new_phase)
[docs]
def get_state(self, state_name: str) -> LifeCycleState:
"""Retrieve a life cycle state from the life cycle.
Parameters
----------
state_name
The name of the state to retrieve
Returns
-------
LifeCycleState
The requested state.
Raises
------
LifeCycleError
If the requested state does not exist.
"""
if state_name not in self:
raise LifeCycleError(f"Attempting to look up non-existent state {state_name}.")
phase = [p for p in self._phases if state_name in p].pop()
return phase.get_state(state_name)
[docs]
def get_state_names(self, phase_name: str) -> List[str]:
"""Retrieve the names of all states in the provided phase.
Parameters
----------
phase_name
The name of the phase to retrieve the state names from.
Return
------
List[str]
The state names in the provided phase.
Raises
------
LifeCycleError
If the phase does not exist in the life cycle.
"""
if phase_name not in self._phase_names:
raise LifeCycleError(
f"Attempting to look up states from non-existent phase {phase_name}."
)
phase = [p for p in self._phases if p.name == phase_name].pop()
return [s.name for s in phase.states]
def _validate(self, phase_name: str, states: List[str]):
"""Validates that a phase and set of states are unique."""
if phase_name in self._phase_names:
raise LifeCycleError(
f"Lifecycle phase names must be unique. You're attempting "
f"to add {phase_name} but it already exists."
)
if len(states) != len(set(states)):
raise LifeCycleError(
f"Attempting to create a life cycle phase with duplicate state names. "
f"States: {states}"
)
duplicates = self._state_names.intersection(states)
if duplicates:
raise LifeCycleError(
f"Lifecycle state names must be unique. You're attempting "
f"to add {duplicates} but they already exist."
)
def __contains__(self, state_name: str) -> bool:
return state_name in self._state_names
def __repr__(self) -> str:
return f"LifeCycle(phases={self._phase_names})"
def __str__(self) -> str:
return "\n".join([str(phase) for phase in self._phases])
[docs]
class ConstraintMaker:
"""Factory for making state-based constraints on component methods."""
def __init__(self, lifecycle_manager):
self.lifecycle_manager = lifecycle_manager
self.constraints = set()
[docs]
def check_valid_state(self, method: Callable, permitted_states: List[str]):
"""Ensures a component method is being called during an allowed state.
Parameters
----------
method
The method the constraint is applied to.
permitted_states
The states in which the method is permitted to be called.
Raises
------
ConstraintError
If the method is being called outside the permitted states.
"""
current_state = self.lifecycle_manager.current_state
if current_state not in permitted_states:
raise ConstraintError(
f"Trying to call {method} during {current_state},"
f" but it may only be called during {permitted_states}."
)
[docs]
def constrain_normal_method(
self, method: Callable, permitted_states: List[str]
) -> Callable:
"""Only permit a method to be called during the provided states.
Constraints are applied by dynamically wrapping and binding a method
to an existing component at run time.
Parameters
----------
method
The method to constrain.
permitted_states
The life cycle states in which the method can be called.
Returns
-------
Callable
The constrained method.
"""
@functools.wraps(method)
def _wrapped(*args, **kwargs):
self.check_valid_state(method, permitted_states)
# Call the __func__ because we're rebinding _wrapped to the method
# name on the object. If we called method directly, we'd get
# two copies of self.
return method.__func__(*args, **kwargs)
# Invoke the descriptor protocol to bind the wrapped method to the
# component instance.
rebound_method = _wrapped.__get__(method.__self__, method.__self__.__class__)
# Then update the instance dictionary to reflect that the wrapped
# method is bound to the original name.
setattr(method.__self__, method.__name__, rebound_method)
return rebound_method
[docs]
@staticmethod
def to_guid(method: Callable) -> str:
"""Convert a method on to a global id.
Because we dynamically rebind methods, the old ones will get garbage
collected, making :func:`id` unreliable for checking if a method
has been constrained before.
"""
return f"{method.__self__.name}.{method.__name__}"
def __call__(self, method: Callable, permitted_states: List[str]) -> Callable:
"""Only permit a method to be called during the provided states.
Constraints are applied by dynamically wrapping and binding a method
to an existing component at run time.
Parameters
----------
method
The method to constrain.
permitted_states
The life cycle states in which the method can be called.
Returns
-------
The constrained method.
Raises
------
TypeError
If an unbound function is supplied for constraint.
ValueError
If the provided method is a python "special" method (i.e. a
method surrounded by double underscores).
"""
if not hasattr(method, "__self__"):
raise TypeError(
"Can only apply constraints to bound object methods. "
f"You supplied the function {method}."
)
name = method.__name__
if name.startswith("__") and name.endswith("__"):
raise ValueError(
"Can only apply constraints to normal object methods. "
f" You supplied {method}."
)
if self.to_guid(method) in self.constraints:
raise ConstraintError(f"Method {method} has already been constrained.")
self.constraints.add(self.to_guid(method))
return self.constrain_normal_method(method, permitted_states)
[docs]
class LifeCycleManager(Manager):
"""Manages ordering- and constraint-based contracts in the simulation."""
def __init__(self):
self.lifecycle = LifeCycle()
self._current_state = self.lifecycle.get_state("initialization")
self._current_state_start_time = time.time()
self._timings = defaultdict(list)
self._make_constraint = ConstraintMaker(self)
@property
def name(self) -> str:
"""The name of this component."""
return "life_cycle_manager"
@property
def current_state(self) -> str:
"""The name of the current life cycle state."""
return self._current_state.name
@property
def timings(self) -> Dict[str, List[float]]:
return self._timings
[docs]
def add_phase(self, phase_name: str, states: List[str], loop: bool = False):
"""Add a new phase to the lifecycle.
Phases must be added in order.
Parameters
----------
phase_name
The name of the phase to add. Phase names must be unique.
states
The list of names (in order) of the states that make up the
life cycle phase. State names must be unique across the entire
life cycle.
loop
Whether the life cycle phase states loop.
Raises
------
LifeCycleError
If the phase or state names are non-unique.
"""
self.lifecycle.add_phase(phase_name, states, loop)
[docs]
def set_state(self, state: str):
"""Sets the current life cycle state to the provided state.
Parameters
----------
state
The name of the state to set.
Raises
------
LifeCycleError
If the requested state doesn't exist in the life cycle.
InvalidTransitionError
If setting the provided state represents an invalid life cycle
transition.
"""
new_state = self.lifecycle.get_state(state)
if self._current_state.valid_next_state(new_state):
self._timings[self._current_state.name].append(
time.time() - self._current_state_start_time
)
new_state.enter()
self._current_state = new_state
self._current_state_start_time = time.time()
else:
raise InvalidTransitionError(
f"Invalid transition from {self.current_state} "
f"to {new_state.name} requested."
)
[docs]
def get_state_names(self, phase: str) -> List[str]:
"""Gets all states in the phase in their order of execution.
Parameters
----------
phase
The name of the phase to retrieve the states for.
Returns
-------
List[str]
A list of state names in order of execution.
"""
return self.lifecycle.get_state_names(phase)
[docs]
def add_handlers(self, state_name: str, handlers: List[Callable]):
"""Registers a set of functions to be called during a life cycle state.
This method does not apply any constraints, rather it is used
to build up an execution order for introspection.
Parameters
----------
state_name
The name of the state to register the handlers for.
handlers
A list of functions that will execute during the state.
"""
s = self.lifecycle.get_state(state_name)
s.add_handlers(handlers)
[docs]
def add_constraint(
self, method: Callable, allow_during: List[str] = (), restrict_during: List[str] = ()
):
"""Constrains a function to be executable only during certain states.
Parameters
----------
method
The method to add constraints to.
allow_during
An optional list of life cycle states in which the provided
method is allowed to be called.
restrict_during
An optional list of life cycle states in which the provided
method is restricted from being called.
Raises
------
ValueError
If neither ``allow_during`` nor ``restrict_during`` are provided,
or if both are provided.
LifeCycleError
If states provided as arguments are not in the life cycle.
ConstraintError
If a lifecycle constraint has already been applied to the provided
method.
"""
if allow_during and restrict_during or not (allow_during or restrict_during):
raise ValueError(
'Must provide exactly one of "allow_during" or "restrict_during".'
)
unknown_states = (
set(allow_during).union(restrict_during).difference(self.lifecycle._state_names)
)
if unknown_states:
raise LifeCycleError(
f"Attempting to constrain {method} with "
f"states not in the life cycle: {list(unknown_states)}."
)
if restrict_during:
allow_during = [
s for s in self.lifecycle._state_names if s not in restrict_during
]
self._make_constraint(method, allow_during)
def __repr__(self) -> str:
return f"LifeCycleManager(state={self.current_state})"
def __str__(self) -> str:
return str(self.lifecycle)
[docs]
class LifeCycleInterface:
"""Interface to the life cycle management system.
The life cycle management system allows components to constrain
methods so that they're only available during certain simulation
life cycle states.
"""
def __init__(self, manager: LifeCycleManager):
self._manager = manager
[docs]
def add_handlers(self, state: str, handlers: List[Callable]):
"""Registers a set of functions to be called during a life cycle state.
This method does not apply any constraints, rather it is used
to build up an execution order for introspection.
Parameters
----------
state
The name of the state to register the handlers for.
handlers
A list of functions that will execute during the state.
"""
self._manager.add_handlers(state, handlers)
[docs]
def add_constraint(
self, method: Callable, allow_during: List[str] = (), restrict_during: List[str] = ()
):
"""Constrains a function to be executable only during certain states.
Parameters
----------
method
The method to add constraints to.
allow_during
An optional list of life cycle states in which the provided
method is allowed to be called.
restrict_during
An optional list of life cycle states in which the provided
method is restricted from being called.
Raises
------
ValueError
If neither ``allow_during`` nor ``restrict_during`` are provided,
or if both are provided.
LifeCycleError
If states provided as arguments are not in the life cycle.
ConstraintError
If a life cycle constraint has already been applied to the
provided method.
"""
self._manager.add_constraint(method, allow_during, restrict_during)
[docs]
def current_state(self) -> Callable[[], str]:
"""Returns a callable that gets the current simulation lifecycle state.
Returns
-------
Callable[[], str]
A callable that returns the current simulation lifecycle state.
"""
return lambda: self._manager.current_state