refactor: dedupe checkpoint helpers and tighten state type hints

This commit is contained in:
Greyson LaLonde
2026-04-23 19:29:04 +08:00
committed by GitHub
parent 55937d7523
commit 69461076df
3 changed files with 93 additions and 87 deletions

View File

@@ -63,12 +63,26 @@ def _resolve(value: CheckpointConfig | bool | None) -> CheckpointConfig | None |
if isinstance(value, CheckpointConfig):
_ensure_handlers_registered()
return value
if value is True:
if value:
_ensure_handlers_registered()
return CheckpointConfig()
if value is False:
return _SENTINEL
return None # None = inherit
return None
def _resolve_from_agent(agent: BaseAgent) -> CheckpointConfig | None:
"""Resolve a checkpoint config starting from an agent, walking to its crew."""
result = _resolve(agent.checkpoint)
if isinstance(result, CheckpointConfig):
return result
if result is _SENTINEL:
return None
crew = agent.crew
if isinstance(crew, Crew):
crew_result = _resolve(crew.checkpoint)
return crew_result if isinstance(crew_result, CheckpointConfig) else None
return None
def _find_checkpoint(source: Any) -> CheckpointConfig | None:
@@ -87,28 +101,11 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
result = _resolve(source.checkpoint)
return result if isinstance(result, CheckpointConfig) else None
if isinstance(source, BaseAgent):
result = _resolve(source.checkpoint)
if isinstance(result, CheckpointConfig):
return result
if result is _SENTINEL:
return None
crew = source.crew
if isinstance(crew, Crew):
result = _resolve(crew.checkpoint)
return result if isinstance(result, CheckpointConfig) else None
return None
return _resolve_from_agent(source)
if isinstance(source, Task):
agent = source.agent
if isinstance(agent, BaseAgent):
result = _resolve(agent.checkpoint)
if isinstance(result, CheckpointConfig):
return result
if result is _SENTINEL:
return None
crew = agent.crew
if isinstance(crew, Crew):
result = _resolve(crew.checkpoint)
return result if isinstance(result, CheckpointConfig) else None
return _resolve_from_agent(agent)
return None
return None
@@ -255,7 +252,8 @@ def _register_all_handlers(event_bus: CrewAIEventsBus) -> None:
seen: set[type] = set()
def _collect(cls: type[BaseEvent]) -> None:
for sub in cls.__subclasses__():
subclasses: list[type[BaseEvent]] = cls.__subclasses__()
for sub in subclasses:
if sub not in seen:
seen.add(sub)
type_field = sub.model_fields.get("type")

View File

@@ -39,7 +39,8 @@ def _build_event_type_map() -> None:
"""Populate _event_type_map from all BaseEvent subclasses."""
def _collect(cls: type[BaseEvent]) -> None:
for sub in cls.__subclasses__():
subclasses: list[type[BaseEvent]] = cls.__subclasses__()
for sub in subclasses:
type_field = sub.model_fields.get("type")
if type_field and type_field.default:
_event_type_map[type_field.default] = sub

View File

@@ -101,7 +101,7 @@ def _migrate(data: dict[str, Any]) -> dict[str, Any]:
"""
raw = data.get("crewai_version")
current = Version(get_crewai_version())
stored = Version(raw) if raw else Version("0.0.0")
stored = Version(raw) if isinstance(raw, str) and raw else Version("0.0.0")
if raw is None:
logger.warning("Checkpoint has no crewai_version — treating as 0.0.0")
@@ -171,16 +171,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
self._checkpoint_id = provider.extract_id(location)
self._parent_id = self._checkpoint_id
def checkpoint(self, location: str) -> str:
"""Write a checkpoint.
Args:
location: Storage destination. For JsonProvider this is a directory
path; for SqliteProvider it is a database file path.
Returns:
A location identifier for the saved checkpoint.
"""
def _begin_checkpoint(self, location: str) -> tuple[str, str | None, str, float]:
"""Emit the start event and return the invariant context for a checkpoint."""
provider_name: str = type(self._provider).__name__
parent_id_snapshot: str | None = self._parent_id
branch_snapshot: str = self._branch
@@ -193,29 +185,37 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
parent_id=parent_id_snapshot,
),
)
start: float = time.perf_counter()
try:
_prepare_entities(self.root)
result = self._provider.checkpoint(
self.model_dump_json(),
location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
)
self._chain_lineage(self._provider, result)
except Exception as exc:
crewai_event_bus.emit(
self,
CheckpointFailedEvent(
location=location,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
error=str(exc),
),
)
raise
return provider_name, parent_id_snapshot, branch_snapshot, time.perf_counter()
def _emit_checkpoint_failed(
self,
location: str,
provider_name: str,
branch_snapshot: str,
parent_id_snapshot: str | None,
exc: Exception,
) -> None:
"""Emit the failure event for a checkpoint write."""
crewai_event_bus.emit(
self,
CheckpointFailedEvent(
location=location,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
error=str(exc),
),
)
def _emit_checkpoint_completed(
self,
result: str,
provider_name: str,
branch_snapshot: str,
parent_id_snapshot: str | None,
start: float,
) -> None:
"""Emit the completion event for a successful checkpoint write."""
crewai_event_bus.emit(
self,
CheckpointCompletedEvent(
@@ -227,6 +227,38 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
duration_ms=(time.perf_counter() - start) * 1000.0,
),
)
def checkpoint(self, location: str) -> str:
"""Write a checkpoint.
Args:
location: Storage destination. For JsonProvider this is a directory
path; for SqliteProvider it is a database file path.
Returns:
A location identifier for the saved checkpoint.
"""
provider_name, parent_id_snapshot, branch_snapshot, start = (
self._begin_checkpoint(location)
)
try:
_prepare_entities(self.root)
result = self._provider.checkpoint(
self.model_dump_json(),
location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
)
self._chain_lineage(self._provider, result)
except Exception as exc:
self._emit_checkpoint_failed(
location, provider_name, branch_snapshot, parent_id_snapshot, exc
)
raise
self._emit_checkpoint_completed(
result, provider_name, branch_snapshot, parent_id_snapshot, start
)
return result
async def acheckpoint(self, location: str) -> str:
@@ -239,19 +271,9 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
Returns:
A location identifier for the saved checkpoint.
"""
provider_name: str = type(self._provider).__name__
parent_id_snapshot: str | None = self._parent_id
branch_snapshot: str = self._branch
crewai_event_bus.emit(
self,
CheckpointStartedEvent(
location=location,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
),
provider_name, parent_id_snapshot, branch_snapshot, start = (
self._begin_checkpoint(location)
)
start: float = time.perf_counter()
try:
_prepare_entities(self.root)
result = await self._provider.acheckpoint(
@@ -262,28 +284,13 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
)
self._chain_lineage(self._provider, result)
except Exception as exc:
crewai_event_bus.emit(
self,
CheckpointFailedEvent(
location=location,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
error=str(exc),
),
self._emit_checkpoint_failed(
location, provider_name, branch_snapshot, parent_id_snapshot, exc
)
raise
crewai_event_bus.emit(
self,
CheckpointCompletedEvent(
location=result,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
checkpoint_id=self._provider.extract_id(result),
duration_ms=(time.perf_counter() - start) * 1000.0,
),
self._emit_checkpoint_completed(
result, provider_name, branch_snapshot, parent_id_snapshot, start
)
return result