Feat/remove langchain (#1654)

* feat: add initial changes from langchain

* feat: remove kwargs of being processed

* feat: remove langchain, update uv.lock and fix type_hint

* feat: change docs

* feat: remove forced requirements for parameter

* feat add tests for new structure tool

* feat: fix tests and adapt code for args
This commit is contained in:
Eduardo Chiarotti
2024-11-26 16:59:52 -03:00
committed by GitHub
parent 8bc09eb054
commit 293305790d
8 changed files with 408 additions and 29 deletions

View File

@@ -9,7 +9,6 @@ authors = [
] ]
dependencies = [ dependencies = [
"pydantic>=2.4.2", "pydantic>=2.4.2",
"langchain>=0.2.16",
"openai>=1.13.3", "openai>=1.13.3",
"opentelemetry-api>=1.22.0", "opentelemetry-api>=1.22.0",
"opentelemetry-sdk>=1.22.0", "opentelemetry-sdk>=1.22.0",

View File

@@ -412,7 +412,7 @@ class Agent(BaseAgent):
for tool in tools: for tool in tools:
if isinstance(tool, CrewAITool): if isinstance(tool, CrewAITool):
tools_list.append(tool.to_langchain()) tools_list.append(tool.to_structured_tool())
else: else:
tools_list.append(tool) tools_list.append(tool)
except ModuleNotFoundError: except ModuleNotFoundError:

View File

@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Type, get_args, get_origin from typing import Any, Callable, Type, get_args, get_origin
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, ConfigDict, Field, validator from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic import BaseModel as PydanticBaseModel from pydantic import BaseModel as PydanticBaseModel
from crewai.tools.structured_tool import CrewStructuredTool
class BaseTool(BaseModel, ABC): class BaseTool(BaseModel, ABC):
class _ArgsSchemaPlaceholder(PydanticBaseModel): class _ArgsSchemaPlaceholder(PydanticBaseModel):
@@ -63,9 +64,10 @@ class BaseTool(BaseModel, ABC):
) -> Any: ) -> Any:
"""Here goes the actual implementation of the tool.""" """Here goes the actual implementation of the tool."""
def to_langchain(self) -> StructuredTool: def to_structured_tool(self) -> CrewStructuredTool:
"""Convert this tool to a CrewStructuredTool instance."""
self._set_args_schema() self._set_args_schema()
return StructuredTool( return CrewStructuredTool(
name=self.name, name=self.name,
description=self.description, description=self.description,
args_schema=self.args_schema, args_schema=self.args_schema,
@@ -73,10 +75,10 @@ class BaseTool(BaseModel, ABC):
) )
@classmethod @classmethod
def from_langchain(cls, tool: StructuredTool) -> "BaseTool": def from_langchain(cls, tool: CrewStructuredTool) -> "BaseTool":
if cls == Tool: if cls == Tool:
if tool.func is None: if tool.func is None:
raise ValueError("StructuredTool must have a callable 'func'") raise ValueError("CrewStructuredTool must have a callable 'func'")
return Tool( return Tool(
name=tool.name, name=tool.name,
description=tool.description, description=tool.description,
@@ -142,9 +144,9 @@ class Tool(BaseTool):
def to_langchain( def to_langchain(
tools: list[BaseTool | StructuredTool], tools: list[BaseTool | CrewStructuredTool],
) -> list[StructuredTool]: ) -> list[CrewStructuredTool]:
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools] return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
def tool(*args): def tool(*args):

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.tools.structured_tool import CrewStructuredTool
class CacheTools(BaseModel): class CacheTools(BaseModel):
@@ -13,9 +14,7 @@ class CacheTools(BaseModel):
) )
def tool(self): def tool(self):
from langchain.tools import StructuredTool return CrewStructuredTool.from_function(
return StructuredTool.from_function(
func=self.hit_cache, func=self.hit_cache,
name=self.name, name=self.name,
description="Reads directly from the cache", description="Reads directly from the cache",

View File

@@ -0,0 +1,242 @@
from __future__ import annotations
import inspect
import textwrap
from typing import Any, Callable, Optional, Union, get_type_hints
from pydantic import BaseModel, Field, create_model
from crewai.utilities.logger import Logger
class CrewStructuredTool:
"""A structured tool that can operate on any number of inputs.
This tool intends to replace StructuredTool with a custom implementation
that integrates better with CrewAI's ecosystem.
"""
def __init__(
self,
name: str,
description: str,
args_schema: type[BaseModel],
func: Callable[..., Any],
) -> None:
"""Initialize the structured tool.
Args:
name: The name of the tool
description: A description of what the tool does
args_schema: The pydantic model for the tool's arguments
func: The function to run when the tool is called
"""
self.name = name
self.description = description
self.args_schema = args_schema
self.func = func
self._logger = Logger()
# Validate the function signature matches the schema
self._validate_function_signature()
@classmethod
def from_function(
cls,
func: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[type[BaseModel]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> CrewStructuredTool:
"""Create a tool from a function.
Args:
func: The function to create a tool from
name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the output directly
args_schema: Optional schema for the function arguments
infer_schema: Whether to infer the schema from the function signature
**kwargs: Additional arguments to pass to the tool
Returns:
A CrewStructuredTool instance
Example:
>>> def add(a: int, b: int) -> int:
... '''Add two numbers'''
... return a + b
>>> tool = CrewStructuredTool.from_function(add)
"""
name = name or func.__name__
description = description or inspect.getdoc(func)
if description is None:
raise ValueError(
f"Function {name} must have a docstring if description not provided."
)
# Clean up the description
description = textwrap.dedent(description).strip()
if args_schema is not None:
# Use provided schema
schema = args_schema
elif infer_schema:
# Infer schema from function signature
schema = cls._create_schema_from_function(name, func)
else:
raise ValueError(
"Either args_schema must be provided or infer_schema must be True."
)
return cls(
name=name,
description=description,
args_schema=schema,
func=func,
)
@staticmethod
def _create_schema_from_function(
name: str,
func: Callable,
) -> type[BaseModel]:
"""Create a Pydantic schema from a function's signature.
Args:
name: The name to use for the schema
func: The function to create a schema from
Returns:
A Pydantic model class
"""
# Get function signature
sig = inspect.signature(func)
# Get type hints
type_hints = get_type_hints(func)
# Create field definitions
fields = {}
for param_name, param in sig.parameters.items():
# Skip self/cls for methods
if param_name in ("self", "cls"):
continue
# Get type annotation
annotation = type_hints.get(param_name, Any)
# Get default value
default = ... if param.default == param.empty else param.default
# Add field
fields[param_name] = (annotation, Field(default=default))
# Create model
schema_name = f"{name.title()}Schema"
return create_model(schema_name, **fields)
def _validate_function_signature(self) -> None:
"""Validate that the function signature matches the args schema."""
sig = inspect.signature(self.func)
schema_fields = self.args_schema.model_fields
# Check required parameters
for param_name, param in sig.parameters.items():
# Skip self/cls for methods
if param_name in ("self", "cls"):
continue
# Skip **kwargs parameters
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
):
continue
# Only validate required parameters without defaults
if param.default == inspect.Parameter.empty:
if param_name not in schema_fields:
raise ValueError(
f"Required function parameter '{param_name}' "
f"not found in args_schema"
)
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
"""Parse and validate the input arguments against the schema.
Args:
raw_args: The raw arguments to parse, either as a string or dict
Returns:
The validated arguments as a dictionary
"""
if isinstance(raw_args, str):
try:
import json
raw_args = json.loads(raw_args)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments as JSON: {e}")
try:
validated_args = self.args_schema.model_validate(raw_args)
return validated_args.model_dump()
except Exception as e:
raise ValueError(f"Arguments validation failed: {e}")
async def ainvoke(
self,
input: Union[str, dict],
config: Optional[dict] = None,
**kwargs: Any,
) -> Any:
"""Asynchronously invoke the tool.
Args:
input: The input arguments
config: Optional configuration
**kwargs: Additional keyword arguments
Returns:
The result of the tool execution
"""
parsed_args = self._parse_args(input)
if inspect.iscoroutinefunction(self.func):
return await self.func(**parsed_args, **kwargs)
else:
# Run sync functions in a thread pool
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, lambda: self.func(**parsed_args, **kwargs)
)
def _run(self, *args, **kwargs) -> Any:
"""Legacy method for compatibility."""
# Convert args/kwargs to our expected format
input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
input_dict.update(kwargs)
return self.invoke(input_dict)
def invoke(
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
) -> Any:
"""Main method for tool execution."""
parsed_args = self._parse_args(input)
return self.func(**parsed_args, **kwargs)
@property
def args(self) -> dict:
"""Get the tool's input arguments schema."""
return self.args_schema.model_json_schema()["properties"]
def __repr__(self) -> str:
return (
f"CrewStructuredTool(name='{self.name}', description='{self.description}')"
)

View File

@@ -1,4 +1,5 @@
from typing import Callable from typing import Callable
from crewai.tools import BaseTool, tool from crewai.tools import BaseTool, tool
@@ -21,8 +22,7 @@ def test_creating_a_tool_using_annotation():
my_tool.func("What is the meaning of life?") == "What is the meaning of life?" my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
) )
# Assert the langchain tool conversion worked as expected converted_tool = my_tool.to_structured_tool()
converted_tool = my_tool.to_langchain()
assert converted_tool.name == "Name of my tool" assert converted_tool.name == "Name of my tool"
assert ( assert (
@@ -41,9 +41,7 @@ def test_creating_a_tool_using_annotation():
def test_creating_a_tool_using_baseclass(): def test_creating_a_tool_using_baseclass():
class MyCustomTool(BaseTool): class MyCustomTool(BaseTool):
name: str = "Name of my tool" name: str = "Name of my tool"
description: str = ( description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
"Clear description for what this tool is useful for, you agent will need this information to use it."
)
def _run(self, question: str) -> str: def _run(self, question: str) -> str:
return question return question
@@ -61,8 +59,7 @@ def test_creating_a_tool_using_baseclass():
} }
assert my_tool.run("What is the meaning of life?") == "What is the meaning of life?" assert my_tool.run("What is the meaning of life?") == "What is the meaning of life?"
# Assert the langchain tool conversion worked as expected converted_tool = my_tool.to_structured_tool()
converted_tool = my_tool.to_langchain()
assert converted_tool.name == "Name of my tool" assert converted_tool.name == "Name of my tool"
assert ( assert (
@@ -73,7 +70,7 @@ def test_creating_a_tool_using_baseclass():
"question": {"title": "Question", "type": "string"} "question": {"title": "Question", "type": "string"}
} }
assert ( assert (
converted_tool.run("What is the meaning of life?") converted_tool._run("What is the meaning of life?")
== "What is the meaning of life?" == "What is the meaning of life?"
) )
@@ -81,9 +78,7 @@ def test_creating_a_tool_using_baseclass():
def test_setting_cache_function(): def test_setting_cache_function():
class MyCustomTool(BaseTool): class MyCustomTool(BaseTool):
name: str = "Name of my tool" name: str = "Name of my tool"
description: str = ( description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
"Clear description for what this tool is useful for, you agent will need this information to use it."
)
cache_function: Callable = lambda: False cache_function: Callable = lambda: False
def _run(self, question: str) -> str: def _run(self, question: str) -> str:
@@ -97,9 +92,7 @@ def test_setting_cache_function():
def test_default_cache_function_is_true(): def test_default_cache_function_is_true():
class MyCustomTool(BaseTool): class MyCustomTool(BaseTool):
name: str = "Name of my tool" name: str = "Name of my tool"
description: str = ( description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
"Clear description for what this tool is useful for, you agent will need this information to use it."
)
def _run(self, question: str) -> str: def _run(self, question: str) -> str:
return question return question

