From 6627845372dcd07a610b2b1dfab48e0a6bb6d9e9 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 3 Apr 2026 04:19:06 +0800 Subject: [PATCH] feat: pass RuntimeState through event bus, add .checkpoint(directory) --- lib/crewai/src/crewai/__init__.py | 29 ++++++++++++++++++- lib/crewai/src/crewai/events/event_bus.py | 21 ++++++++++++-- .../crewai/events/types/event_bus_types.py | 13 +++++++-- .../src/crewai/events/utils/handlers.py | 8 ++++- 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index a5344c8eb..c05e62db2 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -184,7 +184,34 @@ try: | Annotated[Agent, Tag("agent")], Discriminator(_entity_discriminator), ] - RuntimeState = RootModel[list[Entity]] + + class RuntimeState(RootModel[list[Entity]]): + def checkpoint(self, directory: str) -> str: + """Write a checkpoint file to the directory. + + Args: + directory: Directory to write checkpoint files into. + + Returns: + The path of the written file. + """ + from datetime import datetime, timezone + from pathlib import Path as _Path + import uuid + + from crewai.context import capture_execution_context + + for entity in self.root: + entity.execution_context = capture_execution_context() + + dir_path = _Path(directory) + dir_path.mkdir(parents=True, exist_ok=True) + + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") + filename = f"{ts}_{uuid.uuid4().hex[:8]}.json" + file_path = dir_path / filename + file_path.write_text(self.model_dump_json()) + return str(file_path) try: Agent.model_rebuild(force=True, _types_namespace=_full_namespace) diff --git a/lib/crewai/src/crewai/events/event_bus.py b/lib/crewai/src/crewai/events/event_bus.py index eefe1ad88..9675d9b4f 100644 --- a/lib/crewai/src/crewai/events/event_bus.py +++ b/lib/crewai/src/crewai/events/event_bus.py @@ -11,6 +11,7 @@ from collections.abc import Callable, Generator from concurrent.futures import Future, ThreadPoolExecutor from contextlib import contextmanager import contextvars +import inspect import threading from typing import Any, Final, ParamSpec, TypeVar @@ -87,6 +88,7 @@ class CrewAIEventsBus: _futures_lock: threading.Lock _executor_initialized: bool _has_pending_events: bool + _runtime_state: Any def __new__(cls) -> Self: """Create or return the singleton instance. @@ -122,6 +124,7 @@ class CrewAIEventsBus: # Lazy initialization flags - executor and loop created on first emit self._executor_initialized = False self._has_pending_events = False + self._runtime_state: Any = None def _ensure_executor_initialized(self) -> None: """Lazily initialize the thread pool executor and event loop. @@ -248,6 +251,10 @@ class CrewAIEventsBus: return decorator + def set_runtime_state(self, state: Any) -> None: + """Set the RuntimeState that will be passed to event handlers.""" + self._runtime_state = state + def off( self, event_type: type[BaseEvent], @@ -294,10 +301,12 @@ class CrewAIEventsBus: event: The event instance handlers: Frozenset of sync handlers to call """ + state = self._runtime_state errors: list[tuple[SyncHandler, Exception]] = [ (handler, error) for handler in handlers - if (error := is_call_handler_safe(handler, source, event)) is not None + if (error := is_call_handler_safe(handler, source, event, state)) + is not None ] if errors: @@ -319,7 +328,15 @@ class CrewAIEventsBus: event: The event instance handlers: Frozenset of async handlers to call """ - coros = [handler(source, event) for handler in handlers] + state = self._runtime_state + + async def _call(handler: AsyncHandler) -> Any: + sig = inspect.signature(handler) + if len(sig.parameters) >= 3: + return await handler(source, event, state) # type: ignore[call-arg] + return await handler(source, event) # type: ignore[call-arg] + + coros = [_call(handler) for handler in handlers] results = await asyncio.gather(*coros, return_exceptions=True) for handler, result in zip(handlers, results, strict=False): if isinstance(result, Exception): diff --git a/lib/crewai/src/crewai/events/types/event_bus_types.py b/lib/crewai/src/crewai/events/types/event_bus_types.py index 8a650a731..677f6ce93 100644 --- a/lib/crewai/src/crewai/events/types/event_bus_types.py +++ b/lib/crewai/src/crewai/events/types/event_bus_types.py @@ -6,10 +6,17 @@ from typing import Any, TypeAlias from crewai.events.base_events import BaseEvent -SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None] -AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]] +SyncHandler: TypeAlias = ( + Callable[[Any, BaseEvent], None] | Callable[[Any, BaseEvent, Any], None] +) +AsyncHandler: TypeAlias = ( + Callable[[Any, BaseEvent], Coroutine[Any, Any, None]] + | Callable[[Any, BaseEvent, Any], Coroutine[Any, Any, None]] +) SyncHandlerSet: TypeAlias = frozenset[SyncHandler] AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler] -Handler: TypeAlias = Callable[[Any, BaseEvent], Any] +Handler: TypeAlias = ( + Callable[[Any, BaseEvent], Any] | Callable[[Any, BaseEvent, Any], Any] +) ExecutionPlan: TypeAlias = list[set[Handler]] diff --git a/lib/crewai/src/crewai/events/utils/handlers.py b/lib/crewai/src/crewai/events/utils/handlers.py index bc3e76eee..4c7dc31a3 100644 --- a/lib/crewai/src/crewai/events/utils/handlers.py +++ b/lib/crewai/src/crewai/events/utils/handlers.py @@ -41,6 +41,7 @@ def is_call_handler_safe( handler: SyncHandler, source: Any, event: BaseEvent, + state: Any = None, ) -> Exception | None: """Safely call a single handler and return any exception. @@ -48,12 +49,17 @@ def is_call_handler_safe( handler: The handler function to call source: The object that emitted the event event: The event instance + state: Optional RuntimeState passed as third arg if handler accepts it Returns: Exception if handler raised one, None otherwise """ try: - handler(source, event) + sig = inspect.signature(handler) + if len(sig.parameters) >= 3: + handler(source, event, state) # type: ignore[call-arg] + else: + handler(source, event) # type: ignore[call-arg] return None except Exception as e: return e