diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index de26cb784..57942a01e 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "uv~=0.9.13", "aiosqlite~=0.21.0", "pyyaml~=6.0", + "aiofiles~=24.1.0", "lancedb>=0.29.2,<0.30.1", ] diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index fec6f5f7b..24d9fd085 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -162,7 +162,7 @@ try: **sys.modules[_BaseAgent.__module__].__dict__, } - import crewai.runtime_state as _runtime_state_mod + import crewai.state.runtime as _runtime_state_mod for _mod_name in ( _BaseAgent.__module__, @@ -193,7 +193,7 @@ try: from pydantic import Discriminator, Tag - from crewai.runtime_state import RuntimeState, _entity_discriminator + from crewai.state.runtime import RuntimeState, _entity_discriminator Entity = Annotated[ Annotated[Flow, Tag("flow")] # type: ignore[type-arg] @@ -226,6 +226,7 @@ __all__ = [ "BaseLLM", "Crew", "CrewOutput", + "Entity", "ExecutionContext", "Flow", "Knowledge", diff --git a/lib/crewai/src/crewai/context.py b/lib/crewai/src/crewai/context.py index e6efe4349..10184ff39 100644 --- a/lib/crewai/src/crewai/context.py +++ b/lib/crewai/src/crewai/context.py @@ -90,7 +90,7 @@ class ExecutionContext(BaseModel): flow_id: str | None = Field(default=None) flow_method_name: str = Field(default="unknown") - event_id_stack: tuple[tuple[str, str], ...] = Field(default=()) + event_id_stack: tuple[tuple[str, str], ...] = Field(default_factory=tuple) last_event_id: str | None = Field(default=None) triggering_event_id: str | None = Field(default=None) emission_sequence: int = Field(default=0) diff --git a/lib/crewai/src/crewai/events/event_bus.py b/lib/crewai/src/crewai/events/event_bus.py index 045b5bd60..01aa1d9a6 100644 --- a/lib/crewai/src/crewai/events/event_bus.py +++ b/lib/crewai/src/crewai/events/event_bus.py @@ -21,7 +21,7 @@ from typing_extensions import Self if TYPE_CHECKING: - from crewai.runtime_state import RuntimeState + from crewai.state.runtime import RuntimeState from crewai.events.base_events import BaseEvent, get_next_emission_sequence from crewai.events.depends import Depends diff --git a/lib/crewai/src/crewai/state/__init__.py b/lib/crewai/src/crewai/state/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/crewai/src/crewai/state/provider/__init__.py b/lib/crewai/src/crewai/state/provider/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/crewai/src/crewai/state/provider/core.py b/lib/crewai/src/crewai/state/provider/core.py new file mode 100644 index 000000000..a3f7c9c5a --- /dev/null +++ b/lib/crewai/src/crewai/state/provider/core.py @@ -0,0 +1,59 @@ +"""Base protocol for state providers.""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema + + +@runtime_checkable +class BaseProvider(Protocol): + """Interface for persisting and restoring runtime state checkpoints. + + Implementations handle the storage backend (filesystem, cloud, database, + etc.) while ``RuntimeState`` handles serialization. + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """Allow Pydantic to validate any ``BaseProvider`` instance.""" + + def _validate(v: Any) -> BaseProvider: + if isinstance(v, BaseProvider): + return v + raise TypeError(f"Expected a BaseProvider instance, got {type(v)}") + + return core_schema.no_info_plain_validator_function( + _validate, + serialization=core_schema.plain_serializer_function_ser_schema( + lambda v: type(v).__name__, info_arg=False + ), + ) + + def checkpoint(self, data: str, directory: str) -> str: + """Persist a snapshot synchronously. + + Args: + data: The serialized string to persist. + directory: Logical destination (path, bucket prefix, etc.). + + Returns: + A location identifier for the saved checkpoint (e.g. file path, URI). + """ + ... + + async def acheckpoint(self, data: str, directory: str) -> str: + """Persist a snapshot asynchronously. + + Args: + data: The serialized string to persist. + directory: Logical destination (path, bucket prefix, etc.). + + Returns: + A location identifier for the saved checkpoint (e.g. file path, URI). + """ + ... diff --git a/lib/crewai/src/crewai/state/provider/json_provider.py b/lib/crewai/src/crewai/state/provider/json_provider.py new file mode 100644 index 000000000..67f770925 --- /dev/null +++ b/lib/crewai/src/crewai/state/provider/json_provider.py @@ -0,0 +1,64 @@ +"""Filesystem JSON state provider.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +import uuid + +import aiofiles +import aiofiles.os + +from crewai.state.provider.core import BaseProvider + + +class JsonProvider(BaseProvider): + """Persists runtime state checkpoints as JSON files on the local filesystem.""" + + def checkpoint(self, data: str, directory: str) -> str: + """Write a JSON checkpoint file to the directory. + + Args: + data: The serialized JSON string to persist. + directory: Filesystem path where the checkpoint will be saved. + + Returns: + The path to the written checkpoint file. + """ + file_path = _build_path(directory) + file_path.parent.mkdir(parents=True, exist_ok=True) + + with open(file_path, "w") as f: + f.write(data) + return str(file_path) + + async def acheckpoint(self, data: str, directory: str) -> str: + """Write a JSON checkpoint file to the directory asynchronously. + + Args: + data: The serialized JSON string to persist. + directory: Filesystem path where the checkpoint will be saved. + + Returns: + The path to the written checkpoint file. + """ + file_path = _build_path(directory) + await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True) + + async with aiofiles.open(file_path, "w") as f: + await f.write(data) + return str(file_path) + + +def _build_path(directory: str) -> Path: + """Build a timestamped checkpoint file path. + + Args: + directory: Parent directory for the checkpoint file. + + Returns: + The target file path. + """ + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") + filename = f"{ts}_{uuid.uuid4().hex[:8]}.json" + return Path(directory) / filename diff --git a/lib/crewai/src/crewai/runtime_state.py b/lib/crewai/src/crewai/state/runtime.py similarity index 52% rename from lib/crewai/src/crewai/runtime_state.py rename to lib/crewai/src/crewai/state/runtime.py index 0ceff2b85..784154c82 100644 --- a/lib/crewai/src/crewai/runtime_state.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -9,16 +9,17 @@ via ``RuntimeState.model_rebuild()``. from __future__ import annotations -from datetime import datetime, timezone -from pathlib import Path from typing import TYPE_CHECKING, Any -import uuid -from pydantic import RootModel +from pydantic import PrivateAttr, RootModel + +from crewai.context import capture_execution_context +from crewai.state.provider.core import BaseProvider +from crewai.state.provider.json_provider import JsonProvider if TYPE_CHECKING: - pass + from crewai import Entity def _entity_discriminator(v: dict[str, Any] | object) -> str: @@ -30,7 +31,12 @@ def _entity_discriminator(v: dict[str, Any] | object) -> str: def _sync_checkpoint_fields(entity: object) -> None: - """Copy private runtime attrs into checkpoint fields before serializing.""" + """Copy private runtime attrs into checkpoint fields before serializing. + + Args: + entity: The entity whose private runtime attributes will be + copied into its public checkpoint fields. + """ from crewai.crew import Crew from crewai.flow.flow import Flow @@ -56,21 +62,40 @@ def _sync_checkpoint_fields(entity: object) -> None: class RuntimeState(RootModel): # type: ignore[type-arg] - root: list[Entity] # type: ignore[name-defined] # noqa: F821 + root: list[Entity] + _provider: BaseProvider = PrivateAttr(default_factory=JsonProvider) def checkpoint(self, directory: str) -> str: - """Write a checkpoint file to the directory.""" - from crewai.context import capture_execution_context + """Write a checkpoint file to the directory. - for entity in self.root: - entity.execution_context = capture_execution_context() - _sync_checkpoint_fields(entity) + Args: + directory: Filesystem path where the checkpoint JSON will be saved. - dir_path = Path(directory) - dir_path.mkdir(parents=True, exist_ok=True) + Returns: + A location identifier for the saved checkpoint. + """ + _prepare_entities(self.root) + return self._provider.checkpoint(self.model_dump_json(), directory) - ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") - filename = f"{ts}_{uuid.uuid4().hex[:8]}.json" - file_path = dir_path / filename - file_path.write_text(self.model_dump_json()) - return str(file_path) + async def acheckpoint(self, directory: str) -> str: + """Async version of :meth:`checkpoint`. + + Args: + directory: Filesystem path where the checkpoint JSON will be saved. + + Returns: + A location identifier for the saved checkpoint. + """ + _prepare_entities(self.root) + return await self._provider.acheckpoint(self.model_dump_json(), directory) + + +def _prepare_entities(root: list[Entity]) -> None: + """Capture execution context and sync checkpoint fields on each entity. + + Args: + root: List of entities to prepare for serialization. + """ + for entity in root: + entity.execution_context = capture_execution_context() + _sync_checkpoint_fields(entity) diff --git a/uv.lock b/uv.lock index 3179fb67d..8445b9b0d 100644 --- a/uv.lock +++ b/uv.lock @@ -1138,6 +1138,7 @@ wheels = [ name = "crewai" source = { editable = "lib/crewai" } dependencies = [ + { name = "aiofiles" }, { name = "aiosqlite" }, { name = "appdirs" }, { name = "chromadb" }, @@ -1234,6 +1235,7 @@ requires-dist = [ { name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" }, { name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" }, { name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" }, + { name = "aiofiles", specifier = "~=24.1.0" }, { name = "aiosqlite", specifier = "~=0.21.0" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" }, { name = "appdirs", specifier = "~=1.4.4" },