mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 13:18:10 +00:00
refactor: dedupe checkpoint helpers and tighten state type hints
This commit is contained in:
@@ -63,12 +63,26 @@ def _resolve(value: CheckpointConfig | bool | None) -> CheckpointConfig | None |
|
||||
if isinstance(value, CheckpointConfig):
|
||||
_ensure_handlers_registered()
|
||||
return value
|
||||
if value is True:
|
||||
if value:
|
||||
_ensure_handlers_registered()
|
||||
return CheckpointConfig()
|
||||
if value is False:
|
||||
return _SENTINEL
|
||||
return None # None = inherit
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_from_agent(agent: BaseAgent) -> CheckpointConfig | None:
|
||||
"""Resolve a checkpoint config starting from an agent, walking to its crew."""
|
||||
result = _resolve(agent.checkpoint)
|
||||
if isinstance(result, CheckpointConfig):
|
||||
return result
|
||||
if result is _SENTINEL:
|
||||
return None
|
||||
crew = agent.crew
|
||||
if isinstance(crew, Crew):
|
||||
crew_result = _resolve(crew.checkpoint)
|
||||
return crew_result if isinstance(crew_result, CheckpointConfig) else None
|
||||
return None
|
||||
|
||||
|
||||
def _find_checkpoint(source: Any) -> CheckpointConfig | None:
|
||||
@@ -87,28 +101,11 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
|
||||
result = _resolve(source.checkpoint)
|
||||
return result if isinstance(result, CheckpointConfig) else None
|
||||
if isinstance(source, BaseAgent):
|
||||
result = _resolve(source.checkpoint)
|
||||
if isinstance(result, CheckpointConfig):
|
||||
return result
|
||||
if result is _SENTINEL:
|
||||
return None
|
||||
crew = source.crew
|
||||
if isinstance(crew, Crew):
|
||||
result = _resolve(crew.checkpoint)
|
||||
return result if isinstance(result, CheckpointConfig) else None
|
||||
return None
|
||||
return _resolve_from_agent(source)
|
||||
if isinstance(source, Task):
|
||||
agent = source.agent
|
||||
if isinstance(agent, BaseAgent):
|
||||
result = _resolve(agent.checkpoint)
|
||||
if isinstance(result, CheckpointConfig):
|
||||
return result
|
||||
if result is _SENTINEL:
|
||||
return None
|
||||
crew = agent.crew
|
||||
if isinstance(crew, Crew):
|
||||
result = _resolve(crew.checkpoint)
|
||||
return result if isinstance(result, CheckpointConfig) else None
|
||||
return _resolve_from_agent(agent)
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -255,7 +252,8 @@ def _register_all_handlers(event_bus: CrewAIEventsBus) -> None:
|
||||
seen: set[type] = set()
|
||||
|
||||
def _collect(cls: type[BaseEvent]) -> None:
|
||||
for sub in cls.__subclasses__():
|
||||
subclasses: list[type[BaseEvent]] = cls.__subclasses__()
|
||||
for sub in subclasses:
|
||||
if sub not in seen:
|
||||
seen.add(sub)
|
||||
type_field = sub.model_fields.get("type")
|
||||
|
||||
@@ -39,7 +39,8 @@ def _build_event_type_map() -> None:
|
||||
"""Populate _event_type_map from all BaseEvent subclasses."""
|
||||
|
||||
def _collect(cls: type[BaseEvent]) -> None:
|
||||
for sub in cls.__subclasses__():
|
||||
subclasses: list[type[BaseEvent]] = cls.__subclasses__()
|
||||
for sub in subclasses:
|
||||
type_field = sub.model_fields.get("type")
|
||||
if type_field and type_field.default:
|
||||
_event_type_map[type_field.default] = sub
|
||||
|
||||
@@ -101,7 +101,7 @@ def _migrate(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
raw = data.get("crewai_version")
|
||||
current = Version(get_crewai_version())
|
||||
stored = Version(raw) if raw else Version("0.0.0")
|
||||
stored = Version(raw) if isinstance(raw, str) and raw else Version("0.0.0")
|
||||
|
||||
if raw is None:
|
||||
logger.warning("Checkpoint has no crewai_version — treating as 0.0.0")
|
||||
@@ -171,16 +171,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
self._checkpoint_id = provider.extract_id(location)
|
||||
self._parent_id = self._checkpoint_id
|
||||
|
||||
def checkpoint(self, location: str) -> str:
|
||||
"""Write a checkpoint.
|
||||
|
||||
Args:
|
||||
location: Storage destination. For JsonProvider this is a directory
|
||||
path; for SqliteProvider it is a database file path.
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
def _begin_checkpoint(self, location: str) -> tuple[str, str | None, str, float]:
|
||||
"""Emit the start event and return the invariant context for a checkpoint."""
|
||||
provider_name: str = type(self._provider).__name__
|
||||
parent_id_snapshot: str | None = self._parent_id
|
||||
branch_snapshot: str = self._branch
|
||||
@@ -193,29 +185,37 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
parent_id=parent_id_snapshot,
|
||||
),
|
||||
)
|
||||
start: float = time.perf_counter()
|
||||
try:
|
||||
_prepare_entities(self.root)
|
||||
result = self._provider.checkpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
except Exception as exc:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointFailedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
error=str(exc),
|
||||
),
|
||||
)
|
||||
raise
|
||||
return provider_name, parent_id_snapshot, branch_snapshot, time.perf_counter()
|
||||
|
||||
def _emit_checkpoint_failed(
|
||||
self,
|
||||
location: str,
|
||||
provider_name: str,
|
||||
branch_snapshot: str,
|
||||
parent_id_snapshot: str | None,
|
||||
exc: Exception,
|
||||
) -> None:
|
||||
"""Emit the failure event for a checkpoint write."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointFailedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
error=str(exc),
|
||||
),
|
||||
)
|
||||
|
||||
def _emit_checkpoint_completed(
|
||||
self,
|
||||
result: str,
|
||||
provider_name: str,
|
||||
branch_snapshot: str,
|
||||
parent_id_snapshot: str | None,
|
||||
start: float,
|
||||
) -> None:
|
||||
"""Emit the completion event for a successful checkpoint write."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointCompletedEvent(
|
||||
@@ -227,6 +227,38 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
duration_ms=(time.perf_counter() - start) * 1000.0,
|
||||
),
|
||||
)
|
||||
|
||||
def checkpoint(self, location: str) -> str:
|
||||
"""Write a checkpoint.
|
||||
|
||||
Args:
|
||||
location: Storage destination. For JsonProvider this is a directory
|
||||
path; for SqliteProvider it is a database file path.
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
provider_name, parent_id_snapshot, branch_snapshot, start = (
|
||||
self._begin_checkpoint(location)
|
||||
)
|
||||
try:
|
||||
_prepare_entities(self.root)
|
||||
result = self._provider.checkpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
except Exception as exc:
|
||||
self._emit_checkpoint_failed(
|
||||
location, provider_name, branch_snapshot, parent_id_snapshot, exc
|
||||
)
|
||||
raise
|
||||
|
||||
self._emit_checkpoint_completed(
|
||||
result, provider_name, branch_snapshot, parent_id_snapshot, start
|
||||
)
|
||||
return result
|
||||
|
||||
async def acheckpoint(self, location: str) -> str:
|
||||
@@ -239,19 +271,9 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
provider_name: str = type(self._provider).__name__
|
||||
parent_id_snapshot: str | None = self._parent_id
|
||||
branch_snapshot: str = self._branch
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointStartedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
),
|
||||
provider_name, parent_id_snapshot, branch_snapshot, start = (
|
||||
self._begin_checkpoint(location)
|
||||
)
|
||||
start: float = time.perf_counter()
|
||||
try:
|
||||
_prepare_entities(self.root)
|
||||
result = await self._provider.acheckpoint(
|
||||
@@ -262,28 +284,13 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
except Exception as exc:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointFailedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
error=str(exc),
|
||||
),
|
||||
self._emit_checkpoint_failed(
|
||||
location, provider_name, branch_snapshot, parent_id_snapshot, exc
|
||||
)
|
||||
raise
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointCompletedEvent(
|
||||
location=result,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
checkpoint_id=self._provider.extract_id(result),
|
||||
duration_ms=(time.perf_counter() - start) * 1000.0,
|
||||
),
|
||||
self._emit_checkpoint_completed(
|
||||
result, provider_name, branch_snapshot, parent_id_snapshot, start
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user