From 7c3b987037eada482d8e9ab30321ae68d8655989 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 7 Apr 2026 02:14:03 +0800 Subject: [PATCH] fix: make BaseTool survive checkpoint JSON round-trip BaseTool could not serialize to JSON because args_schema (a class reference) and cache_function (a lambda) are not JSON-serializable. This caused checkpointing to crash for any crew with tools. - Add PlainSerializer to args_schema so it round-trips via JSON schema - Replace default cache_function lambda with named _default_cache_function and type it as SerializableCallable so it serializes to a dotted path - Add computed_field tool_type that stores the fully qualified class name - Add restore_tool_from_dict to reconstruct the concrete subclass from checkpoint dicts, pre-resolving callback strings to callables - Update BaseAgent.validate_tools and Task._restore_tools_from_checkpoint to handle dict inputs from checkpoint deserialization --- .../crewai/agents/agent_builder/base_agent.py | 11 +-- lib/crewai/src/crewai/task.py | 17 ++++- lib/crewai/src/crewai/tools/base_tool.py | 68 +++++++++++++++++-- 3 files changed, 85 insertions(+), 11 deletions(-) diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index cfa08bbc3..e80e190f1 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -39,7 +39,7 @@ from crewai.memory.unified_memory import Memory from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig from crewai.skills.models import Skill -from crewai.tools.base_tool import BaseTool, Tool +from crewai.tools.base_tool import BaseTool, Tool, restore_tool_from_dict from crewai.types.callback import SerializableCallable from crewai.utilities.config import process_config from crewai.utilities.i18n import I18N, get_i18n @@ -361,13 +361,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): def process_model_config(cls, values: Any) -> dict[str, Any]: return process_config(values, cls) - @field_validator("tools") + @field_validator("tools", mode="before") @classmethod def validate_tools(cls, tools: list[Any]) -> list[BaseTool]: """Validate and process the tools provided to the agent. - This method ensures that each tool is either an instance of BaseTool - or an object with 'name', 'func', and 'description' attributes. If the + This method ensures that each tool is either an instance of BaseTool, + a dict from checkpoint deserialization with a ``tool_type`` key, or an + object with 'name', 'func', and 'description' attributes. If the tool meets these criteria, it is processed and added to the list of tools. Otherwise, a ValueError is raised. """ @@ -379,6 +380,8 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): for tool in tools: if isinstance(tool, BaseTool): processed_tools.append(tool) + elif isinstance(tool, dict) and "tool_type" in tool: + processed_tools.append(restore_tool_from_dict(tool)) elif all(hasattr(tool, attr) for attr in required_attrs): # Tool has the required attributes, create a Tool instance processed_tools.append(Tool.from_langchain(tool)) diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 73e49ade9..4fb5aad04 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -48,7 +48,7 @@ from crewai.llms.base_llm import BaseLLM from crewai.security import Fingerprint, SecurityConfig from crewai.tasks.output_format import OutputFormat from crewai.tasks.task_output import TaskOutput -from crewai.tools.base_tool import BaseTool +from crewai.tools.base_tool import BaseTool, restore_tool_from_dict from crewai.utilities.config import process_config from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified from crewai.utilities.converter import Converter, convert_to_model @@ -237,6 +237,21 @@ class Task(BaseModel): _thread: threading.Thread | None = PrivateAttr(default=None) model_config = {"arbitrary_types_allowed": True} + @field_validator("tools", mode="before") + @classmethod + def _restore_tools_from_checkpoint( + cls, tools: list[Any] | None + ) -> list[Any] | None: + if not tools: + return tools + restored: list[Any] = [] + for tool in tools: + if isinstance(tool, dict) and "tool_type" in tool: + restored.append(restore_tool_from_dict(tool)) + else: + restored.append(tool) + return restored + @field_validator("guardrail") @classmethod def validate_guardrail_function( diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 118fa307b..68260c8f1 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -3,10 +3,12 @@ from __future__ import annotations from abc import ABC, abstractmethod import asyncio from collections.abc import Awaitable, Callable +import importlib from inspect import Parameter, signature import json import threading from typing import ( + Annotated, Any, Generic, ParamSpec, @@ -19,13 +21,21 @@ from pydantic import ( BaseModel as PydanticBaseModel, ConfigDict, Field, + PlainSerializer, PrivateAttr, + computed_field, create_model, field_validator, ) from typing_extensions import TypeIs -from crewai.tools.structured_tool import CrewStructuredTool, build_schema_hint +from crewai.tools.structured_tool import ( + CrewStructuredTool, + _deserialize_schema, + _serialize_schema, + build_schema_hint, +) +from crewai.types.callback import SerializableCallable, _resolve_dotted_path from crewai.utilities.printer import Printer from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.string_utils import sanitize_tool_name @@ -37,6 +47,37 @@ P = ParamSpec("P") R = TypeVar("R", covariant=True) +def restore_tool_from_dict(data: dict[str, Any]) -> BaseTool: + """Reconstruct a concrete ``BaseTool`` subclass from a checkpoint dict. + + The dict must contain a ``tool_type`` key holding the fully qualified + class name (e.g. ``crewai_tools.tools.serper_dev_tool.SerperDevTool``). + The class is imported and instantiated with the remaining fields so that + its ``_run`` method is available. + """ + + data = dict(data) # avoid mutating caller + dotted = data.pop("tool_type") + mod_path, cls_name = dotted.rsplit(".", 1) + cls = getattr(importlib.import_module(mod_path), cls_name) + + for key in ("cache_function",): + val = data.get(key) + if isinstance(val, str): + try: + data[key] = _resolve_dotted_path(val) + except (ValueError, ImportError): + data.pop(key) + + result: BaseTool = cls(**data) + return result + + +def _default_cache_function(_args: Any = None, _result: Any = None) -> bool: + """Default cache function that always allows caching.""" + return True + + def _is_async_callable(func: Callable[..., Any]) -> bool: """Check if a callable is async.""" return asyncio.iscoroutinefunction(func) @@ -70,7 +111,10 @@ class BaseTool(BaseModel, ABC): default_factory=list, description="List of environment variables used by the tool.", ) - args_schema: type[PydanticBaseModel] = Field( + args_schema: Annotated[ + type[PydanticBaseModel], + PlainSerializer(_serialize_schema, return_type=dict | None, when_used="json"), + ] = Field( default=_ArgsSchemaPlaceholder, validate_default=True, description="The schema for the arguments that the tool accepts.", @@ -80,8 +124,8 @@ class BaseTool(BaseModel, ABC): default=False, description="Flag to check if the description has been updated." ) - cache_function: Callable[..., bool] = Field( - default=lambda _args=None, _result=None: True, + cache_function: SerializableCallable = Field( + default=_default_cache_function, description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.", ) result_as_answer: bool = Field( @@ -98,12 +142,24 @@ class BaseTool(BaseModel, ABC): ) _usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + @computed_field # type: ignore[prop-decorator] + @property + def tool_type(self) -> str: + cls = type(self) + return f"{cls.__module__}.{cls.__qualname__}" + @field_validator("args_schema", mode="before") @classmethod def _default_args_schema( - cls, v: type[PydanticBaseModel] + cls, v: type[PydanticBaseModel] | dict[str, Any] | None ) -> type[PydanticBaseModel]: - if v != cls._ArgsSchemaPlaceholder: + if isinstance(v, dict): + restored = _deserialize_schema(v) + if restored is not None: + return restored + if v is None or v == cls._ArgsSchemaPlaceholder: + pass # fall through to generate from signature + elif isinstance(v, type): return v run_sig = signature(cls._run)