mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 14:09:24 +00:00
fix: make BaseTool survive checkpoint JSON round-trip
BaseTool could not serialize to JSON because args_schema (a class reference) and cache_function (a lambda) are not JSON-serializable. This caused checkpointing to crash for any crew with tools. - Add PlainSerializer to args_schema so it round-trips via JSON schema - Replace default cache_function lambda with named _default_cache_function and type it as SerializableCallable - Add computed_field tool_type storing the fully qualified class name - Add __init_subclass__ registry and __get_pydantic_core_schema__ on BaseTool so any list[BaseTool] field automatically dispatches to the concrete subclass during deserialization via tool_type lookup - No changes needed to BaseAgent.validate_tools or Task — Pydantic handles it natively through the custom core schema
This commit is contained in:
@@ -39,7 +39,7 @@ from crewai.memory.unified_memory import Memory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.skills.models import Skill
|
||||
from crewai.tools.base_tool import BaseTool, Tool, restore_tool_from_dict
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
@@ -361,14 +361,13 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
def process_model_config(cls, values: Any) -> dict[str, Any]:
|
||||
return process_config(values, cls)
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@field_validator("tools")
|
||||
@classmethod
|
||||
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
|
||||
"""Validate and process the tools provided to the agent.
|
||||
|
||||
This method ensures that each tool is either an instance of BaseTool,
|
||||
a dict from checkpoint deserialization with a ``tool_type`` key, or an
|
||||
object with 'name', 'func', and 'description' attributes. If the
|
||||
This method ensures that each tool is either an instance of BaseTool
|
||||
or an object with 'name', 'func', and 'description' attributes. If the
|
||||
tool meets these criteria, it is processed and added to the list of
|
||||
tools. Otherwise, a ValueError is raised.
|
||||
"""
|
||||
@@ -380,8 +379,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
processed_tools.append(tool)
|
||||
elif isinstance(tool, dict) and "tool_type" in tool:
|
||||
processed_tools.append(restore_tool_from_dict(tool))
|
||||
elif all(hasattr(tool, attr) for attr in required_attrs):
|
||||
# Tool has the required attributes, create a Tool instance
|
||||
processed_tools.append(Tool.from_langchain(tool))
|
||||
|
||||
@@ -48,7 +48,7 @@ from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.base_tool import BaseTool, restore_tool_from_dict
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
|
||||
from crewai.utilities.converter import Converter, convert_to_model
|
||||
@@ -237,21 +237,6 @@ class Task(BaseModel):
|
||||
_thread: threading.Thread | None = PrivateAttr(default=None)
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def _restore_tools_from_checkpoint(
|
||||
cls, tools: list[Any] | None
|
||||
) -> list[Any] | None:
|
||||
if not tools:
|
||||
return tools
|
||||
restored: list[Any] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict) and "tool_type" in tool:
|
||||
restored.append(restore_tool_from_dict(tool))
|
||||
else:
|
||||
restored.append(tool)
|
||||
return restored
|
||||
|
||||
@field_validator("guardrail")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
|
||||
@@ -21,12 +21,14 @@ from pydantic import (
|
||||
BaseModel as PydanticBaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
GetCoreSchemaHandler,
|
||||
PlainSerializer,
|
||||
PrivateAttr,
|
||||
computed_field,
|
||||
create_model,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.tools.structured_tool import (
|
||||
@@ -46,21 +48,27 @@ _printer = Printer()
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R", covariant=True)
|
||||
|
||||
# Registry populated by BaseTool.__init_subclass__; used for checkpoint
|
||||
# deserialization so that list[BaseTool] fields resolve the concrete class.
|
||||
_TOOL_TYPE_REGISTRY: dict[str, type] = {}
|
||||
|
||||
def restore_tool_from_dict(data: dict[str, Any]) -> BaseTool:
|
||||
"""Reconstruct a concrete ``BaseTool`` subclass from a checkpoint dict.
|
||||
# Sentinel set after BaseTool is defined so __get_pydantic_core_schema__
|
||||
# can distinguish the base class from subclasses despite
|
||||
# ``from __future__ import annotations``.
|
||||
_BASE_TOOL_CLS: type | None = None
|
||||
|
||||
The dict must contain a ``tool_type`` key holding the fully qualified
|
||||
class name (e.g. ``crewai_tools.tools.serper_dev_tool.SerperDevTool``).
|
||||
The class is imported and instantiated with the remaining fields so that
|
||||
its ``_run`` method is available.
|
||||
"""
|
||||
|
||||
data = dict(data) # avoid mutating caller
|
||||
dotted = data.pop("tool_type")
|
||||
mod_path, cls_name = dotted.rsplit(".", 1)
|
||||
cls = getattr(importlib.import_module(mod_path), cls_name)
|
||||
def _resolve_tool_dict(value: dict[str, Any]) -> Any:
|
||||
"""Validate a dict with ``tool_type`` into the concrete BaseTool subclass."""
|
||||
dotted = value.get("tool_type", "")
|
||||
tool_cls = _TOOL_TYPE_REGISTRY.get(dotted)
|
||||
if tool_cls is None:
|
||||
mod_path, cls_name = dotted.rsplit(".", 1)
|
||||
tool_cls = getattr(importlib.import_module(mod_path), cls_name)
|
||||
|
||||
# Pre-resolve serialized callback strings so SerializableCallable's
|
||||
# BeforeValidator sees a callable and skips the env-var guard.
|
||||
data = dict(value)
|
||||
for key in ("cache_function",):
|
||||
val = data.get(key)
|
||||
if isinstance(val, str):
|
||||
@@ -69,8 +77,7 @@ def restore_tool_from_dict(data: dict[str, Any]) -> BaseTool:
|
||||
except (ValueError, ImportError):
|
||||
data.pop(key)
|
||||
|
||||
result: BaseTool = cls(**data)
|
||||
return result
|
||||
return tool_cls.model_validate(data) # type: ignore[union-attr]
|
||||
|
||||
|
||||
def _default_cache_function(_args: Any = None, _result: Any = None) -> bool:
|
||||
@@ -101,6 +108,36 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
key = f"{cls.__module__}.{cls.__qualname__}"
|
||||
_TOOL_TYPE_REGISTRY[key] = cls
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
default_schema = handler(source_type)
|
||||
if cls is not _BASE_TOOL_CLS:
|
||||
return default_schema
|
||||
|
||||
def _validate_tool(value: Any, nxt: Any) -> Any:
|
||||
if isinstance(value, _BASE_TOOL_CLS):
|
||||
return value
|
||||
if isinstance(value, dict) and "tool_type" in value:
|
||||
return _resolve_tool_dict(value)
|
||||
return nxt(value)
|
||||
|
||||
return core_schema.no_info_wrap_validator_function(
|
||||
_validate_tool,
|
||||
default_schema,
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda v: v.model_dump(mode="json"),
|
||||
info_arg=False,
|
||||
when_used="json",
|
||||
),
|
||||
)
|
||||
|
||||
name: str = Field(
|
||||
description="The unique name of the tool that clearly communicates its purpose."
|
||||
)
|
||||
@@ -421,6 +458,9 @@ class BaseTool(BaseModel, ABC):
|
||||
)
|
||||
|
||||
|
||||
_BASE_TOOL_CLS = BaseTool
|
||||
|
||||
|
||||
class Tool(BaseTool, Generic[P, R]):
|
||||
"""Tool that wraps a callable function.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user