From a737d4f27b991018c7b7bde10d0b9625edd667d1 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 27 Nov 2024 10:45:11 -0500 Subject: [PATCH] fix tool calling for langchain tools --- src/crewai/agents/agent_builder/base_agent.py | 25 ++++++++++- src/crewai/tools/base_tool.py | 43 ++++++++++++++++++- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index 55315e7ff..6659a16a8 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -19,6 +19,7 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces from crewai.agents.cache.cache_handler import CacheHandler from crewai.agents.tools_handler import ToolsHandler from crewai.tools import BaseTool +from crewai.tools.base_tool import Tool from crewai.utilities import I18N, Logger, RPMController from crewai.utilities.config import process_config @@ -106,7 +107,7 @@ class BaseAgent(ABC, BaseModel): default=False, description="Enable agent to delegate and ask questions among each other.", ) - tools: Optional[List[BaseTool]] = Field( + tools: Optional[List[Any]] = Field( default_factory=list, description="Tools at agents' disposal" ) max_iter: Optional[int] = Field( @@ -135,6 +136,28 @@ class BaseAgent(ABC, BaseModel): def process_model_config(cls, values): return process_config(values, cls) + @field_validator("tools") + @classmethod + def validate_tools(cls, tools): + processed_tools = [] + for tool in tools: + if isinstance(tool, BaseTool): + processed_tools.append(tool) + elif ( + hasattr(tool, "name") + and hasattr(tool, "func") + and hasattr(tool, "description") + ): + # Tool has the required attributes, create a Tool instance + processed_tools.append(Tool.from_langchain(tool)) + else: + raise ValueError( + f"Invalid tool type: {type(tool)}. " + "Tool must be an instance of BaseTool or " + "an object with 'name', 'func', and 'description' attributes." + ) + return processed_tools + @model_validator(mode="after") def validate_and_set_attributes(self): # Validate required fields diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index 06e427528..44076eb5c 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod +from inspect import signature from typing import Any, Callable, Type, get_args, get_origin -from pydantic import BaseModel, ConfigDict, Field, validator +from pydantic import BaseModel, ConfigDict, Field, create_model, validator from pydantic import BaseModel as PydanticBaseModel from crewai.tools.structured_tool import CrewStructuredTool @@ -136,12 +137,50 @@ class BaseTool(BaseModel, ABC): class Tool(BaseTool): - func: Callable """The function that will be executed when the tool is called.""" + func: Callable + def _run(self, *args: Any, **kwargs: Any) -> Any: return self.func(*args, **kwargs) + @classmethod + def from_langchain(cls, tool: Any) -> "Tool": + if not hasattr(tool, "func") or not callable(tool.func): + raise ValueError("The provided tool must have a callable 'func' attribute.") + + args_schema = getattr(tool, "args_schema", None) + + if args_schema is None: + # Infer args_schema from the function signature if not provided + func_signature = signature(tool.func) + annotations = func_signature.parameters + args_fields = {} + for name, param in annotations.items(): + if name != "self": + param_annotation = ( + param.annotation if param.annotation != param.empty else Any + ) + field_info = Field( + default=..., + description="", + ) + args_fields[name] = (param_annotation, field_info) + if args_fields: + args_schema = create_model(f"{tool.name}Input", **args_fields) + else: + # Create a default schema with no fields if no parameters are found + args_schema = create_model( + f"{tool.name}Input", __base__=PydanticBaseModel + ) + + return cls( + name=getattr(tool, "name", "Unnamed Tool"), + description=getattr(tool, "description", ""), + func=tool.func, + args_schema=args_schema, + ) + def to_langchain( tools: list[BaseTool | CrewStructuredTool],