View File

@@ -0,0 +1,146 @@
from typing import Optional
import pytest
from pydantic import BaseModel, Field
from crewai.tools.structured_tool import CrewStructuredTool
# Test fixtures
@pytest.fixture
def basic_function():
def test_func(param1: str, param2: int = 0) -> str:
"""Test function with basic params."""
return f"{param1} {param2}"
return test_func
@pytest.fixture
def schema_class():
class TestSchema(BaseModel):
param1: str
param2: int = Field(default=0)
return TestSchema
class TestCrewStructuredTool:
def test_initialization(self, basic_function, schema_class):
"""Test basic initialization of CrewStructuredTool"""
tool = CrewStructuredTool(
name="test_tool",
description="Test tool description",
func=basic_function,
args_schema=schema_class,
)
assert tool.name == "test_tool"
assert tool.description == "Test tool description"
assert tool.func == basic_function
assert tool.args_schema == schema_class
def test_from_function(self, basic_function):
"""Test creating tool from function"""
tool = CrewStructuredTool.from_function(
func=basic_function, name="test_tool", description="Test description"
)
assert tool.name == "test_tool"
assert tool.description == "Test description"
assert tool.func == basic_function
assert isinstance(tool.args_schema, type(BaseModel))
def test_validate_function_signature(self, basic_function, schema_class):
"""Test function signature validation"""
tool = CrewStructuredTool(
name="test_tool",
description="Test tool",
func=basic_function,
args_schema=schema_class,
)
# Should not raise any exceptions
tool._validate_function_signature()
@pytest.mark.asyncio
async def test_ainvoke(self, basic_function):
"""Test asynchronous invocation"""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
result = await tool.ainvoke(input={"param1": "test"})
assert result == "test 0"
def test_parse_args_dict(self, basic_function):
"""Test parsing dictionary arguments"""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
parsed = tool._parse_args({"param1": "test", "param2": 42})
assert parsed["param1"] == "test"
assert parsed["param2"] == 42
def test_parse_args_string(self, basic_function):
"""Test parsing string arguments"""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
parsed = tool._parse_args('{"param1": "test", "param2": 42}')
assert parsed["param1"] == "test"
assert parsed["param2"] == 42
def test_complex_types(self):
"""Test handling of complex parameter types"""
def complex_func(nested: dict, items: list) -> str:
"""Process complex types."""
return f"Processed {len(items)} items with {len(nested)} nested keys"
tool = CrewStructuredTool.from_function(
func=complex_func, name="test_tool", description="Test complex types"
)
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
assert result == "Processed 3 items with 1 nested keys"
def test_schema_inheritance(self):
"""Test tool creation with inherited schema"""
def extended_func(base_param: str, extra_param: int) -> str:
"""Test function with inherited schema."""
return f"{base_param} {extra_param}"
class BaseSchema(BaseModel):
base_param: str
class ExtendedSchema(BaseSchema):
extra_param: int
tool = CrewStructuredTool.from_function(
func=extended_func, name="test_tool", args_schema=ExtendedSchema
)
result = tool.invoke({"base_param": "test", "extra_param": 42})
assert result == "test 42"
def test_default_values_in_schema(self):
"""Test handling of default values in schema"""
def default_func(
required_param: str,
optional_param: str = "default",
nullable_param: Optional[int] = None,
) -> str:
"""Test function with default values."""
return f"{required_param} {optional_param} {nullable_param}"
tool = CrewStructuredTool.from_function(
func=default_func, name="test_tool", description="Test defaults"
)
# Test with minimal parameters
result = tool.invoke({"required_param": "test"})
assert result == "test default None"
# Test with all parameters
result = tool.invoke(
{"required_param": "test", "optional_param": "custom", "nullable_param": 42}
)
assert result == "test custom 42"

