mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 21:28:10 +00:00
refactor: move RuntimeState to state/, add async checkpoint with provider pattern
- Move runtime_state.py to state/runtime.py - Add acheckpoint async method using aiofiles - Introduce BaseProvider protocol and JsonProvider for pluggable storage - Add aiofiles dependency to crewai package - Use PrivateAttr for provider on RootModel
This commit is contained in:
@@ -43,6 +43,7 @@ dependencies = [
|
||||
"uv~=0.9.13",
|
||||
"aiosqlite~=0.21.0",
|
||||
"pyyaml~=6.0",
|
||||
"aiofiles~=24.1.0",
|
||||
"lancedb>=0.29.2,<0.30.1",
|
||||
]
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ try:
|
||||
**sys.modules[_BaseAgent.__module__].__dict__,
|
||||
}
|
||||
|
||||
import crewai.runtime_state as _runtime_state_mod
|
||||
import crewai.state.runtime as _runtime_state_mod
|
||||
|
||||
for _mod_name in (
|
||||
_BaseAgent.__module__,
|
||||
@@ -193,7 +193,7 @@ try:
|
||||
|
||||
from pydantic import Discriminator, Tag
|
||||
|
||||
from crewai.runtime_state import RuntimeState, _entity_discriminator
|
||||
from crewai.state.runtime import RuntimeState, _entity_discriminator
|
||||
|
||||
Entity = Annotated[
|
||||
Annotated[Flow, Tag("flow")] # type: ignore[type-arg]
|
||||
@@ -226,6 +226,7 @@ __all__ = [
|
||||
"BaseLLM",
|
||||
"Crew",
|
||||
"CrewOutput",
|
||||
"Entity",
|
||||
"ExecutionContext",
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
|
||||
@@ -90,7 +90,7 @@ class ExecutionContext(BaseModel):
|
||||
flow_id: str | None = Field(default=None)
|
||||
flow_method_name: str = Field(default="unknown")
|
||||
|
||||
event_id_stack: tuple[tuple[str, str], ...] = Field(default=())
|
||||
event_id_stack: tuple[tuple[str, str], ...] = Field(default_factory=tuple)
|
||||
last_event_id: str | None = Field(default=None)
|
||||
triggering_event_id: str | None = Field(default=None)
|
||||
emission_sequence: int = Field(default=0)
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.runtime_state import RuntimeState
|
||||
from crewai.state.runtime import RuntimeState
|
||||
|
||||
from crewai.events.base_events import BaseEvent, get_next_emission_sequence
|
||||
from crewai.events.depends import Depends
|
||||
|
||||
0
lib/crewai/src/crewai/state/__init__.py
Normal file
0
lib/crewai/src/crewai/state/__init__.py
Normal file
0
lib/crewai/src/crewai/state/provider/__init__.py
Normal file
0
lib/crewai/src/crewai/state/provider/__init__.py
Normal file
59
lib/crewai/src/crewai/state/provider/core.py
Normal file
59
lib/crewai/src/crewai/state/provider/core.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Base protocol for state providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseProvider(Protocol):
|
||||
"""Interface for persisting and restoring runtime state checkpoints.
|
||||
|
||||
Implementations handle the storage backend (filesystem, cloud, database,
|
||||
etc.) while ``RuntimeState`` handles serialization.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Allow Pydantic to validate any ``BaseProvider`` instance."""
|
||||
|
||||
def _validate(v: Any) -> BaseProvider:
|
||||
if isinstance(v, BaseProvider):
|
||||
return v
|
||||
raise TypeError(f"Expected a BaseProvider instance, got {type(v)}")
|
||||
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
_validate,
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda v: type(v).__name__, info_arg=False
|
||||
),
|
||||
)
|
||||
|
||||
def checkpoint(self, data: str, directory: str) -> str:
|
||||
"""Persist a snapshot synchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized string to persist.
|
||||
directory: Logical destination (path, bucket prefix, etc.).
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint (e.g. file path, URI).
|
||||
"""
|
||||
...
|
||||
|
||||
async def acheckpoint(self, data: str, directory: str) -> str:
|
||||
"""Persist a snapshot asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized string to persist.
|
||||
directory: Logical destination (path, bucket prefix, etc.).
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint (e.g. file path, URI).
|
||||
"""
|
||||
...
|
||||
64
lib/crewai/src/crewai/state/provider/json_provider.py
Normal file
64
lib/crewai/src/crewai/state/provider/json_provider.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Filesystem JSON state provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
|
||||
|
||||
class JsonProvider(BaseProvider):
|
||||
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
|
||||
|
||||
def checkpoint(self, data: str, directory: str) -> str:
|
||||
"""Write a JSON checkpoint file to the directory.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
directory: Filesystem path where the checkpoint will be saved.
|
||||
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(directory)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
async def acheckpoint(self, data: str, directory: str) -> str:
|
||||
"""Write a JSON checkpoint file to the directory asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
directory: Filesystem path where the checkpoint will be saved.
|
||||
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(directory)
|
||||
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "w") as f:
|
||||
await f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
|
||||
def _build_path(directory: str) -> Path:
|
||||
"""Build a timestamped checkpoint file path.
|
||||
|
||||
Args:
|
||||
directory: Parent directory for the checkpoint file.
|
||||
|
||||
Returns:
|
||||
The target file path.
|
||||
"""
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
||||
filename = f"{ts}_{uuid.uuid4().hex[:8]}.json"
|
||||
return Path(directory) / filename
|
||||
@@ -9,16 +9,17 @@ via ``RuntimeState.model_rebuild()``.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from pydantic import RootModel
|
||||
from pydantic import PrivateAttr, RootModel
|
||||
|
||||
from crewai.context import capture_execution_context
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from crewai import Entity
|
||||
|
||||
|
||||
def _entity_discriminator(v: dict[str, Any] | object) -> str:
|
||||
@@ -30,7 +31,12 @@ def _entity_discriminator(v: dict[str, Any] | object) -> str:
|
||||
|
||||
|
||||
def _sync_checkpoint_fields(entity: object) -> None:
|
||||
"""Copy private runtime attrs into checkpoint fields before serializing."""
|
||||
"""Copy private runtime attrs into checkpoint fields before serializing.
|
||||
|
||||
Args:
|
||||
entity: The entity whose private runtime attributes will be
|
||||
copied into its public checkpoint fields.
|
||||
"""
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
@@ -56,21 +62,40 @@ def _sync_checkpoint_fields(entity: object) -> None:
|
||||
|
||||
|
||||
class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
root: list[Entity] # type: ignore[name-defined] # noqa: F821
|
||||
root: list[Entity]
|
||||
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)
|
||||
|
||||
def checkpoint(self, directory: str) -> str:
|
||||
"""Write a checkpoint file to the directory."""
|
||||
from crewai.context import capture_execution_context
|
||||
"""Write a checkpoint file to the directory.
|
||||
|
||||
for entity in self.root:
|
||||
entity.execution_context = capture_execution_context()
|
||||
_sync_checkpoint_fields(entity)
|
||||
Args:
|
||||
directory: Filesystem path where the checkpoint JSON will be saved.
|
||||
|
||||
dir_path = Path(directory)
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
return self._provider.checkpoint(self.model_dump_json(), directory)
|
||||
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
||||
filename = f"{ts}_{uuid.uuid4().hex[:8]}.json"
|
||||
file_path = dir_path / filename
|
||||
file_path.write_text(self.model_dump_json())
|
||||
return str(file_path)
|
||||
async def acheckpoint(self, directory: str) -> str:
|
||||
"""Async version of :meth:`checkpoint`.
|
||||
|
||||
Args:
|
||||
directory: Filesystem path where the checkpoint JSON will be saved.
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
return await self._provider.acheckpoint(self.model_dump_json(), directory)
|
||||
|
||||
|
||||
def _prepare_entities(root: list[Entity]) -> None:
|
||||
"""Capture execution context and sync checkpoint fields on each entity.
|
||||
|
||||
Args:
|
||||
root: List of entities to prepare for serialization.
|
||||
"""
|
||||
for entity in root:
|
||||
entity.execution_context = capture_execution_context()
|
||||
_sync_checkpoint_fields(entity)
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -1138,6 +1138,7 @@ wheels = [
|
||||
name = "crewai"
|
||||
source = { editable = "lib/crewai" }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "appdirs" },
|
||||
{ name = "chromadb" },
|
||||
@@ -1234,6 +1235,7 @@ requires-dist = [
|
||||
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
|
||||
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
||||
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
||||
{ name = "aiofiles", specifier = "~=24.1.0" },
|
||||
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
|
||||
{ name = "appdirs", specifier = "~=1.4.4" },
|
||||
|
||||
Reference in New Issue
Block a user