From 69461076df442d9e660b99d24e9c19ad59578301 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 23 Apr 2026 19:29:04 +0800 Subject: [PATCH] refactor: dedupe checkpoint helpers and tighten state type hints --- .../src/crewai/state/checkpoint_listener.py | 42 +++--- lib/crewai/src/crewai/state/event_record.py | 3 +- lib/crewai/src/crewai/state/runtime.py | 135 +++++++++--------- 3 files changed, 93 insertions(+), 87 deletions(-) diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py index f9634005e..0c2adc127 100644 --- a/lib/crewai/src/crewai/state/checkpoint_listener.py +++ b/lib/crewai/src/crewai/state/checkpoint_listener.py @@ -63,12 +63,26 @@ def _resolve(value: CheckpointConfig | bool | None) -> CheckpointConfig | None | if isinstance(value, CheckpointConfig): _ensure_handlers_registered() return value - if value is True: + if value: _ensure_handlers_registered() return CheckpointConfig() if value is False: return _SENTINEL - return None # None = inherit + return None + + +def _resolve_from_agent(agent: BaseAgent) -> CheckpointConfig | None: + """Resolve a checkpoint config starting from an agent, walking to its crew.""" + result = _resolve(agent.checkpoint) + if isinstance(result, CheckpointConfig): + return result + if result is _SENTINEL: + return None + crew = agent.crew + if isinstance(crew, Crew): + crew_result = _resolve(crew.checkpoint) + return crew_result if isinstance(crew_result, CheckpointConfig) else None + return None def _find_checkpoint(source: Any) -> CheckpointConfig | None: @@ -87,28 +101,11 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None: result = _resolve(source.checkpoint) return result if isinstance(result, CheckpointConfig) else None if isinstance(source, BaseAgent): - result = _resolve(source.checkpoint) - if isinstance(result, CheckpointConfig): - return result - if result is _SENTINEL: - return None - crew = source.crew - if isinstance(crew, Crew): - result = _resolve(crew.checkpoint) - return result if isinstance(result, CheckpointConfig) else None - return None + return _resolve_from_agent(source) if isinstance(source, Task): agent = source.agent if isinstance(agent, BaseAgent): - result = _resolve(agent.checkpoint) - if isinstance(result, CheckpointConfig): - return result - if result is _SENTINEL: - return None - crew = agent.crew - if isinstance(crew, Crew): - result = _resolve(crew.checkpoint) - return result if isinstance(result, CheckpointConfig) else None + return _resolve_from_agent(agent) return None return None @@ -255,7 +252,8 @@ def _register_all_handlers(event_bus: CrewAIEventsBus) -> None: seen: set[type] = set() def _collect(cls: type[BaseEvent]) -> None: - for sub in cls.__subclasses__(): + subclasses: list[type[BaseEvent]] = cls.__subclasses__() + for sub in subclasses: if sub not in seen: seen.add(sub) type_field = sub.model_fields.get("type") diff --git a/lib/crewai/src/crewai/state/event_record.py b/lib/crewai/src/crewai/state/event_record.py index 7b8c20c5b..866398e0a 100644 --- a/lib/crewai/src/crewai/state/event_record.py +++ b/lib/crewai/src/crewai/state/event_record.py @@ -39,7 +39,8 @@ def _build_event_type_map() -> None: """Populate _event_type_map from all BaseEvent subclasses.""" def _collect(cls: type[BaseEvent]) -> None: - for sub in cls.__subclasses__(): + subclasses: list[type[BaseEvent]] = cls.__subclasses__() + for sub in subclasses: type_field = sub.model_fields.get("type") if type_field and type_field.default: _event_type_map[type_field.default] = sub diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index a815845d6..471107997 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -101,7 +101,7 @@ def _migrate(data: dict[str, Any]) -> dict[str, Any]: """ raw = data.get("crewai_version") current = Version(get_crewai_version()) - stored = Version(raw) if raw else Version("0.0.0") + stored = Version(raw) if isinstance(raw, str) and raw else Version("0.0.0") if raw is None: logger.warning("Checkpoint has no crewai_version — treating as 0.0.0") @@ -171,16 +171,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg] self._checkpoint_id = provider.extract_id(location) self._parent_id = self._checkpoint_id - def checkpoint(self, location: str) -> str: - """Write a checkpoint. - - Args: - location: Storage destination. For JsonProvider this is a directory - path; for SqliteProvider it is a database file path. - - Returns: - A location identifier for the saved checkpoint. - """ + def _begin_checkpoint(self, location: str) -> tuple[str, str | None, str, float]: + """Emit the start event and return the invariant context for a checkpoint.""" provider_name: str = type(self._provider).__name__ parent_id_snapshot: str | None = self._parent_id branch_snapshot: str = self._branch @@ -193,29 +185,37 @@ class RuntimeState(RootModel): # type: ignore[type-arg] parent_id=parent_id_snapshot, ), ) - start: float = time.perf_counter() - try: - _prepare_entities(self.root) - result = self._provider.checkpoint( - self.model_dump_json(), - location, - parent_id=parent_id_snapshot, - branch=branch_snapshot, - ) - self._chain_lineage(self._provider, result) - except Exception as exc: - crewai_event_bus.emit( - self, - CheckpointFailedEvent( - location=location, - provider=provider_name, - branch=branch_snapshot, - parent_id=parent_id_snapshot, - error=str(exc), - ), - ) - raise + return provider_name, parent_id_snapshot, branch_snapshot, time.perf_counter() + def _emit_checkpoint_failed( + self, + location: str, + provider_name: str, + branch_snapshot: str, + parent_id_snapshot: str | None, + exc: Exception, + ) -> None: + """Emit the failure event for a checkpoint write.""" + crewai_event_bus.emit( + self, + CheckpointFailedEvent( + location=location, + provider=provider_name, + branch=branch_snapshot, + parent_id=parent_id_snapshot, + error=str(exc), + ), + ) + + def _emit_checkpoint_completed( + self, + result: str, + provider_name: str, + branch_snapshot: str, + parent_id_snapshot: str | None, + start: float, + ) -> None: + """Emit the completion event for a successful checkpoint write.""" crewai_event_bus.emit( self, CheckpointCompletedEvent( @@ -227,6 +227,38 @@ class RuntimeState(RootModel): # type: ignore[type-arg] duration_ms=(time.perf_counter() - start) * 1000.0, ), ) + + def checkpoint(self, location: str) -> str: + """Write a checkpoint. + + Args: + location: Storage destination. For JsonProvider this is a directory + path; for SqliteProvider it is a database file path. + + Returns: + A location identifier for the saved checkpoint. + """ + provider_name, parent_id_snapshot, branch_snapshot, start = ( + self._begin_checkpoint(location) + ) + try: + _prepare_entities(self.root) + result = self._provider.checkpoint( + self.model_dump_json(), + location, + parent_id=parent_id_snapshot, + branch=branch_snapshot, + ) + self._chain_lineage(self._provider, result) + except Exception as exc: + self._emit_checkpoint_failed( + location, provider_name, branch_snapshot, parent_id_snapshot, exc + ) + raise + + self._emit_checkpoint_completed( + result, provider_name, branch_snapshot, parent_id_snapshot, start + ) return result async def acheckpoint(self, location: str) -> str: @@ -239,19 +271,9 @@ class RuntimeState(RootModel): # type: ignore[type-arg] Returns: A location identifier for the saved checkpoint. """ - provider_name: str = type(self._provider).__name__ - parent_id_snapshot: str | None = self._parent_id - branch_snapshot: str = self._branch - crewai_event_bus.emit( - self, - CheckpointStartedEvent( - location=location, - provider=provider_name, - branch=branch_snapshot, - parent_id=parent_id_snapshot, - ), + provider_name, parent_id_snapshot, branch_snapshot, start = ( + self._begin_checkpoint(location) ) - start: float = time.perf_counter() try: _prepare_entities(self.root) result = await self._provider.acheckpoint( @@ -262,28 +284,13 @@ class RuntimeState(RootModel): # type: ignore[type-arg] ) self._chain_lineage(self._provider, result) except Exception as exc: - crewai_event_bus.emit( - self, - CheckpointFailedEvent( - location=location, - provider=provider_name, - branch=branch_snapshot, - parent_id=parent_id_snapshot, - error=str(exc), - ), + self._emit_checkpoint_failed( + location, provider_name, branch_snapshot, parent_id_snapshot, exc ) raise - crewai_event_bus.emit( - self, - CheckpointCompletedEvent( - location=result, - provider=provider_name, - branch=branch_snapshot, - parent_id=parent_id_snapshot, - checkpoint_id=self._provider.extract_id(result), - duration_ms=(time.perf_counter() - start) * 1000.0, - ), + self._emit_checkpoint_completed( + result, provider_name, branch_snapshot, parent_id_snapshot, start ) return result