2
uv.lock generated
View File

@@ -619,7 +619,6 @@ dependencies = [
{ name = "instructor" }, { name = "instructor" },
{ name = "json-repair" }, { name = "json-repair" },
{ name = "jsonref" }, { name = "jsonref" },
{ name = "langchain" },
{ name = "litellm" }, { name = "litellm" },
{ name = "openai" }, { name = "openai" },
{ name = "openpyxl" }, { name = "openpyxl" },
@@ -692,7 +691,6 @@ requires-dist = [
{ name = "instructor", specifier = ">=1.3.3" }, { name = "instructor", specifier = ">=1.3.3" },
{ name = "json-repair", specifier = ">=0.25.2" }, { name = "json-repair", specifier = ">=0.25.2" },
{ name = "jsonref", specifier = ">=1.1.0" }, { name = "jsonref", specifier = ">=1.1.0" },
{ name = "langchain", specifier = ">=0.2.16" },
{ name = "litellm", specifier = ">=1.44.22" }, { name = "litellm", specifier = ">=1.44.22" },
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = ">=0.1.29" }, { name = "mem0ai", marker = "extra == 'mem0'", specifier = ">=0.1.29" },
{ name = "openai", specifier = ">=1.13.3" }, { name = "openai", specifier = ">=1.13.3" },