mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: make BaseTool survive checkpoint JSON round-trip
BaseTool could not serialize to JSON because args_schema (a class reference) and cache_function (a lambda) are not JSON-serializable. This caused checkpointing to crash for any crew with tools. - Add PlainSerializer to args_schema so it round-trips via JSON schema - Replace default cache_function lambda with named _default_cache_function and type it as SerializableCallable so it serializes to a dotted path - Add computed_field tool_type that stores the fully qualified class name - Add restore_tool_from_dict to reconstruct the concrete subclass from checkpoint dicts, pre-resolving callback strings to callables - Update BaseAgent.validate_tools and Task._restore_tools_from_checkpoint to handle dict inputs from checkpoint deserialization
This commit is contained in:
@@ -39,7 +39,7 @@ from crewai.memory.unified_memory import Memory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.skills.models import Skill
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.tools.base_tool import BaseTool, Tool, restore_tool_from_dict
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
@@ -361,13 +361,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
def process_model_config(cls, values: Any) -> dict[str, Any]:
|
||||
return process_config(values, cls)
|
||||
|
||||
@field_validator("tools")
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
|
||||
"""Validate and process the tools provided to the agent.
|
||||
|
||||
This method ensures that each tool is either an instance of BaseTool
|
||||
or an object with 'name', 'func', and 'description' attributes. If the
|
||||
This method ensures that each tool is either an instance of BaseTool,
|
||||
a dict from checkpoint deserialization with a ``tool_type`` key, or an
|
||||
object with 'name', 'func', and 'description' attributes. If the
|
||||
tool meets these criteria, it is processed and added to the list of
|
||||
tools. Otherwise, a ValueError is raised.
|
||||
"""
|
||||
@@ -379,6 +380,8 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
processed_tools.append(tool)
|
||||
elif isinstance(tool, dict) and "tool_type" in tool:
|
||||
processed_tools.append(restore_tool_from_dict(tool))
|
||||
elif all(hasattr(tool, attr) for attr in required_attrs):
|
||||
# Tool has the required attributes, create a Tool instance
|
||||
processed_tools.append(Tool.from_langchain(tool))
|
||||
|
||||
@@ -48,7 +48,7 @@ from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.base_tool import BaseTool, restore_tool_from_dict
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
|
||||
from crewai.utilities.converter import Converter, convert_to_model
|
||||
@@ -237,6 +237,21 @@ class Task(BaseModel):
|
||||
_thread: threading.Thread | None = PrivateAttr(default=None)
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def _restore_tools_from_checkpoint(
|
||||
cls, tools: list[Any] | None
|
||||
) -> list[Any] | None:
|
||||
if not tools:
|
||||
return tools
|
||||
restored: list[Any] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict) and "tool_type" in tool:
|
||||
restored.append(restore_tool_from_dict(tool))
|
||||
else:
|
||||
restored.append(tool)
|
||||
return restored
|
||||
|
||||
@field_validator("guardrail")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
|
||||
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
import importlib
|
||||
from inspect import Parameter, signature
|
||||
import json
|
||||
import threading
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Generic,
|
||||
ParamSpec,
|
||||
@@ -19,13 +21,21 @@ from pydantic import (
|
||||
BaseModel as PydanticBaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PlainSerializer,
|
||||
PrivateAttr,
|
||||
computed_field,
|
||||
create_model,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool, build_schema_hint
|
||||
from crewai.tools.structured_tool import (
|
||||
CrewStructuredTool,
|
||||
_deserialize_schema,
|
||||
_serialize_schema,
|
||||
build_schema_hint,
|
||||
)
|
||||
from crewai.types.callback import SerializableCallable, _resolve_dotted_path
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
@@ -37,6 +47,37 @@ P = ParamSpec("P")
|
||||
R = TypeVar("R", covariant=True)
|
||||
|
||||
|
||||
def restore_tool_from_dict(data: dict[str, Any]) -> BaseTool:
|
||||
"""Reconstruct a concrete ``BaseTool`` subclass from a checkpoint dict.
|
||||
|
||||
The dict must contain a ``tool_type`` key holding the fully qualified
|
||||
class name (e.g. ``crewai_tools.tools.serper_dev_tool.SerperDevTool``).
|
||||
The class is imported and instantiated with the remaining fields so that
|
||||
its ``_run`` method is available.
|
||||
"""
|
||||
|
||||
data = dict(data) # avoid mutating caller
|
||||
dotted = data.pop("tool_type")
|
||||
mod_path, cls_name = dotted.rsplit(".", 1)
|
||||
cls = getattr(importlib.import_module(mod_path), cls_name)
|
||||
|
||||
for key in ("cache_function",):
|
||||
val = data.get(key)
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
data[key] = _resolve_dotted_path(val)
|
||||
except (ValueError, ImportError):
|
||||
data.pop(key)
|
||||
|
||||
result: BaseTool = cls(**data)
|
||||
return result
|
||||
|
||||
|
||||
def _default_cache_function(_args: Any = None, _result: Any = None) -> bool:
|
||||
"""Default cache function that always allows caching."""
|
||||
return True
|
||||
|
||||
|
||||
def _is_async_callable(func: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable is async."""
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
@@ -70,7 +111,10 @@ class BaseTool(BaseModel, ABC):
|
||||
default_factory=list,
|
||||
description="List of environment variables used by the tool.",
|
||||
)
|
||||
args_schema: type[PydanticBaseModel] = Field(
|
||||
args_schema: Annotated[
|
||||
type[PydanticBaseModel],
|
||||
PlainSerializer(_serialize_schema, return_type=dict | None, when_used="json"),
|
||||
] = Field(
|
||||
default=_ArgsSchemaPlaceholder,
|
||||
validate_default=True,
|
||||
description="The schema for the arguments that the tool accepts.",
|
||||
@@ -80,8 +124,8 @@ class BaseTool(BaseModel, ABC):
|
||||
default=False, description="Flag to check if the description has been updated."
|
||||
)
|
||||
|
||||
cache_function: Callable[..., bool] = Field(
|
||||
default=lambda _args=None, _result=None: True,
|
||||
cache_function: SerializableCallable = Field(
|
||||
default=_default_cache_function,
|
||||
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
||||
)
|
||||
result_as_answer: bool = Field(
|
||||
@@ -98,12 +142,24 @@ class BaseTool(BaseModel, ABC):
|
||||
)
|
||||
_usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def tool_type(self) -> str:
|
||||
cls = type(self)
|
||||
return f"{cls.__module__}.{cls.__qualname__}"
|
||||
|
||||
@field_validator("args_schema", mode="before")
|
||||
@classmethod
|
||||
def _default_args_schema(
|
||||
cls, v: type[PydanticBaseModel]
|
||||
cls, v: type[PydanticBaseModel] | dict[str, Any] | None
|
||||
) -> type[PydanticBaseModel]:
|
||||
if v != cls._ArgsSchemaPlaceholder:
|
||||
if isinstance(v, dict):
|
||||
restored = _deserialize_schema(v)
|
||||
if restored is not None:
|
||||
return restored
|
||||
if v is None or v == cls._ArgsSchemaPlaceholder:
|
||||
pass # fall through to generate from signature
|
||||
elif isinstance(v, type):
|
||||
return v
|
||||
|
||||
run_sig = signature(cls._run)
|
||||
|
||||
Reference in New Issue
Block a user