Source code for vivarium.framework.state_machine

"""
=============
State Machine
=============

A state machine implementation for use in ``vivarium`` simulations.

"""

from __future__ import annotations

from collections.abc import Callable, Iterable, Sequence
from enum import Enum
from typing import TYPE_CHECKING, Any, Iterator

import numpy as np
import pandas as pd

from vivarium import Component

if TYPE_CHECKING:
    from vivarium.framework.engine import Builder
    from vivarium.framework.event import Event
    from vivarium.framework.population import PopulationView, SimulantData
    from vivarium.types import ClockTime, DataInput, NumericArray


[docs] def default_probability_function(index: pd.Index[int]) -> pd.Series[float]: """Returns a series of ones for the provided index. This is the default transition decision function (always triggers this transition). """ return pd.Series(1.0, index=index)
def _next_state( index: pd.Index[int], event_time: ClockTime, transition_set: TransitionSet, population_view: PopulationView, ) -> None: """Moves a population between different states using information from a `TransitionSet`. Parameters ---------- index An iterable of integer labels for the simulants. event_time When this transition is occurring. transition_set A set of potential transitions available to the simulants. population_view A view of the internal state of the simulation. """ if len(transition_set) == 0 or index.empty: return outputs, decisions = transition_set.choose_new_state(index) groups = _groupby_new_state(index, outputs, decisions) if groups: for output, affected_index in sorted(groups, key=lambda x: str(x[0])): if output == "null_transition": pass elif isinstance(output, Transient): if not isinstance(output, State): raise ValueError("Invalid transition output: {}".format(output)) output.transition_effect(affected_index, event_time, population_view) output.next_state(affected_index, event_time, population_view) elif isinstance(output, State): output.transition_effect(affected_index, event_time, population_view) else: raise ValueError("Invalid transition output: {}".format(output)) def _groupby_new_state( index: pd.Index[int], outputs: list[State | str], decisions: pd.Series[Any] ) -> list[tuple[State | str, pd.Index[int]]]: """Groups the simulants in the index by their new output state. Parameters ---------- index An iterable of integer labels for the simulants. outputs A list of possible output states. decisions A series containing the name of the next state for each simulant in the index. Returns ------- The first item in each tuple is the name of an output state and the second item is a `pandas.Index` representing the simulants to transition into that state. """ groups = pd.Series(index).groupby( pd.CategoricalIndex(decisions.values, categories=outputs), observed=False ) return [(output, pd.Index(sub_group.values)) for output, sub_group in groups]
[docs] class Trigger(Enum): NOT_TRIGGERED = 0 START_INACTIVE = 1 START_ACTIVE = 2
def _process_trigger(trigger: Trigger) -> tuple[pd.Index[int] | None, bool]: if trigger == Trigger.NOT_TRIGGERED: return None, False elif trigger == Trigger.START_INACTIVE: return pd.Index([]), False elif trigger == Trigger.START_ACTIVE: return pd.Index([]), True else: raise ValueError("Invalid trigger state provided: {}".format(trigger))
[docs] class Transition(Component): """A process by which an entity might change into a particular state.""" ##################### # Lifecycle methods # ##################### def __init__( self, input_state: State, output_state: State, probability_func: Callable[ [pd.Index[int]], pd.Series[float] ] = lambda index: pd.Series(1.0, index=index), triggered: Trigger = Trigger.NOT_TRIGGERED, ) -> None: """Initializes a transition between two states. Parameters ---------- input_state The start state of the entity that undergoes the transition. output_state The end state of the entity that undergoes the transition. probability_func A method or function that describing the probability of this transition occurring. triggered A flag indicating whether this transition is triggered by some event. """ super().__init__() self.input_state = input_state self.output_state = output_state self._probability = probability_func self._active_index, self.start_active = _process_trigger(triggered) ################## # Public methods # ##################
[docs] def set_active(self, index: pd.Index[int]) -> None: if self._active_index is None: raise ValueError( "This transition is not triggered. An active index cannot be set or modified." ) else: self._active_index = self._active_index.union(pd.Index(index))
[docs] def set_inactive(self, index: pd.Index[int]) -> None: if self._active_index is None: raise ValueError( "This transition is not triggered. An active index cannot be set or modified." ) else: self._active_index = self._active_index.difference(pd.Index(index))
[docs] def probability(self, index: pd.Index[int]) -> pd.Series[float]: if self._active_index is None: return self._probability(index) index = pd.Index(index) activated_index = self._active_index.intersection(index) null_index = index.difference(self._active_index) activated = pd.Series(self._probability(activated_index), index=activated_index) null = pd.Series(np.zeros(len(null_index), dtype=float), index=null_index) activated.update(null) return activated
[docs] class State(Component): """An abstract representation of a particular position in a state space. Attributes ---------- state_id The name of this state. This should be unique transition_set A container for potential transitions out of this state. """ ############## # Properties # ############## @property def configuration_defaults(self) -> dict[str, Any]: return { f"{self.name}": { "data_sources": { "initialization_weights": self.initialization_weights, }, }, } @property def model(self) -> str | None: return self._model ##################### # Lifecycle methods # ##################### def __init__( self, state_id: str, allow_self_transition: bool = True, initialization_weights: DataInput = 0.0, ) -> None: super().__init__() self.state_id = state_id self.transition_set = TransitionSet( self.state_id, allow_self_transition=allow_self_transition ) self.initialization_weights = initialization_weights self._model: str | None = None self._sub_components = [self.transition_set] self.initialization_weights_pipeline = f"{self.state_id}.initialization_weights"
[docs] def setup(self, builder: Builder) -> None: self.initialization_weights_table = self.build_lookup_table( builder, "initialization_weights" ) builder.value.register_attribute_producer( self.initialization_weights_pipeline, self.initialization_weights_table )
################## # Public methods # ##################
[docs] def has_initialization_weights(self) -> bool: """Determines if state has explicitly defined initialization weights.""" return not ( not isinstance(self.initialization_weights_table.data, pd.DataFrame) and self.initialization_weights_table.data == 0.0 )
[docs] def set_model(self, model_name: str) -> None: """Defines the column name for the model this state belongs to""" self._model = model_name
[docs] def next_state( self, index: pd.Index[int], event_time: ClockTime, population_view: PopulationView ) -> None: """Moves a population between different states. Parameters ---------- index An iterable of integer labels for the simulants. event_time When this transition is occurring. population_view A view of the internal state of the simulation. """ return _next_state(index, event_time, self.transition_set, population_view)
[docs] def transition_effect( self, index: pd.Index[int], event_time: ClockTime, population_view: PopulationView ) -> None: """Updates the simulation state and triggers any side-effects associated with entering this state. Parameters ---------- index An iterable of integer labels for the simulants. event_time The time at which this transition occurs. population_view A view of the internal state of the simulation. """ if self.model is None: raise ValueError( f"State '{self.state_id}' has no model set. " "Call set_model() before transitioning." ) population_view.update(self.model, lambda _: pd.Series(self.state_id, index=index)) self.transition_side_effect(index, event_time)
[docs] def cleanup_effect(self, index: pd.Index[int], event_time: ClockTime) -> None: pass
[docs] def add_transition( self, transition: Transition | None = None, output_state: State | None = None, probability_function: Callable[ [pd.Index[int]], pd.Series[float] ] = default_probability_function, triggered: Trigger = Trigger.NOT_TRIGGERED, ) -> Transition: """Adds a transition to this state and its `TransitionSet`. A transition can be added by passing a `Transition` object or by specifying an output state and a decision function. If a transition is provided, the output state and decision function must not be. Parameters ---------- transition The transition to add. output_state The state to transition to probability_function A function that determines the probability that this transition should happen. By default, this is a function that will produce a probability of 1.0 for all simulants in the state. triggered A flag indicating whether this transition is triggered by some event. """ if transition is not None: if ( output_state is not None or probability_function != default_probability_function or triggered != Trigger.NOT_TRIGGERED ): raise ValueError( "Cannot provide an output state or a decision function if a" " transition is provided." ) else: if output_state is None: raise ValueError("Must specify either a transition or an output state.") transition = Transition(self, output_state, probability_function, triggered) self.transition_set.append(transition) return transition
################## # Helper methods # ##################
[docs] def transition_side_effect(self, index: pd.Index[int], event_time: ClockTime) -> None: pass
[docs] class Transient: """Used to tell _next_state to transition a second time.""" pass
[docs] class TransientState(State, Transient): pass
[docs] class TransitionSet(Component): """A container for state machine transitions. Attributes ---------- state_id The unique name of the state that instantiated this TransitionSet. Typically a string but any object implementing __str__ will do. allow_null_transition Specified whether it is possible not to transition on a given time-step transitions A list of transitions that can be taken from this state. random The randomness stream. """ ############## # Properties # ############## @property def name(self) -> str: return f"transition_set.{self.state_id}" ##################### # Lifecycle methods # ##################### def __init__( self, state_id: str, *transitions: Transition, allow_self_transition: bool = True ): super().__init__() self.state_id = state_id self.allow_self_transition = allow_self_transition self.transitions: list[Transition] = [] self._sub_components = self.transitions self.extend(transitions)
[docs] def setup(self, builder: Builder) -> None: """Performs this component's simulation setup and return sub-components. Parameters ---------- builder Interface to several simulation tools including access to common random number generation, in particular. """ self.random = builder.randomness.get_stream(self.name)
################## # Public methods # ##################
[docs] def choose_new_state( self, index: pd.Index[int] ) -> tuple[list[State | str], pd.Series[Any]]: """Chooses a new state for each simulant in the index. Parameters ---------- index An iterable of integer labels for the simulants. Returns ------- A tuple of the possible end states of this set of transitions and a series containing the name of the next state for each simulant in the index. """ outputs, probabilities = zip( *[ (transition.output_state, np.array(transition.probability(index))) for transition in self.transitions ] ) probabilities = np.transpose(probabilities) outputs, probabilities = self._normalize_probabilities(outputs, probabilities) return outputs, self.random.choice(index, outputs, probabilities)
[docs] def append(self, transition: Transition) -> None: if not isinstance(transition, Transition): raise TypeError( "TransitionSet must contain only Transition objects. " f"Check constructor arguments: {self}" ) self.transitions.append(transition)
[docs] def extend(self, transitions: Iterable[Transition]) -> None: for transition in transitions: self.append(transition)
################## # Helper methods # ################## def _normalize_probabilities( self, outputs: list[State | str], probabilities: NumericArray ) -> tuple[list[State | str], NumericArray]: """Normalizes probabilities to sum to 1 and add a null transition. Parameters ---------- outputs List of possible end states corresponding to this containers transitions. probabilities A set of probability weights whose columns correspond to the end states in `outputs` and whose rows correspond to each simulant undergoing the transition. Returns ------- A tuple of the original output list expanded to include a null transition (a transition back to the starting state) if requested and the original probabilities rescaled to sum to 1 and potentially expanded to include a null transition weight. """ outputs = list(outputs) # This is mainly for flexibility with the triggered transitions. # We may have multiple out transitions from a state where one of them # is gated until some criteria is met. After the criteria is # met, the gated transition becomes the default (likely as opposed # to a self transition). default_transition_count = np.sum(probabilities == 1, axis=1) if np.any(default_transition_count > 1): raise ValueError("Multiple transitions specified with probability 1.") has_default = default_transition_count == 1 total = np.sum(probabilities, axis=1) probabilities[has_default] /= total[has_default, np.newaxis] total = np.sum(probabilities, axis=1) # All totals should be ~<= 1 at this point. if self.allow_self_transition: if np.any(total > 1 + 1e-08): # Accommodate rounding errors raise ValueError( f"Null transition requested with un-normalized " f"probability weights: {probabilities}" ) total[total > 1] = 1 # Correct allowed rounding errors. probabilities = np.concatenate( [probabilities, (1 - total)[:, np.newaxis]], axis=1 ) outputs.append("null_transition") else: if np.any(total == 0): raise ValueError("No valid transitions for some simulants.") else: # total might be less than zero in some places probabilities = np.divide(probabilities, total[:, np.newaxis]) return outputs, probabilities def __iter__(self) -> Iterator[Transition]: return iter(self.transitions) def __len__(self) -> int: return len(self.transitions) def __hash__(self) -> int: return hash(id(self))
[docs] class Machine(Component): """A collection of states and transitions between those states. Attributes ---------- states The collection of states represented by this state machine. state_column A label for the piece of simulation state governed by this state machine. """ ############## # Properties # ############## @property def sub_components(self) -> Sequence[Component]: return self.states ##################### # Lifecycle methods # ##################### def __init__( self, state_column: str, states: Iterable[State] = (), initial_state: State | None = None, ) -> None: super().__init__() self.states: list[State] = [] self.state_column = state_column self._initial_state = initial_state self.initialization_weights_pipelines: list[str] = [] if states: self.add_states(states) if initial_state is not None: if initial_state not in self.states: raise ValueError( f"Initial state '{initial_state}' must be one of the" f" states: {self.states}." ) initial_state.initialization_weights = 1.0
[docs] def setup(self, builder: Builder) -> None: self.randomness = builder.randomness.get_stream(self.name) builder.population.register_initializer( initializer=self.initialize_state, columns=self.state_column, required_resources=[self.randomness, *self.initialization_weights_pipelines], )
[docs] def on_post_setup(self, event: Event) -> None: states_with_initialization_weights = [ state for state in self.states if state.has_initialization_weights() ] if self._initial_state is not None and states_with_initialization_weights != [ self._initial_state ]: raise ValueError( "Cannot specify both an initial state and provide initialization" " weights to states." ) elif self._initial_state is None and not states_with_initialization_weights: raise ValueError( "Must specify either an initial state or provide" " initialization weights to states." )
[docs] def initialize_state(self, pop_data: SimulantData) -> None: state_ids = [s.state_id for s in self.states] state_weights = self.population_view.get( pop_data.index, self.initialization_weights_pipelines ) initial_states = self.randomness.choice( pop_data.index, state_ids, state_weights.to_numpy(), "initialization" ).rename(self.state_column) self.population_view.initialize(initial_states)
[docs] def on_time_step(self, event: Event) -> None: self.transition(event.index, event.time)
[docs] def on_time_step_cleanup(self, event: Event) -> None: self.cleanup(event.index, event.time)
################## # Public methods # ##################
[docs] def add_states(self, states: Iterable[State]) -> None: for state in states: self.states.append(state) self.initialization_weights_pipelines.append( state.initialization_weights_pipeline ) state.set_model(self.state_column)
[docs] def transition(self, index: pd.Index[int], event_time: ClockTime) -> None: """Finds the population in each state and moves them to the next state. Parameters ---------- index An iterable of integer labels for the simulants. event_time The time at which this transition occurs. """ for state, affected in self._get_state_pops(index): if not affected.empty: state.next_state( affected.index, event_time, self.population_view, )
[docs] def cleanup(self, index: pd.Index[int], event_time: ClockTime) -> None: for state, affected in self._get_state_pops(index): if not affected.empty: state.cleanup_effect(affected.index, event_time)
def _get_state_pops(self, index: pd.Index[int]) -> list[tuple[State, pd.Series[Any]]]: population = self.population_view.get(index, self.state_column) if not isinstance(population, pd.Series): raise TypeError( "Expected population view to return a pandas Series for" f" state column '{self.state_column}', but got: {type(population)}" ) return [(state, population[population == state.state_id]) for state in self.states] ################## # Helper methods # ##################
[docs] def get_initialization_parameters(self) -> dict[str, Any]: """ Gets the values of the state column specified in the __init__`. Returns ------- The value of the state column. Notes ----- This retrieves the value of the attribute at the time of calling which is not guaranteed to be the same as the original value. """ return {"state_column": self.state_column}