fix: make BaseTool survive checkpoint JSON round-trip

BaseTool could not serialize to JSON because args_schema (a class
reference) and cache_function (a lambda) are not JSON-serializable.
This caused checkpointing to crash for any crew with tools.

- Add PlainSerializer to args_schema so it round-trips via JSON schema
- Replace default cache_function lambda with named _default_cache_function
  and type it as SerializableCallable
- Add computed_field tool_type storing the fully qualified class name
- Add __init_subclass__ registry and __get_pydantic_core_schema__ on
  BaseTool so any list[BaseTool] field automatically dispatches to the
  concrete subclass during deserialization via tool_type lookup
- No changes needed to BaseAgent.validate_tools or Task — Pydantic
  handles it natively through the custom core schema
This commit is contained in:
Greyson LaLonde
2026-04-07 02:27:21 +08:00
parent 7c3b987037
commit ac0a4b2bd9
3 changed files with 58 additions and 36 deletions

View File

@@ -39,7 +39,7 @@ from crewai.memory.unified_memory import Memory
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.security_config import SecurityConfig
from crewai.skills.models import Skill
from crewai.tools.base_tool import BaseTool, Tool, restore_tool_from_dict
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.callback import SerializableCallable
from crewai.utilities.config import process_config
from crewai.utilities.i18n import I18N, get_i18n
@@ -361,14 +361,13 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
def process_model_config(cls, values: Any) -> dict[str, Any]:
return process_config(values, cls)
@field_validator("tools", mode="before")
@field_validator("tools")
@classmethod
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool,
a dict from checkpoint deserialization with a ``tool_type`` key, or an
object with 'name', 'func', and 'description' attributes. If the
This method ensures that each tool is either an instance of BaseTool
or an object with 'name', 'func', and 'description' attributes. If the
tool meets these criteria, it is processed and added to the list of
tools. Otherwise, a ValueError is raised.
"""
@@ -380,8 +379,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
for tool in tools:
if isinstance(tool, BaseTool):
processed_tools.append(tool)
elif isinstance(tool, dict) and "tool_type" in tool:
processed_tools.append(restore_tool_from_dict(tool))
elif all(hasattr(tool, attr) for attr in required_attrs):
# Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool))

View File

@@ -48,7 +48,7 @@ from crewai.llms.base_llm import BaseLLM
from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool, restore_tool_from_dict
from crewai.tools.base_tool import BaseTool
from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
from crewai.utilities.converter import Converter, convert_to_model
@@ -237,21 +237,6 @@ class Task(BaseModel):
_thread: threading.Thread | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}
@field_validator("tools", mode="before")
@classmethod
def _restore_tools_from_checkpoint(
cls, tools: list[Any] | None
) -> list[Any] | None:
if not tools:
return tools
restored: list[Any] = []
for tool in tools:
if isinstance(tool, dict) and "tool_type" in tool:
restored.append(restore_tool_from_dict(tool))
else:
restored.append(tool)
return restored
@field_validator("guardrail")
@classmethod
def validate_guardrail_function(

View File

@@ -21,12 +21,14 @@ from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
PlainSerializer,
PrivateAttr,
computed_field,
create_model,
field_validator,
)
from pydantic_core import CoreSchema, core_schema
from typing_extensions import TypeIs
from crewai.tools.structured_tool import (
@@ -46,21 +48,27 @@ _printer = Printer()
P = ParamSpec("P")
R = TypeVar("R", covariant=True)
# Registry populated by BaseTool.__init_subclass__; used for checkpoint
# deserialization so that list[BaseTool] fields resolve the concrete class.
_TOOL_TYPE_REGISTRY: dict[str, type] = {}
def restore_tool_from_dict(data: dict[str, Any]) -> BaseTool:
"""Reconstruct a concrete ``BaseTool`` subclass from a checkpoint dict.
# Sentinel set after BaseTool is defined so __get_pydantic_core_schema__
# can distinguish the base class from subclasses despite
# ``from __future__ import annotations``.
_BASE_TOOL_CLS: type | None = None
The dict must contain a ``tool_type`` key holding the fully qualified
class name (e.g. ``crewai_tools.tools.serper_dev_tool.SerperDevTool``).
The class is imported and instantiated with the remaining fields so that
its ``_run`` method is available.
"""
data = dict(data) # avoid mutating caller
dotted = data.pop("tool_type")
mod_path, cls_name = dotted.rsplit(".", 1)
cls = getattr(importlib.import_module(mod_path), cls_name)
def _resolve_tool_dict(value: dict[str, Any]) -> Any:
"""Validate a dict with ``tool_type`` into the concrete BaseTool subclass."""
dotted = value.get("tool_type", "")
tool_cls = _TOOL_TYPE_REGISTRY.get(dotted)
if tool_cls is None:
mod_path, cls_name = dotted.rsplit(".", 1)
tool_cls = getattr(importlib.import_module(mod_path), cls_name)
# Pre-resolve serialized callback strings so SerializableCallable's
# BeforeValidator sees a callable and skips the env-var guard.
data = dict(value)
for key in ("cache_function",):
val = data.get(key)
if isinstance(val, str):
@@ -69,8 +77,7 @@ def restore_tool_from_dict(data: dict[str, Any]) -> BaseTool:
except (ValueError, ImportError):
data.pop(key)
result: BaseTool = cls(**data)
return result
return tool_cls.model_validate(data) # type: ignore[union-attr]
def _default_cache_function(_args: Any = None, _result: Any = None) -> bool:
@@ -101,6 +108,36 @@ class BaseTool(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
key = f"{cls.__module__}.{cls.__qualname__}"
_TOOL_TYPE_REGISTRY[key] = cls
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
default_schema = handler(source_type)
if cls is not _BASE_TOOL_CLS:
return default_schema
def _validate_tool(value: Any, nxt: Any) -> Any:
if isinstance(value, _BASE_TOOL_CLS):
return value
if isinstance(value, dict) and "tool_type" in value:
return _resolve_tool_dict(value)
return nxt(value)
return core_schema.no_info_wrap_validator_function(
_validate_tool,
default_schema,
serialization=core_schema.plain_serializer_function_ser_schema(
lambda v: v.model_dump(mode="json"),
info_arg=False,
when_used="json",
),
)
name: str = Field(
description="The unique name of the tool that clearly communicates its purpose."
)
@@ -421,6 +458,9 @@ class BaseTool(BaseModel, ABC):
)
_BASE_TOOL_CLS = BaseTool
class Tool(BaseTool, Generic[P, R]):
"""Tool that wraps a callable function.