feat: pass RuntimeState through event bus, add .checkpoint(directory)

This commit is contained in:
Greyson LaLonde
2026-04-03 04:19:06 +08:00
parent 804c26bd01
commit 6627845372
4 changed files with 64 additions and 7 deletions

View File

@@ -184,7 +184,34 @@ try:
| Annotated[Agent, Tag("agent")],
Discriminator(_entity_discriminator),
]
RuntimeState = RootModel[list[Entity]]
class RuntimeState(RootModel[list[Entity]]):
def checkpoint(self, directory: str) -> str:
"""Write a checkpoint file to the directory.
Args:
directory: Directory to write checkpoint files into.
Returns:
The path of the written file.
"""
from datetime import datetime, timezone
from pathlib import Path as _Path
import uuid
from crewai.context import capture_execution_context
for entity in self.root:
entity.execution_context = capture_execution_context()
dir_path = _Path(directory)
dir_path.mkdir(parents=True, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
filename = f"{ts}_{uuid.uuid4().hex[:8]}.json"
file_path = dir_path / filename
file_path.write_text(self.model_dump_json())
return str(file_path)
try:
Agent.model_rebuild(force=True, _types_namespace=_full_namespace)

View File

@@ -11,6 +11,7 @@ from collections.abc import Callable, Generator
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
import contextvars
import inspect
import threading
from typing import Any, Final, ParamSpec, TypeVar
@@ -87,6 +88,7 @@ class CrewAIEventsBus:
_futures_lock: threading.Lock
_executor_initialized: bool
_has_pending_events: bool
_runtime_state: Any
def __new__(cls) -> Self:
"""Create or return the singleton instance.
@@ -122,6 +124,7 @@ class CrewAIEventsBus:
# Lazy initialization flags - executor and loop created on first emit
self._executor_initialized = False
self._has_pending_events = False
self._runtime_state: Any = None
def _ensure_executor_initialized(self) -> None:
"""Lazily initialize the thread pool executor and event loop.
@@ -248,6 +251,10 @@ class CrewAIEventsBus:
return decorator
def set_runtime_state(self, state: Any) -> None:
"""Set the RuntimeState that will be passed to event handlers."""
self._runtime_state = state
def off(
self,
event_type: type[BaseEvent],
@@ -294,10 +301,12 @@ class CrewAIEventsBus:
event: The event instance
handlers: Frozenset of sync handlers to call
"""
state = self._runtime_state
errors: list[tuple[SyncHandler, Exception]] = [
(handler, error)
for handler in handlers
if (error := is_call_handler_safe(handler, source, event)) is not None
if (error := is_call_handler_safe(handler, source, event, state))
is not None
]
if errors:
@@ -319,7 +328,15 @@ class CrewAIEventsBus:
event: The event instance
handlers: Frozenset of async handlers to call
"""
coros = [handler(source, event) for handler in handlers]
state = self._runtime_state
async def _call(handler: AsyncHandler) -> Any:
sig = inspect.signature(handler)
if len(sig.parameters) >= 3:
return await handler(source, event, state) # type: ignore[call-arg]
return await handler(source, event) # type: ignore[call-arg]
coros = [_call(handler) for handler in handlers]
results = await asyncio.gather(*coros, return_exceptions=True)
for handler, result in zip(handlers, results, strict=False):
if isinstance(result, Exception):

View File

@@ -6,10 +6,17 @@ from typing import Any, TypeAlias
from crewai.events.base_events import BaseEvent
SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None]
AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
SyncHandler: TypeAlias = (
Callable[[Any, BaseEvent], None] | Callable[[Any, BaseEvent, Any], None]
)
AsyncHandler: TypeAlias = (
Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
| Callable[[Any, BaseEvent, Any], Coroutine[Any, Any, None]]
)
SyncHandlerSet: TypeAlias = frozenset[SyncHandler]
AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler]
Handler: TypeAlias = Callable[[Any, BaseEvent], Any]
Handler: TypeAlias = (
Callable[[Any, BaseEvent], Any] | Callable[[Any, BaseEvent, Any], Any]
)
ExecutionPlan: TypeAlias = list[set[Handler]]

View File

@@ -41,6 +41,7 @@ def is_call_handler_safe(
handler: SyncHandler,
source: Any,
event: BaseEvent,
state: Any = None,
) -> Exception | None:
"""Safely call a single handler and return any exception.
@@ -48,12 +49,17 @@ def is_call_handler_safe(
handler: The handler function to call
source: The object that emitted the event
event: The event instance
state: Optional RuntimeState passed as third arg if handler accepts it
Returns:
Exception if handler raised one, None otherwise
"""
try:
handler(source, event)
sig = inspect.signature(handler)
if len(sig.parameters) >= 3:
handler(source, event, state) # type: ignore[call-arg]
else:
handler(source, event) # type: ignore[call-arg]
return None
except Exception as e:
return e