mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
feat: pass RuntimeState through event bus, add .checkpoint(directory)
This commit is contained in:
@@ -184,7 +184,34 @@ try:
|
||||
| Annotated[Agent, Tag("agent")],
|
||||
Discriminator(_entity_discriminator),
|
||||
]
|
||||
RuntimeState = RootModel[list[Entity]]
|
||||
|
||||
class RuntimeState(RootModel[list[Entity]]):
|
||||
def checkpoint(self, directory: str) -> str:
|
||||
"""Write a checkpoint file to the directory.
|
||||
|
||||
Args:
|
||||
directory: Directory to write checkpoint files into.
|
||||
|
||||
Returns:
|
||||
The path of the written file.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path as _Path
|
||||
import uuid
|
||||
|
||||
from crewai.context import capture_execution_context
|
||||
|
||||
for entity in self.root:
|
||||
entity.execution_context = capture_execution_context()
|
||||
|
||||
dir_path = _Path(directory)
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
||||
filename = f"{ts}_{uuid.uuid4().hex[:8]}.json"
|
||||
file_path = dir_path / filename
|
||||
file_path.write_text(self.model_dump_json())
|
||||
return str(file_path)
|
||||
|
||||
try:
|
||||
Agent.model_rebuild(force=True, _types_namespace=_full_namespace)
|
||||
|
||||
@@ -11,6 +11,7 @@ from collections.abc import Callable, Generator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
import contextvars
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Any, Final, ParamSpec, TypeVar
|
||||
|
||||
@@ -87,6 +88,7 @@ class CrewAIEventsBus:
|
||||
_futures_lock: threading.Lock
|
||||
_executor_initialized: bool
|
||||
_has_pending_events: bool
|
||||
_runtime_state: Any
|
||||
|
||||
def __new__(cls) -> Self:
|
||||
"""Create or return the singleton instance.
|
||||
@@ -122,6 +124,7 @@ class CrewAIEventsBus:
|
||||
# Lazy initialization flags - executor and loop created on first emit
|
||||
self._executor_initialized = False
|
||||
self._has_pending_events = False
|
||||
self._runtime_state: Any = None
|
||||
|
||||
def _ensure_executor_initialized(self) -> None:
|
||||
"""Lazily initialize the thread pool executor and event loop.
|
||||
@@ -248,6 +251,10 @@ class CrewAIEventsBus:
|
||||
|
||||
return decorator
|
||||
|
||||
def set_runtime_state(self, state: Any) -> None:
|
||||
"""Set the RuntimeState that will be passed to event handlers."""
|
||||
self._runtime_state = state
|
||||
|
||||
def off(
|
||||
self,
|
||||
event_type: type[BaseEvent],
|
||||
@@ -294,10 +301,12 @@ class CrewAIEventsBus:
|
||||
event: The event instance
|
||||
handlers: Frozenset of sync handlers to call
|
||||
"""
|
||||
state = self._runtime_state
|
||||
errors: list[tuple[SyncHandler, Exception]] = [
|
||||
(handler, error)
|
||||
for handler in handlers
|
||||
if (error := is_call_handler_safe(handler, source, event)) is not None
|
||||
if (error := is_call_handler_safe(handler, source, event, state))
|
||||
is not None
|
||||
]
|
||||
|
||||
if errors:
|
||||
@@ -319,7 +328,15 @@ class CrewAIEventsBus:
|
||||
event: The event instance
|
||||
handlers: Frozenset of async handlers to call
|
||||
"""
|
||||
coros = [handler(source, event) for handler in handlers]
|
||||
state = self._runtime_state
|
||||
|
||||
async def _call(handler: AsyncHandler) -> Any:
|
||||
sig = inspect.signature(handler)
|
||||
if len(sig.parameters) >= 3:
|
||||
return await handler(source, event, state) # type: ignore[call-arg]
|
||||
return await handler(source, event) # type: ignore[call-arg]
|
||||
|
||||
coros = [_call(handler) for handler in handlers]
|
||||
results = await asyncio.gather(*coros, return_exceptions=True)
|
||||
for handler, result in zip(handlers, results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
|
||||
@@ -6,10 +6,17 @@ from typing import Any, TypeAlias
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None]
|
||||
AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
|
||||
SyncHandler: TypeAlias = (
|
||||
Callable[[Any, BaseEvent], None] | Callable[[Any, BaseEvent, Any], None]
|
||||
)
|
||||
AsyncHandler: TypeAlias = (
|
||||
Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
|
||||
| Callable[[Any, BaseEvent, Any], Coroutine[Any, Any, None]]
|
||||
)
|
||||
SyncHandlerSet: TypeAlias = frozenset[SyncHandler]
|
||||
AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler]
|
||||
|
||||
Handler: TypeAlias = Callable[[Any, BaseEvent], Any]
|
||||
Handler: TypeAlias = (
|
||||
Callable[[Any, BaseEvent], Any] | Callable[[Any, BaseEvent, Any], Any]
|
||||
)
|
||||
ExecutionPlan: TypeAlias = list[set[Handler]]
|
||||
|
||||
@@ -41,6 +41,7 @@ def is_call_handler_safe(
|
||||
handler: SyncHandler,
|
||||
source: Any,
|
||||
event: BaseEvent,
|
||||
state: Any = None,
|
||||
) -> Exception | None:
|
||||
"""Safely call a single handler and return any exception.
|
||||
|
||||
@@ -48,12 +49,17 @@ def is_call_handler_safe(
|
||||
handler: The handler function to call
|
||||
source: The object that emitted the event
|
||||
event: The event instance
|
||||
state: Optional RuntimeState passed as third arg if handler accepts it
|
||||
|
||||
Returns:
|
||||
Exception if handler raised one, None otherwise
|
||||
"""
|
||||
try:
|
||||
handler(source, event)
|
||||
sig = inspect.signature(handler)
|
||||
if len(sig.parameters) >= 3:
|
||||
handler(source, event, state) # type: ignore[call-arg]
|
||||
else:
|
||||
handler(source, event) # type: ignore[call-arg]
|
||||
return None
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
Reference in New Issue
Block a user