Source code for vivarium.framework.state_machine

"""
=============
State Machine
=============

A state machine implementation for use in ``vivarium`` simulations.

"""
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd

from vivarium import Component

if TYPE_CHECKING:
    from vivarium.framework.engine import Builder
    from vivarium.framework.population import PopulationView
    from vivarium.framework.time import Time


def _next_state(
    index: pd.Index,
    event_time: "Time",
    transition_set: "TransitionSet",
    population_view: "PopulationView",
) -> None:
    """Moves a population between different states using information from a `TransitionSet`.

    Parameters
    ----------
    index
        An iterable of integer labels for the simulants.
    event_time
        When this transition is occurring.
    transition_set
        A set of potential transitions available to the simulants.
    population_view
        A view of the internal state of the simulation.

    """
    if len(transition_set) == 0 or index.empty:
        return

    outputs, decisions = transition_set.choose_new_state(index)
    groups = _groupby_new_state(index, outputs, decisions)

    if groups:
        for output, affected_index in sorted(groups, key=lambda x: str(x[0])):
            if output == "null_transition":
                pass
            elif isinstance(output, Transient):
                if not isinstance(output, State):
                    raise ValueError("Invalid transition output: {}".format(output))
                output.transition_effect(affected_index, event_time, population_view)
                output.next_state(affected_index, event_time, population_view)
            elif isinstance(output, State):
                output.transition_effect(affected_index, event_time, population_view)
            else:
                raise ValueError("Invalid transition output: {}".format(output))


def _groupby_new_state(
    index: pd.Index, outputs: List, decisions: pd.Series
) -> List[Tuple[str, pd.Index]]:
    """Groups the simulants in the index by their new output state.

    Parameters
    ----------
    index
        An iterable of integer labels for the simulants.
    outputs
        A list of possible output states.
    decisions
        A series containing the name of the next state for each simulant in the
        index.

    Returns
    -------
    List[Tuple[str, pandas.Index]
        The first item in each tuple is the name of an output state and the
        second item is a `pandas.Index` representing the simulants to transition
        into that state.

    """
    groups = pd.Series(index).groupby(
        pd.Categorical(decisions.values, categories=outputs), observed=False
    )
    return [(output, pd.Index(sub_group.values)) for output, sub_group in groups]


