diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index e80e190f1..cfa08bbc3 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -39,7 +39,7 @@ from crewai.memory.unified_memory import Memory from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig from crewai.skills.models import Skill -from crewai.tools.base_tool import BaseTool, Tool, restore_tool_from_dict +from crewai.tools.base_tool import BaseTool, Tool from crewai.types.callback import SerializableCallable from crewai.utilities.config import process_config from crewai.utilities.i18n import I18N, get_i18n @@ -361,14 +361,13 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): def process_model_config(cls, values: Any) -> dict[str, Any]: return process_config(values, cls) - @field_validator("tools", mode="before") + @field_validator("tools") @classmethod def validate_tools(cls, tools: list[Any]) -> list[BaseTool]: """Validate and process the tools provided to the agent. - This method ensures that each tool is either an instance of BaseTool, - a dict from checkpoint deserialization with a ``tool_type`` key, or an - object with 'name', 'func', and 'description' attributes. If the + This method ensures that each tool is either an instance of BaseTool + or an object with 'name', 'func', and 'description' attributes. If the tool meets these criteria, it is processed and added to the list of tools. Otherwise, a ValueError is raised. """ @@ -380,8 +379,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): for tool in tools: if isinstance(tool, BaseTool): processed_tools.append(tool) - elif isinstance(tool, dict) and "tool_type" in tool: - processed_tools.append(restore_tool_from_dict(tool)) elif all(hasattr(tool, attr) for attr in required_attrs): # Tool has the required attributes, create a Tool instance processed_tools.append(Tool.from_langchain(tool)) diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 4fb5aad04..73e49ade9 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -48,7 +48,7 @@ from crewai.llms.base_llm import BaseLLM from crewai.security import Fingerprint, SecurityConfig from crewai.tasks.output_format import OutputFormat from crewai.tasks.task_output import TaskOutput -from crewai.tools.base_tool import BaseTool, restore_tool_from_dict +from crewai.tools.base_tool import BaseTool from crewai.utilities.config import process_config from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified from crewai.utilities.converter import Converter, convert_to_model @@ -237,21 +237,6 @@ class Task(BaseModel): _thread: threading.Thread | None = PrivateAttr(default=None) model_config = {"arbitrary_types_allowed": True} - @field_validator("tools", mode="before") - @classmethod - def _restore_tools_from_checkpoint( - cls, tools: list[Any] | None - ) -> list[Any] | None: - if not tools: - return tools - restored: list[Any] = [] - for tool in tools: - if isinstance(tool, dict) and "tool_type" in tool: - restored.append(restore_tool_from_dict(tool)) - else: - restored.append(tool) - return restored - @field_validator("guardrail") @classmethod def validate_guardrail_function( diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 68260c8f1..11f88a768 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -21,12 +21,14 @@ from pydantic import ( BaseModel as PydanticBaseModel, ConfigDict, Field, + GetCoreSchemaHandler, PlainSerializer, PrivateAttr, computed_field, create_model, field_validator, ) +from pydantic_core import CoreSchema, core_schema from typing_extensions import TypeIs from crewai.tools.structured_tool import ( @@ -46,21 +48,27 @@ _printer = Printer() P = ParamSpec("P") R = TypeVar("R", covariant=True) +# Registry populated by BaseTool.__init_subclass__; used for checkpoint +# deserialization so that list[BaseTool] fields resolve the concrete class. +_TOOL_TYPE_REGISTRY: dict[str, type] = {} -def restore_tool_from_dict(data: dict[str, Any]) -> BaseTool: - """Reconstruct a concrete ``BaseTool`` subclass from a checkpoint dict. +# Sentinel set after BaseTool is defined so __get_pydantic_core_schema__ +# can distinguish the base class from subclasses despite +# ``from __future__ import annotations``. +_BASE_TOOL_CLS: type | None = None - The dict must contain a ``tool_type`` key holding the fully qualified - class name (e.g. ``crewai_tools.tools.serper_dev_tool.SerperDevTool``). - The class is imported and instantiated with the remaining fields so that - its ``_run`` method is available. - """ - data = dict(data) # avoid mutating caller - dotted = data.pop("tool_type") - mod_path, cls_name = dotted.rsplit(".", 1) - cls = getattr(importlib.import_module(mod_path), cls_name) +def _resolve_tool_dict(value: dict[str, Any]) -> Any: + """Validate a dict with ``tool_type`` into the concrete BaseTool subclass.""" + dotted = value.get("tool_type", "") + tool_cls = _TOOL_TYPE_REGISTRY.get(dotted) + if tool_cls is None: + mod_path, cls_name = dotted.rsplit(".", 1) + tool_cls = getattr(importlib.import_module(mod_path), cls_name) + # Pre-resolve serialized callback strings so SerializableCallable's + # BeforeValidator sees a callable and skips the env-var guard. + data = dict(value) for key in ("cache_function",): val = data.get(key) if isinstance(val, str): @@ -69,8 +77,7 @@ def restore_tool_from_dict(data: dict[str, Any]) -> BaseTool: except (ValueError, ImportError): data.pop(key) - result: BaseTool = cls(**data) - return result + return tool_cls.model_validate(data) # type: ignore[union-attr] def _default_cache_function(_args: Any = None, _result: Any = None) -> bool: @@ -101,6 +108,36 @@ class BaseTool(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + key = f"{cls.__module__}.{cls.__qualname__}" + _TOOL_TYPE_REGISTRY[key] = cls + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + default_schema = handler(source_type) + if cls is not _BASE_TOOL_CLS: + return default_schema + + def _validate_tool(value: Any, nxt: Any) -> Any: + if isinstance(value, _BASE_TOOL_CLS): + return value + if isinstance(value, dict) and "tool_type" in value: + return _resolve_tool_dict(value) + return nxt(value) + + return core_schema.no_info_wrap_validator_function( + _validate_tool, + default_schema, + serialization=core_schema.plain_serializer_function_ser_schema( + lambda v: v.model_dump(mode="json"), + info_arg=False, + when_used="json", + ), + ) + name: str = Field( description="The unique name of the tool that clearly communicates its purpose." ) @@ -421,6 +458,9 @@ class BaseTool(BaseModel, ABC): ) +_BASE_TOOL_CLS = BaseTool + + class Tool(BaseTool, Generic[P, R]): """Tool that wraps a callable function.