Source code for simcraft.activities.state_machine

"""
State machine framework for entity lifecycle management.

Provides a flexible state machine implementation for modeling
complex entity behaviors and workflows.
"""

from __future__ import annotations
from dataclasses import dataclass, field
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
)
from enum import Enum, auto

if TYPE_CHECKING:
    from simcraft.core.simulation import Simulation

T = TypeVar("T")


[docs] @dataclass class State: """ A state in a state machine. Attributes ---------- name : str State identifier on_enter : Optional[Callable] Called when entering the state on_exit : Optional[Callable] Called when exiting the state on_stay : Optional[Callable] Called while staying in the state is_initial : bool Whether this is an initial state is_final : bool Whether this is a final state metadata : Dict Additional state metadata """ name: str on_enter: Optional[Callable[[Any], None]] = None on_exit: Optional[Callable[[Any], None]] = None on_stay: Optional[Callable[[Any], None]] = None is_initial: bool = False is_final: bool = False metadata: Dict[str, Any] = field(default_factory=dict) def __hash__(self) -> int: return hash(self.name) def __eq__(self, other: object) -> bool: if isinstance(other, State): return self.name == other.name if isinstance(other, str): return self.name == other return False def __str__(self) -> str: return self.name
[docs] @dataclass class Transition: """ A transition between states. Attributes ---------- source : str Source state name target : str Target state name trigger : str Event that triggers the transition guard : Optional[Callable] Condition that must be true for transition action : Optional[Callable] Action to perform during transition priority : int Priority when multiple transitions match (higher = first) """ source: str target: str trigger: str = "" guard: Optional[Callable[[Any], bool]] = None action: Optional[Callable[[Any], None]] = None priority: int = 0
[docs] def can_fire(self, context: Any) -> bool: """Check if transition can fire.""" if self.guard is None: return True return self.guard(context)
[docs] def fire(self, context: Any) -> None: """Execute transition action.""" if self.action is not None: self.action(context)
[docs] class StateMachine(Generic[T]): """ Flexible state machine for entity lifecycle management. Supports hierarchical states, guards, actions, and timed transitions. Parameters ---------- sim : Simulation Parent simulation name : str State machine name Examples -------- >>> sm = StateMachine(sim, name="OrderProcess") >>> >>> # Define states >>> sm.add_state("created", is_initial=True) >>> sm.add_state("processing") >>> sm.add_state("shipped") >>> sm.add_state("delivered", is_final=True) >>> >>> # Define transitions >>> sm.add_transition("created", "processing", trigger="start") >>> sm.add_transition("processing", "shipped", ... trigger="complete", ... guard=lambda ctx: ctx.payment_verified) >>> sm.add_transition("shipped", "delivered", trigger="arrive") >>> >>> # Create instance and process >>> order = Order(id=1) >>> instance = sm.create_instance(order) >>> instance.trigger("start") # -> processing >>> instance.trigger("complete") # -> shipped (if payment verified) """
[docs] def __init__( self, sim: "Simulation", name: str = "", ) -> None: """Initialize state machine.""" self._sim = sim self._name = name or f"StateMachine_{id(self)}" self._states: Dict[str, State] = {} self._transitions: Dict[str, List[Transition]] = {} self._initial_state: Optional[str] = None # Global callbacks self._on_state_enter: Optional[Callable[[str, T], None]] = None self._on_state_exit: Optional[Callable[[str, T], None]] = None self._on_transition: Optional[Callable[[str, str, T], None]] = None
@property def name(self) -> str: """Get state machine name.""" return self._name @property def states(self) -> List[str]: """Get list of state names.""" return list(self._states.keys()) @property def initial_state(self) -> Optional[str]: """Get initial state name.""" return self._initial_state
[docs] def add_state( self, name: str, on_enter: Optional[Callable[[T], None]] = None, on_exit: Optional[Callable[[T], None]] = None, is_initial: bool = False, is_final: bool = False, **metadata: Any, ) -> State: """ Add a state to the machine. Parameters ---------- name : str State identifier on_enter : Optional[Callable] Called when entering state on_exit : Optional[Callable] Called when exiting state is_initial : bool Whether this is the initial state is_final : bool Whether this is a final state **metadata Additional state metadata Returns ------- State The created state """ state = State( name=name, on_enter=on_enter, on_exit=on_exit, is_initial=is_initial, is_final=is_final, metadata=metadata, ) self._states[name] = state self._transitions[name] = [] if is_initial: self._initial_state = name return state
[docs] def add_transition( self, source: str, target: str, trigger: str = "", guard: Optional[Callable[[T], bool]] = None, action: Optional[Callable[[T], None]] = None, priority: int = 0, ) -> Transition: """ Add a transition between states. Parameters ---------- source : str Source state name target : str Target state name trigger : str Event trigger name (empty for automatic) guard : Optional[Callable] Condition for transition action : Optional[Callable] Action to perform during transition priority : int Priority (higher = checked first) Returns ------- Transition The created transition """ if source not in self._states: raise ValueError(f"Unknown source state: {source}") if target not in self._states: raise ValueError(f"Unknown target state: {target}") transition = Transition( source=source, target=target, trigger=trigger, guard=guard, action=action, priority=priority, ) # Insert by priority (descending) transitions = self._transitions[source] insert_idx = 0 for i, t in enumerate(transitions): if transition.priority > t.priority: insert_idx = i break insert_idx = i + 1 transitions.insert(insert_idx, transition) return transition
[docs] def add_timed_transition( self, source: str, target: str, duration: Union[float, Callable[[T], float]], action: Optional[Callable[[T], None]] = None, ) -> Transition: """ Add a timed transition that fires after a duration. Parameters ---------- source : str Source state name target : str Target state name duration : Union[float, Callable] Time to wait or function returning time action : Optional[Callable] Action to perform Returns ------- Transition The created transition """ # Store duration info in transition for later scheduling transition = self.add_transition( source=source, target=target, trigger=f"_timeout_{source}_{target}", action=action, ) # Store duration for instance scheduling if not hasattr(self, "_timed_transitions"): self._timed_transitions = {} self._timed_transitions[transition.trigger] = duration return transition
[docs] def get_state(self, name: str) -> Optional[State]: """ Get a state by name. Parameters ---------- name : str State name Returns ------- Optional[State] The state or None """ return self._states.get(name)
[docs] def get_transitions_from(self, state: str) -> List[Transition]: """ Get all transitions from a state. Parameters ---------- state : str Source state name Returns ------- List[Transition] Transitions from the state """ return self._transitions.get(state, [])
[docs] def create_instance(self, context: T) -> "StateMachineInstance[T]": """ Create a new state machine instance. Parameters ---------- context : T Context object (entity) for this instance Returns ------- StateMachineInstance New instance in initial state """ return StateMachineInstance(self, context)
[docs] def on_state_enter(self, callback: Callable[[str, T], None]) -> None: """Set global callback for state entry.""" self._on_state_enter = callback
[docs] def on_state_exit(self, callback: Callable[[str, T], None]) -> None: """Set global callback for state exit.""" self._on_state_exit = callback
[docs] def on_transition(self, callback: Callable[[str, str, T], None]) -> None: """Set global callback for transitions.""" self._on_transition = callback
def __repr__(self) -> str: """Return detailed representation.""" return ( f"StateMachine(name={self._name!r}, " f"states={len(self._states)}, " f"initial={self._initial_state})" )
[docs] class StateMachineInstance(Generic[T]): """ Instance of a state machine for a specific entity. Parameters ---------- machine : StateMachine Parent state machine definition context : T Context object (entity) """
[docs] def __init__( self, machine: StateMachine[T], context: T, ) -> None: """Initialize instance.""" self._machine = machine self._context = context self._current_state: Optional[str] = None self._history: List[Tuple[float, str, str]] = [] self._pending_timeouts: Dict[str, Any] = {} # Enter initial state if machine.initial_state: self._enter_state(machine.initial_state)
@property def current_state(self) -> Optional[str]: """Get current state name.""" return self._current_state @property def context(self) -> T: """Get context object.""" return self._context @property def history(self) -> List[Tuple[float, str, str]]: """Get transition history (time, from_state, to_state).""" return self._history.copy() @property def is_in_final_state(self) -> bool: """Check if in a final state.""" if self._current_state is None: return False state = self._machine.get_state(self._current_state) return state is not None and state.is_final
[docs] def trigger(self, event: str) -> bool: """ Trigger an event to potentially cause a transition. Parameters ---------- event : str Event name Returns ------- bool True if a transition occurred """ if self._current_state is None: return False transitions = self._machine.get_transitions_from(self._current_state) for transition in transitions: if transition.trigger == event and transition.can_fire(self._context): self._execute_transition(transition) return True return False
[docs] def can_trigger(self, event: str) -> bool: """ Check if an event can cause a transition. Parameters ---------- event : str Event name Returns ------- bool True if a transition would occur """ if self._current_state is None: return False transitions = self._machine.get_transitions_from(self._current_state) for transition in transitions: if transition.trigger == event and transition.can_fire(self._context): return True return False
[docs] def force_state(self, state: str) -> None: """ Force transition to a state (bypassing guards). Parameters ---------- state : str Target state name """ if state not in self._machine._states: raise ValueError(f"Unknown state: {state}") self._exit_current_state() self._enter_state(state)
def _execute_transition(self, transition: Transition) -> None: """Execute a transition.""" from_state = self._current_state to_state = transition.target # Exit current state self._exit_current_state() # Execute transition action transition.fire(self._context) # Global transition callback if self._machine._on_transition: self._machine._on_transition(from_state, to_state, self._context) # Record history self._history.append((self._machine._sim.now, from_state, to_state)) # Enter new state self._enter_state(to_state) def _enter_state(self, state_name: str) -> None: """Enter a state.""" self._current_state = state_name state = self._machine.get_state(state_name) if state: # State entry callback if state.on_enter: state.on_enter(self._context) # Global entry callback if self._machine._on_state_enter: self._machine._on_state_enter(state_name, self._context) # Schedule timed transitions self._schedule_timed_transitions(state_name) def _exit_current_state(self) -> None: """Exit current state.""" if self._current_state is None: return state = self._machine.get_state(self._current_state) # Cancel pending timeouts for event in list(self._pending_timeouts.keys()): scheduled_event = self._pending_timeouts.pop(event) if scheduled_event: scheduled_event.cancel() if state: # State exit callback if state.on_exit: state.on_exit(self._context) # Global exit callback if self._machine._on_state_exit: self._machine._on_state_exit(self._current_state, self._context) def _schedule_timed_transitions(self, state_name: str) -> None: """Schedule any timed transitions from the state.""" if not hasattr(self._machine, "_timed_transitions"): return transitions = self._machine.get_transitions_from(state_name) for transition in transitions: if transition.trigger in self._machine._timed_transitions: duration = self._machine._timed_transitions[transition.trigger] if callable(duration): duration = duration(self._context) event = self._machine._sim.schedule( self.trigger, delay=duration, args=(transition.trigger,), ) self._pending_timeouts[transition.trigger] = event def __repr__(self) -> str: """Return detailed representation.""" return ( f"StateMachineInstance(" f"machine={self._machine.name!r}, " f"state={self._current_state}, " f"context={self._context})" )
# Common state machine patterns def create_simple_workflow( sim: "Simulation", states: List[str], name: str = "Workflow", ) -> StateMachine: """ Create a simple linear workflow state machine. Parameters ---------- sim : Simulation Parent simulation states : List[str] State names in order name : str Machine name Returns ------- StateMachine Configured state machine Examples -------- >>> sm = create_simple_workflow(sim, ["start", "step1", "step2", "end"]) """ sm: StateMachine = StateMachine(sim, name) for i, state in enumerate(states): sm.add_state( state, is_initial=(i == 0), is_final=(i == len(states) - 1) ) for i in range(len(states) - 1): sm.add_transition(states[i], states[i + 1], trigger="next") return sm def create_processing_workflow( sim: "Simulation", processing_time: Union[float, Callable[..., float]], name: str = "ProcessingWorkflow", ) -> StateMachine: """ Create a workflow with timed processing state. Parameters ---------- sim : Simulation Parent simulation processing_time : Union[float, Callable] Processing duration name : str Machine name Returns ------- StateMachine Configured state machine """ sm: StateMachine = StateMachine(sim, name) sm.add_state("waiting", is_initial=True) sm.add_state("processing") sm.add_state("completed", is_final=True) sm.add_transition("waiting", "processing", trigger="start") sm.add_timed_transition("processing", "completed", duration=processing_time) return sm