Compare commits

...

12 Commits

Author SHA1 Message Date
Brandon Hancock
e748615c9c doc strings 2024-11-27 11:04:24 -05:00
Brandon Hancock
94176552f3 Merge branch 'main' into feat/remove-langchain 2024-11-27 10:47:37 -05:00
Brandon Hancock
a737d4f27b fix tool calling for langchain tools 2024-11-27 10:45:11 -05:00
Eduardo Chiarotti
4d67b2db8e Merge branch 'main' into feat/remove-langchain 2024-11-26 16:56:41 -03:00
Eduardo Chiarotti
44b9fc5058 feat: fix tests and adapt code for args 2024-11-25 19:45:41 -03:00
Eduardo Chiarotti
65f0730d5a Merge branch 'main' into feat/remove-langchain 2024-11-25 19:21:49 -03:00
Eduardo Chiarotti
5195d9edce feat add tests for new structure tool 2024-11-25 16:53:33 -03:00
Eduardo Chiarotti
3b8ab25714 feat: remove forced requirements for parameter 2024-11-25 16:37:25 -03:00
Eduardo Chiarotti
9c3843f925 feat: change docs 2024-11-25 16:10:22 -03:00
Eduardo Chiarotti
6e492e90e4 feat: remove langchain, update uv.lock and fix type_hint 2024-11-25 15:30:29 -03:00
Eduardo Chiarotti
aaec3e1713 feat: remove kwargs of being processed 2024-11-22 16:11:39 -03:00
Eduardo Chiarotti
1edff675eb feat: add initial changes from langchain 2024-11-22 15:49:40 -03:00
2 changed files with 128 additions and 14 deletions

View File

@@ -19,6 +19,7 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler
from crewai.tools import BaseTool
from crewai.tools.base_tool import Tool
from crewai.utilities import I18N, Logger, RPMController
from crewai.utilities.config import process_config
@@ -106,7 +107,7 @@ class BaseAgent(ABC, BaseModel):
default=False,
description="Enable agent to delegate and ask questions among each other.",
)
tools: Optional[List[BaseTool]] = Field(
tools: Optional[List[Any]] = Field(
default_factory=list, description="Tools at agents' disposal"
)
max_iter: Optional[int] = Field(
@@ -135,6 +136,35 @@ class BaseAgent(ABC, BaseModel):
def process_model_config(cls, values):
return process_config(values, cls)
@field_validator("tools")
@classmethod
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
or an object with 'name', 'func', and 'description' attributes. If the
tool meets these criteria, it is processed and added to the list of
tools. Otherwise, a ValueError is raised.
"""
processed_tools = []
for tool in tools:
if isinstance(tool, BaseTool):
processed_tools.append(tool)
elif (
hasattr(tool, "name")
and hasattr(tool, "func")
and hasattr(tool, "description")
):
# Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool))
else:
raise ValueError(
f"Invalid tool type: {type(tool)}. "
"Tool must be an instance of BaseTool or "
"an object with 'name', 'func', and 'description' attributes."
)
return processed_tools
@model_validator(mode="after")
def validate_and_set_attributes(self):
# Validate required fields

View File

@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Type, get_args, get_origin
from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic import BaseModel, ConfigDict, Field, create_model, validator
from pydantic import BaseModel as PydanticBaseModel
from crewai.tools.structured_tool import CrewStructuredTool
@@ -75,17 +76,47 @@ class BaseTool(BaseModel, ABC):
)
@classmethod
def from_langchain(cls, tool: CrewStructuredTool) -> "BaseTool":
if cls == Tool:
if tool.func is None:
raise ValueError("CrewStructuredTool must have a callable 'func'")
return Tool(
name=tool.name,
description=tool.description,
args_schema=tool.args_schema,
func=tool.func,
)
raise NotImplementedError(f"from_langchain not implemented for {cls.__name__}")
def from_langchain(cls, tool: Any) -> "BaseTool":
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
Tool instance. It ensures that the provided tool has a callable 'func'
attribute and infers the argument schema if not explicitly provided.
"""
if not hasattr(tool, "func") or not callable(tool.func):
raise ValueError("The provided tool must have a callable 'func' attribute.")
args_schema = getattr(tool, "args_schema", None)
if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
return cls(
name=getattr(tool, "name", "Unnamed Tool"),
description=getattr(tool, "description", ""),
func=tool.func,
args_schema=args_schema,
)
def _set_args_schema(self):
if self.args_schema is None:
@@ -136,12 +167,65 @@ class BaseTool(BaseModel, ABC):
class Tool(BaseTool):
func: Callable
"""The function that will be executed when the tool is called."""
func: Callable
def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
@classmethod
def from_langchain(cls, tool: Any) -> "Tool":
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
Tool instance. It ensures that the provided tool has a callable 'func'
attribute and infers the argument schema if not explicitly provided.
Args:
tool (Any): The CrewStructuredTool object to be converted.
Returns:
Tool: A new Tool instance created from the provided CrewStructuredTool.
Raises:
ValueError: If the provided tool does not have a callable 'func' attribute.
"""
if not hasattr(tool, "func") or not callable(tool.func):
raise ValueError("The provided tool must have a callable 'func' attribute.")
args_schema = getattr(tool, "args_schema", None)
if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
return cls(
name=getattr(tool, "name", "Unnamed Tool"),
description=getattr(tool, "description", ""),
func=tool.func,
args_schema=args_schema,
)
def to_langchain(
tools: list[BaseTool | CrewStructuredTool],