fix: make BaseTool survive checkpoint JSON round-trip

BaseTool could not serialize to JSON because args_schema (a class
reference) and cache_function (a lambda) are not JSON-serializable.
This caused checkpointing to crash for any crew with tools.

- Add PlainSerializer to args_schema so it round-trips via JSON schema
- Replace default cache_function lambda with named _default_cache_function
  and type it as SerializableCallable so it serializes to a dotted path
- Add computed_field tool_type that stores the fully qualified class name
- Add restore_tool_from_dict to reconstruct the concrete subclass from
  checkpoint dicts, pre-resolving callback strings to callables
- Update BaseAgent.validate_tools and Task._restore_tools_from_checkpoint
  to handle dict inputs from checkpoint deserialization
This commit is contained in:
Greyson LaLonde
2026-04-07 02:14:03 +08:00
parent 14d5d5bbf8
commit 7c3b987037
3 changed files with 85 additions and 11 deletions

View File

@@ -39,7 +39,7 @@ from crewai.memory.unified_memory import Memory
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.security_config import SecurityConfig
from crewai.skills.models import Skill
from crewai.tools.base_tool import BaseTool, Tool
from crewai.tools.base_tool import BaseTool, Tool, restore_tool_from_dict
from crewai.types.callback import SerializableCallable
from crewai.utilities.config import process_config
from crewai.utilities.i18n import I18N, get_i18n
@@ -361,13 +361,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
def process_model_config(cls, values: Any) -> dict[str, Any]:
return process_config(values, cls)
@field_validator("tools")
@field_validator("tools", mode="before")
@classmethod
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
or an object with 'name', 'func', and 'description' attributes. If the
This method ensures that each tool is either an instance of BaseTool,
a dict from checkpoint deserialization with a ``tool_type`` key, or an
object with 'name', 'func', and 'description' attributes. If the
tool meets these criteria, it is processed and added to the list of
tools. Otherwise, a ValueError is raised.
"""
@@ -379,6 +380,8 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
for tool in tools:
if isinstance(tool, BaseTool):
processed_tools.append(tool)
elif isinstance(tool, dict) and "tool_type" in tool:
processed_tools.append(restore_tool_from_dict(tool))
elif all(hasattr(tool, attr) for attr in required_attrs):
# Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool))

View File

@@ -48,7 +48,7 @@ from crewai.llms.base_llm import BaseLLM
from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool
from crewai.tools.base_tool import BaseTool, restore_tool_from_dict
from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
from crewai.utilities.converter import Converter, convert_to_model
@@ -237,6 +237,21 @@ class Task(BaseModel):
_thread: threading.Thread | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}
@field_validator("tools", mode="before")
@classmethod
def _restore_tools_from_checkpoint(
cls, tools: list[Any] | None
) -> list[Any] | None:
if not tools:
return tools
restored: list[Any] = []
for tool in tools:
if isinstance(tool, dict) and "tool_type" in tool:
restored.append(restore_tool_from_dict(tool))
else:
restored.append(tool)
return restored
@field_validator("guardrail")
@classmethod
def validate_guardrail_function(

View File

@@ -3,10 +3,12 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Awaitable, Callable
import importlib
from inspect import Parameter, signature
import json
import threading
from typing import (
Annotated,
Any,
Generic,
ParamSpec,
@@ -19,13 +21,21 @@ from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Field,
PlainSerializer,
PrivateAttr,
computed_field,
create_model,
field_validator,
)
from typing_extensions import TypeIs
from crewai.tools.structured_tool import CrewStructuredTool, build_schema_hint
from crewai.tools.structured_tool import (
CrewStructuredTool,
_deserialize_schema,
_serialize_schema,
build_schema_hint,
)
from crewai.types.callback import SerializableCallable, _resolve_dotted_path
from crewai.utilities.printer import Printer
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.string_utils import sanitize_tool_name
@@ -37,6 +47,37 @@ P = ParamSpec("P")
R = TypeVar("R", covariant=True)
def restore_tool_from_dict(data: dict[str, Any]) -> BaseTool:
"""Reconstruct a concrete ``BaseTool`` subclass from a checkpoint dict.
The dict must contain a ``tool_type`` key holding the fully qualified
class name (e.g. ``crewai_tools.tools.serper_dev_tool.SerperDevTool``).
The class is imported and instantiated with the remaining fields so that
its ``_run`` method is available.
"""
data = dict(data) # avoid mutating caller
dotted = data.pop("tool_type")
mod_path, cls_name = dotted.rsplit(".", 1)
cls = getattr(importlib.import_module(mod_path), cls_name)
for key in ("cache_function",):
val = data.get(key)
if isinstance(val, str):
try:
data[key] = _resolve_dotted_path(val)
except (ValueError, ImportError):
data.pop(key)
result: BaseTool = cls(**data)
return result
def _default_cache_function(_args: Any = None, _result: Any = None) -> bool:
"""Default cache function that always allows caching."""
return True
def _is_async_callable(func: Callable[..., Any]) -> bool:
"""Check if a callable is async."""
return asyncio.iscoroutinefunction(func)
@@ -70,7 +111,10 @@ class BaseTool(BaseModel, ABC):
default_factory=list,
description="List of environment variables used by the tool.",
)
args_schema: type[PydanticBaseModel] = Field(
args_schema: Annotated[
type[PydanticBaseModel],
PlainSerializer(_serialize_schema, return_type=dict | None, when_used="json"),
] = Field(
default=_ArgsSchemaPlaceholder,
validate_default=True,
description="The schema for the arguments that the tool accepts.",
@@ -80,8 +124,8 @@ class BaseTool(BaseModel, ABC):
default=False, description="Flag to check if the description has been updated."
)
cache_function: Callable[..., bool] = Field(
default=lambda _args=None, _result=None: True,
cache_function: SerializableCallable = Field(
default=_default_cache_function,
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
)
result_as_answer: bool = Field(
@@ -98,12 +142,24 @@ class BaseTool(BaseModel, ABC):
)
_usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
@computed_field # type: ignore[prop-decorator]
@property
def tool_type(self) -> str:
cls = type(self)
return f"{cls.__module__}.{cls.__qualname__}"
@field_validator("args_schema", mode="before")
@classmethod
def _default_args_schema(
cls, v: type[PydanticBaseModel]
cls, v: type[PydanticBaseModel] | dict[str, Any] | None
) -> type[PydanticBaseModel]:
if v != cls._ArgsSchemaPlaceholder:
if isinstance(v, dict):
restored = _deserialize_schema(v)
if restored is not None:
return restored
if v is None or v == cls._ArgsSchemaPlaceholder:
pass # fall through to generate from signature
elif isinstance(v, type):
return v
run_sig = signature(cls._run)