Feat/remove langchain (#1668)

* feat: add initial changes from langchain

* feat: remove kwargs of being processed

* feat: remove langchain, update uv.lock and fix type_hint

* feat: change docs

* feat: remove forced requirements for parameter

* feat add tests for new structure tool

* feat: fix tests and adapt code for args

* fix tool calling for langchain tools

* doc strings

---------

Co-authored-by: Eduardo Chiarotti <dudumelgaco@hotmail.com>
This commit is contained in:
Brandon Hancock (bhancock_ai)
2024-11-27 11:22:49 -05:00
committed by GitHub
parent 293305790d
commit 366bbbbea3
2 changed files with 128 additions and 14 deletions

View File

@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Type, get_args, get_origin
from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic import BaseModel, ConfigDict, Field, create_model, validator
from pydantic import BaseModel as PydanticBaseModel
from crewai.tools.structured_tool import CrewStructuredTool
@@ -75,17 +76,47 @@ class BaseTool(BaseModel, ABC):
)
@classmethod
def from_langchain(cls, tool: CrewStructuredTool) -> "BaseTool":
if cls == Tool:
if tool.func is None:
raise ValueError("CrewStructuredTool must have a callable 'func'")
return Tool(
name=tool.name,
description=tool.description,
args_schema=tool.args_schema,
func=tool.func,
)
raise NotImplementedError(f"from_langchain not implemented for {cls.__name__}")
def from_langchain(cls, tool: Any) -> "BaseTool":
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
Tool instance. It ensures that the provided tool has a callable 'func'
attribute and infers the argument schema if not explicitly provided.
"""
if not hasattr(tool, "func") or not callable(tool.func):
raise ValueError("The provided tool must have a callable 'func' attribute.")
args_schema = getattr(tool, "args_schema", None)
if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
return cls(
name=getattr(tool, "name", "Unnamed Tool"),
description=getattr(tool, "description", ""),
func=tool.func,
args_schema=args_schema,
)
def _set_args_schema(self):
if self.args_schema is None:
@@ -136,12 +167,65 @@ class BaseTool(BaseModel, ABC):
class Tool(BaseTool):
func: Callable
"""The function that will be executed when the tool is called."""
func: Callable
def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
@classmethod
def from_langchain(cls, tool: Any) -> "Tool":
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
Tool instance. It ensures that the provided tool has a callable 'func'
attribute and infers the argument schema if not explicitly provided.
Args:
tool (Any): The CrewStructuredTool object to be converted.
Returns:
Tool: A new Tool instance created from the provided CrewStructuredTool.
Raises:
ValueError: If the provided tool does not have a callable 'func' attribute.
"""
if not hasattr(tool, "func") or not callable(tool.func):
raise ValueError("The provided tool must have a callable 'func' attribute.")
args_schema = getattr(tool, "args_schema", None)
if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
return cls(
name=getattr(tool, "name", "Unnamed Tool"),
description=getattr(tool, "description", ""),
func=tool.func,
args_schema=args_schema,
)
def to_langchain(
tools: list[BaseTool | CrewStructuredTool],