[docs] class Trigger(Enum): NOT_TRIGGERED = 0 START_INACTIVE = 1 START_ACTIVE = 2
def _process_trigger(trigger): if trigger == Trigger.NOT_TRIGGERED: return None, False elif trigger == Trigger.START_INACTIVE: return pd.Index([]), False elif trigger == Trigger.START_ACTIVE: return pd.Index([]), True else: raise ValueError("Invalid trigger state provided: {}".format(trigger))
[docs] class Transition(Component): """A process by which an entity might change into a particular state. Parameters ---------- input_state The start state of the entity that undergoes the transition. output_state The end state of the entity that undergoes the transition. probability_func A method or function that describing the probability of this transition occurring. """ ##################### # Lifecycle methods # ##################### def __init__( self, input_state: "State", output_state: "State", probability_func: Callable[[pd.Index], pd.Series] = lambda index: pd.Series( 1.0, index=index ), triggered=Trigger.NOT_TRIGGERED, ): super().__init__() self.input_state = input_state self.output_state = output_state self._probability = probability_func self._active_index, self.start_active = _process_trigger(triggered) ################## # Public methods # ##################
[docs] def set_active(self, index: pd.Index) -> None: if self._active_index is None: raise ValueError( "This transition is not triggered. An active index cannot be set or modified." ) else: self._active_index = self._active_index.union(pd.Index(index))
[docs] def set_inactive(self, index: pd.Index) -> None: if self._active_index is None: raise ValueError( "This transition is not triggered. An active index cannot be set or modified." ) else: self._active_index = self._active_index.difference(pd.Index(index))
[docs] def probability(self, index: pd.Index) -> pd.Series: if self._active_index is None: return self._probability(index) index = pd.Index(index) activated_index = self._active_index.intersection(index) null_index = index.difference(self._active_index) activated = pd.Series(self._probability(activated_index), index=activated_index) null = pd.Series(np.zeros(len(null_index), dtype=float), index=null_index) return activated.append(null)
[docs] class State(Component): """An abstract representation of a particular position in a state space. Attributes ---------- state_id The name of this state. This should be unique transition_set A container for potential transitions out of this state. """ ############## # Properties # ############## @property def model(self) -> str: return self._model ##################### # Lifecycle methods # ##################### def __init__(self, state_id: str, allow_self_transition: bool = False): super().__init__() self.state_id = state_id self.transition_set = TransitionSet( self.state_id, allow_self_transition=allow_self_transition ) self._model = None self._sub_components = [self.transition_set] ################## # Public methods # ##################
[docs] def set_model(self, model_name: str) -> None: """Defines the column name for the model this state belongs to""" self._model = model_name
[docs] def next_state( self, index: pd.Index, event_time: "Time", population_view: "PopulationView" ) -> None: """Moves a population between different states. Parameters ---------- index An iterable of integer labels for the simulants. event_time When this transition is occurring. population_view A view of the internal state of the simulation. """ return _next_state(index, event_time, self.transition_set, population_view)
[docs] def transition_effect( self, index: pd.Index, event_time: "Time", population_view: "PopulationView" ) -> None: """Updates the simulation state and triggers any side-effects associated with entering this state. Parameters ---------- index An iterable of integer labels for the simulants. event_time The time at which this transition occurs. population_view A view of the internal state of the simulation. """ population_view.update(pd.Series(self.state_id, index=index)) self.transition_side_effect(index, event_time)
[docs] def cleanup_effect(self, index: pd.Index, event_time: "Time") -> None: pass
[docs] def add_transition(self, transition: Transition) -> None: """Adds a transition to this state and its `TransitionSet`. Parameters ---------- transition The transition to add """ self.transition_set.append(transition)
[docs] def allow_self_transitions(self) -> None: self.transition_set.allow_null_transition = True
################## # Helper methods # ##################
[docs] def transition_side_effect(self, index: pd.Index, event_time: "Time") -> None: pass
[docs] class Transient: """Used to tell _next_state to transition a second time.""" pass
[docs] class TransientState(State, Transient): pass
[docs] class TransitionSet(Component): """A container for state machine transitions. Parameters ---------- state_id The unique name of the state that instantiated this TransitionSet. Typically a string but any object implementing __str__ will do. iterable Any iterable whose elements are `Transition` objects. allow_null_transition Specified whether it is possible not to transition on a given time-step """ ############## # Properties # ############## @property def name(self) -> str: return f"transition_set.{self.state_id}" ##################### # Lifecycle methods # ##################### def __init__( self, state_id: str, *transitions: Transition, allow_self_transition: bool = False ): super().__init__() self.state_id = state_id self.allow_null_transition = allow_self_transition self.transitions = [] self._sub_components = self.transitions self.extend(transitions)
[docs] def setup(self, builder: "Builder") -> None: """Performs this component's simulation setup and return sub-components. Parameters ---------- builder Interface to several simulation tools including access to common random number generation, in particular. """ self.random = builder.randomness.get_stream(self.name)
################## # Public methods # ##################
[docs] def choose_new_state(self, index: pd.Index) -> Tuple[List, pd.Series]: """Chooses a new state for each simulant in the index. Parameters ---------- index An iterable of integer labels for the simulants. Returns ------- List The possible end states of this set of transitions. pandas.Series A series containing the name of the next state for each simulant in the index. """ outputs, probabilities = zip( *[ (transition.output_state, np.array(transition.probability(index))) for transition in self.transitions ] ) probabilities = np.transpose(probabilities) outputs, probabilities = self._normalize_probabilities(outputs, probabilities) return outputs, self.random.choice(index, outputs, probabilities)
[docs] def append(self, transition: Transition) -> None: if not isinstance(transition, Transition): raise TypeError( "TransitionSet must contain only Transition objects. " f"Check constructor arguments: {self}" ) self.transitions.append(transition)
[docs] def extend(self, transitions: Iterable[Transition]) -> None: for transition in transitions: self.append(transition)
################## # Helper methods # ################## def _normalize_probabilities(self, outputs, probabilities): """Normalize probabilities to sum to 1 and add a null transition. Parameters ---------- outputs List of possible end states corresponding to this containers transitions. probabilities A set of probability weights whose columns correspond to the end states in `outputs` and whose rows correspond to each simulant undergoing the transition. Returns ------- List The original output list expanded to include a null transition (a transition back to the starting state) if requested. numpy.ndarray The original probabilities rescaled to sum to 1 and potentially expanded to include a null transition weight. """ outputs = list(outputs) # This is mainly for flexibility with the triggered transitions. # We may have multiple out transitions from a state where one of them # is gated until some criteria is met. After the criteria is # met, the gated transition becomes the default (likely as opposed # to a self transition). default_transition_count = np.sum(probabilities == 1, axis=1) if np.any(default_transition_count > 1): raise ValueError("Multiple transitions specified with probability 1.") has_default = default_transition_count == 1 total = np.sum(probabilities, axis=1) probabilities[has_default] /= total[has_default, np.newaxis] total = np.sum(probabilities, axis=1) # All totals should be ~<= 1 at this point. if self.allow_null_transition: if np.any(total > 1 + 1e-08): # Accommodate rounding errors raise ValueError( f"Null transition requested with un-normalized " f"probability weights: {probabilities}" ) total[total > 1] = 1 # Correct allowed rounding errors. probabilities = np.concatenate( [probabilities, (1 - total)[:, np.newaxis]], axis=1 ) outputs.append("null_transition") else: if np.any(total == 0): raise ValueError("No valid transitions for some simulants.") else: # total might be less than zero in some places probabilities /= total[:, np.newaxis] return outputs, probabilities def __iter__(self): return iter(self.transitions) def __len__(self): return len(self.transitions) def __hash__(self): return hash(id(self))
[docs] class Machine(Component): """A collection of states and transitions between those states. Attributes ---------- states The collection of states represented by this state machine. state_column A label for the piece of simulation state governed by this state machine. """ ############## # Properties # ############## @property def sub_components(self): return self.states @property def columns_required(self) -> Optional[List[str]]: return [self.state_column] ##################### # Lifecycle methods # ##################### def __init__(self, state_column: str, states: Iterable[State] = ()): super().__init__() self.states = [] self.state_column = state_column if states: self.add_states(states) ################## # Public methods # ##################
[docs] def add_states(self, states: Iterable[State]) -> None: for state in states: self.states.append(state) state.set_model(self.state_column)
[docs] def transition(self, index: pd.Index, event_time: "Time") -> None: """Finds the population in each state and moves them to the next state. Parameters ---------- index An iterable of integer labels for the simulants. event_time The time at which this transition occurs. """ for state, affected in self._get_state_pops(index): if not affected.empty: state.next_state( affected.index, event_time, self.population_view.subview([self.state_column]), )
[docs] def cleanup(self, index: pd.Index, event_time: "Time") -> None: for state, affected in self._get_state_pops(index): if not affected.empty: state.cleanup_effect(affected.index, event_time)
def _get_state_pops(self, index: pd.Index) -> List[Tuple[State, pd.DataFrame]]: population = self.population_view.get(index) return [ (state, population[population[self.state_column] == state.state_id]) for state in self.states ] ################## # Helper methods # ##################
[docs] def get_initialization_parameters(self) -> Dict[str, Any]: """ Gets the values of the state column specified in the __init__`. Note: this retrieves the value of the attribute at the time of calling which is not guaranteed to be the same as the original value. """ return {"state_column": self.state_column}