mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Feat/remove langchain (#1654)
* feat: add initial changes from langchain * feat: remove kwargs of being processed * feat: remove langchain, update uv.lock and fix type_hint * feat: change docs * feat: remove forced requirements for parameter * feat add tests for new structure tool * feat: fix tests and adapt code for args
This commit is contained in:
committed by
GitHub
parent
8bc09eb054
commit
293305790d
@@ -9,7 +9,6 @@ authors = [
|
||||
]
|
||||
dependencies = [
|
||||
"pydantic>=2.4.2",
|
||||
"langchain>=0.2.16",
|
||||
"openai>=1.13.3",
|
||||
"opentelemetry-api>=1.22.0",
|
||||
"opentelemetry-sdk>=1.22.0",
|
||||
|
||||
@@ -412,7 +412,7 @@ class Agent(BaseAgent):
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, CrewAITool):
|
||||
tools_list.append(tool.to_langchain())
|
||||
tools_list.append(tool.to_structured_tool())
|
||||
else:
|
||||
tools_list.append(tool)
|
||||
except ModuleNotFoundError:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Type, get_args, get_origin
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, validator
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||
@@ -63,9 +64,10 @@ class BaseTool(BaseModel, ABC):
|
||||
) -> Any:
|
||||
"""Here goes the actual implementation of the tool."""
|
||||
|
||||
def to_langchain(self) -> StructuredTool:
|
||||
def to_structured_tool(self) -> CrewStructuredTool:
|
||||
"""Convert this tool to a CrewStructuredTool instance."""
|
||||
self._set_args_schema()
|
||||
return StructuredTool(
|
||||
return CrewStructuredTool(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
args_schema=self.args_schema,
|
||||
@@ -73,10 +75,10 @@ class BaseTool(BaseModel, ABC):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, tool: StructuredTool) -> "BaseTool":
|
||||
def from_langchain(cls, tool: CrewStructuredTool) -> "BaseTool":
|
||||
if cls == Tool:
|
||||
if tool.func is None:
|
||||
raise ValueError("StructuredTool must have a callable 'func'")
|
||||
raise ValueError("CrewStructuredTool must have a callable 'func'")
|
||||
return Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
@@ -142,9 +144,9 @@ class Tool(BaseTool):
|
||||
|
||||
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | StructuredTool],
|
||||
) -> list[StructuredTool]:
|
||||
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools]
|
||||
tools: list[BaseTool | CrewStructuredTool],
|
||||
) -> list[CrewStructuredTool]:
|
||||
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
def tool(*args):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
class CacheTools(BaseModel):
|
||||
@@ -13,9 +14,7 @@ class CacheTools(BaseModel):
|
||||
)
|
||||
|
||||
def tool(self):
|
||||
from langchain.tools import StructuredTool
|
||||
|
||||
return StructuredTool.from_function(
|
||||
return CrewStructuredTool.from_function(
|
||||
func=self.hit_cache,
|
||||
name=self.name,
|
||||
description="Reads directly from the cache",
|
||||
|
||||
242
src/crewai/tools/structured_tool.py
Normal file
242
src/crewai/tools/structured_tool.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any, Callable, Optional, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class CrewStructuredTool:
|
||||
"""A structured tool that can operate on any number of inputs.
|
||||
|
||||
This tool intends to replace StructuredTool with a custom implementation
|
||||
that integrates better with CrewAI's ecosystem.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
args_schema: type[BaseModel],
|
||||
func: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Initialize the structured tool.
|
||||
|
||||
Args:
|
||||
name: The name of the tool
|
||||
description: A description of what the tool does
|
||||
args_schema: The pydantic model for the tool's arguments
|
||||
func: The function to run when the tool is called
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.args_schema = args_schema
|
||||
self.func = func
|
||||
self._logger = Logger()
|
||||
|
||||
# Validate the function signature matches the schema
|
||||
self._validate_function_signature()
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Callable,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> CrewStructuredTool:
|
||||
"""Create a tool from a function.
|
||||
|
||||
Args:
|
||||
func: The function to create a tool from
|
||||
name: The name of the tool. Defaults to the function name
|
||||
description: The description of the tool. Defaults to the function docstring
|
||||
return_direct: Whether to return the output directly
|
||||
args_schema: Optional schema for the function arguments
|
||||
infer_schema: Whether to infer the schema from the function signature
|
||||
**kwargs: Additional arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
A CrewStructuredTool instance
|
||||
|
||||
Example:
|
||||
>>> def add(a: int, b: int) -> int:
|
||||
... '''Add two numbers'''
|
||||
... return a + b
|
||||
>>> tool = CrewStructuredTool.from_function(add)
|
||||
"""
|
||||
name = name or func.__name__
|
||||
description = description or inspect.getdoc(func)
|
||||
|
||||
if description is None:
|
||||
raise ValueError(
|
||||
f"Function {name} must have a docstring if description not provided."
|
||||
)
|
||||
|
||||
# Clean up the description
|
||||
description = textwrap.dedent(description).strip()
|
||||
|
||||
if args_schema is not None:
|
||||
# Use provided schema
|
||||
schema = args_schema
|
||||
elif infer_schema:
|
||||
# Infer schema from function signature
|
||||
schema = cls._create_schema_from_function(name, func)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either args_schema must be provided or infer_schema must be True."
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
description=description,
|
||||
args_schema=schema,
|
||||
func=func,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_schema_from_function(
|
||||
name: str,
|
||||
func: Callable,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic schema from a function's signature.
|
||||
|
||||
Args:
|
||||
name: The name to use for the schema
|
||||
func: The function to create a schema from
|
||||
|
||||
Returns:
|
||||
A Pydantic model class
|
||||
"""
|
||||
# Get function signature
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Get type hints
|
||||
type_hints = get_type_hints(func)
|
||||
|
||||
# Create field definitions
|
||||
fields = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self/cls for methods
|
||||
if param_name in ("self", "cls"):
|
||||
continue
|
||||
|
||||
# Get type annotation
|
||||
annotation = type_hints.get(param_name, Any)
|
||||
|
||||
# Get default value
|
||||
default = ... if param.default == param.empty else param.default
|
||||
|
||||
# Add field
|
||||
fields[param_name] = (annotation, Field(default=default))
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields)
|
||||
|
||||
def _validate_function_signature(self) -> None:
|
||||
"""Validate that the function signature matches the args schema."""
|
||||
sig = inspect.signature(self.func)
|
||||
schema_fields = self.args_schema.model_fields
|
||||
|
||||
# Check required parameters
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self/cls for methods
|
||||
if param_name in ("self", "cls"):
|
||||
continue
|
||||
|
||||
# Skip **kwargs parameters
|
||||
if param.kind in (
|
||||
inspect.Parameter.VAR_KEYWORD,
|
||||
inspect.Parameter.VAR_POSITIONAL,
|
||||
):
|
||||
continue
|
||||
|
||||
# Only validate required parameters without defaults
|
||||
if param.default == inspect.Parameter.empty:
|
||||
if param_name not in schema_fields:
|
||||
raise ValueError(
|
||||
f"Required function parameter '{param_name}' "
|
||||
f"not found in args_schema"
|
||||
)
|
||||
|
||||
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
Args:
|
||||
raw_args: The raw arguments to parse, either as a string or dict
|
||||
|
||||
Returns:
|
||||
The validated arguments as a dictionary
|
||||
"""
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
raw_args = json.loads(raw_args)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse arguments as JSON: {e}")
|
||||
|
||||
try:
|
||||
validated_args = self.args_schema.model_validate(raw_args)
|
||||
return validated_args.model_dump()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Arguments validation failed: {e}")
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, dict],
|
||||
config: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Asynchronously invoke the tool.
|
||||
|
||||
Args:
|
||||
input: The input arguments
|
||||
config: Optional configuration
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
The result of the tool execution
|
||||
"""
|
||||
parsed_args = self._parse_args(input)
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return await self.func(**parsed_args, **kwargs)
|
||||
else:
|
||||
# Run sync functions in a thread pool
|
||||
import asyncio
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.func(**parsed_args, **kwargs)
|
||||
)
|
||||
|
||||
def _run(self, *args, **kwargs) -> Any:
|
||||
"""Legacy method for compatibility."""
|
||||
# Convert args/kwargs to our expected format
|
||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
|
||||
input_dict.update(kwargs)
|
||||
return self.invoke(input_dict)
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Main method for tool execution."""
|
||||
parsed_args = self._parse_args(input)
|
||||
return self.func(**parsed_args, **kwargs)
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""Get the tool's input arguments schema."""
|
||||
return self.args_schema.model_json_schema()["properties"]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"CrewStructuredTool(name='{self.name}', description='{self.description}')"
|
||||
)
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
from crewai.tools import BaseTool, tool
|
||||
|
||||
|
||||
@@ -21,8 +22,7 @@ def test_creating_a_tool_using_annotation():
|
||||
my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
|
||||
)
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = my_tool.to_langchain()
|
||||
converted_tool = my_tool.to_structured_tool()
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
|
||||
assert (
|
||||
@@ -41,9 +41,7 @@ def test_creating_a_tool_using_annotation():
|
||||
def test_creating_a_tool_using_baseclass():
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = (
|
||||
"Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
@@ -61,8 +59,7 @@ def test_creating_a_tool_using_baseclass():
|
||||
}
|
||||
assert my_tool.run("What is the meaning of life?") == "What is the meaning of life?"
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = my_tool.to_langchain()
|
||||
converted_tool = my_tool.to_structured_tool()
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
|
||||
assert (
|
||||
@@ -73,7 +70,7 @@ def test_creating_a_tool_using_baseclass():
|
||||
"question": {"title": "Question", "type": "string"}
|
||||
}
|
||||
assert (
|
||||
converted_tool.run("What is the meaning of life?")
|
||||
converted_tool._run("What is the meaning of life?")
|
||||
== "What is the meaning of life?"
|
||||
)
|
||||
|
||||
@@ -81,9 +78,7 @@ def test_creating_a_tool_using_baseclass():
|
||||
def test_setting_cache_function():
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = (
|
||||
"Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
cache_function: Callable = lambda: False
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
@@ -97,9 +92,7 @@ def test_setting_cache_function():
|
||||
def test_default_cache_function_is_true():
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = (
|
||||
"Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
|
||||
146
tests/tools/test_structured_tool.py
Normal file
146
tests/tools/test_structured_tool.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
def basic_function():
|
||||
def test_func(param1: str, param2: int = 0) -> str:
|
||||
"""Test function with basic params."""
|
||||
return f"{param1} {param2}"
|
||||
|
||||
return test_func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def schema_class():
|
||||
class TestSchema(BaseModel):
|
||||
param1: str
|
||||
param2: int = Field(default=0)
|
||||
|
||||
return TestSchema
|
||||
|
||||
|
||||
class TestCrewStructuredTool:
|
||||
def test_initialization(self, basic_function, schema_class):
|
||||
"""Test basic initialization of CrewStructuredTool"""
|
||||
tool = CrewStructuredTool(
|
||||
name="test_tool",
|
||||
description="Test tool description",
|
||||
func=basic_function,
|
||||
args_schema=schema_class,
|
||||
)
|
||||
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "Test tool description"
|
||||
assert tool.func == basic_function
|
||||
assert tool.args_schema == schema_class
|
||||
|
||||
def test_from_function(self, basic_function):
|
||||
"""Test creating tool from function"""
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=basic_function, name="test_tool", description="Test description"
|
||||
)
|
||||
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "Test description"
|
||||
assert tool.func == basic_function
|
||||
assert isinstance(tool.args_schema, type(BaseModel))
|
||||
|
||||
def test_validate_function_signature(self, basic_function, schema_class):
|
||||
"""Test function signature validation"""
|
||||
tool = CrewStructuredTool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
func=basic_function,
|
||||
args_schema=schema_class,
|
||||
)
|
||||
|
||||
# Should not raise any exceptions
|
||||
tool._validate_function_signature()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke(self, basic_function):
|
||||
"""Test asynchronous invocation"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
|
||||
result = await tool.ainvoke(input={"param1": "test"})
|
||||
assert result == "test 0"
|
||||
|
||||
def test_parse_args_dict(self, basic_function):
|
||||
"""Test parsing dictionary arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
|
||||
parsed = tool._parse_args({"param1": "test", "param2": 42})
|
||||
assert parsed["param1"] == "test"
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
def test_parse_args_string(self, basic_function):
|
||||
"""Test parsing string arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
|
||||
parsed = tool._parse_args('{"param1": "test", "param2": 42}')
|
||||
assert parsed["param1"] == "test"
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
def test_complex_types(self):
|
||||
"""Test handling of complex parameter types"""
|
||||
|
||||
def complex_func(nested: dict, items: list) -> str:
|
||||
"""Process complex types."""
|
||||
return f"Processed {len(items)} items with {len(nested)} nested keys"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=complex_func, name="test_tool", description="Test complex types"
|
||||
)
|
||||
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
|
||||
assert result == "Processed 3 items with 1 nested keys"
|
||||
|
||||
def test_schema_inheritance(self):
|
||||
"""Test tool creation with inherited schema"""
|
||||
|
||||
def extended_func(base_param: str, extra_param: int) -> str:
|
||||
"""Test function with inherited schema."""
|
||||
return f"{base_param} {extra_param}"
|
||||
|
||||
class BaseSchema(BaseModel):
|
||||
base_param: str
|
||||
|
||||
class ExtendedSchema(BaseSchema):
|
||||
extra_param: int
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=extended_func, name="test_tool", args_schema=ExtendedSchema
|
||||
)
|
||||
|
||||
result = tool.invoke({"base_param": "test", "extra_param": 42})
|
||||
assert result == "test 42"
|
||||
|
||||
def test_default_values_in_schema(self):
|
||||
"""Test handling of default values in schema"""
|
||||
|
||||
def default_func(
|
||||
required_param: str,
|
||||
optional_param: str = "default",
|
||||
nullable_param: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Test function with default values."""
|
||||
return f"{required_param} {optional_param} {nullable_param}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=default_func, name="test_tool", description="Test defaults"
|
||||
)
|
||||
|
||||
# Test with minimal parameters
|
||||
result = tool.invoke({"required_param": "test"})
|
||||
assert result == "test default None"
|
||||
|
||||
# Test with all parameters
|
||||
result = tool.invoke(
|
||||
{"required_param": "test", "optional_param": "custom", "nullable_param": 42}
|
||||
)
|
||||
assert result == "test custom 42"
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -619,7 +619,6 @@ dependencies = [
|
||||
{ name = "instructor" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "jsonref" },
|
||||
{ name = "langchain" },
|
||||
{ name = "litellm" },
|
||||
{ name = "openai" },
|
||||
{ name = "openpyxl" },
|
||||
@@ -692,7 +691,6 @@ requires-dist = [
|
||||
{ name = "instructor", specifier = ">=1.3.3" },
|
||||
{ name = "json-repair", specifier = ">=0.25.2" },
|
||||
{ name = "jsonref", specifier = ">=1.1.0" },
|
||||
{ name = "langchain", specifier = ">=0.2.16" },
|
||||
{ name = "litellm", specifier = ">=1.44.22" },
|
||||
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = ">=0.1.29" },
|
||||
{ name = "openai", specifier = ">=1.13.3" },
|
||||
|
||||
Reference in New Issue
Block a user