From 293305790d60ff7a8940ff1fafc1ffb27329f4cd Mon Sep 17 00:00:00 2001 From: Eduardo Chiarotti Date: Tue, 26 Nov 2024 16:59:52 -0300 Subject: [PATCH] Feat/remove langchain (#1654) * feat: add initial changes from langchain * feat: remove kwargs of being processed * feat: remove langchain, update uv.lock and fix type_hint * feat: change docs * feat: remove forced requirements for parameter * feat add tests for new structure tool * feat: fix tests and adapt code for args --- pyproject.toml | 1 - src/crewai/agent.py | 2 +- src/crewai/tools/base_tool.py | 18 +- src/crewai/tools/cache_tools/cache_tools.py | 5 +- src/crewai/tools/structured_tool.py | 242 ++++++++++++++++++++ tests/tools/test_base_tool.py | 21 +- tests/tools/test_structured_tool.py | 146 ++++++++++++ uv.lock | 2 - 8 files changed, 408 insertions(+), 29 deletions(-) create mode 100644 src/crewai/tools/structured_tool.py create mode 100644 tests/tools/test_structured_tool.py diff --git a/pyproject.toml b/pyproject.toml index 4bd416ab5..1d7d8cc43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,6 @@ authors = [ ] dependencies = [ "pydantic>=2.4.2", - "langchain>=0.2.16", "openai>=1.13.3", "opentelemetry-api>=1.22.0", "opentelemetry-sdk>=1.22.0", diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 520ee40fd..30edb3be4 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -412,7 +412,7 @@ class Agent(BaseAgent): for tool in tools: if isinstance(tool, CrewAITool): - tools_list.append(tool.to_langchain()) + tools_list.append(tool.to_structured_tool()) else: tools_list.append(tool) except ModuleNotFoundError: diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index f41fb7c0b..06e427528 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Type, get_args, get_origin -from langchain_core.tools import StructuredTool from pydantic import BaseModel, ConfigDict, Field, validator from pydantic import BaseModel as PydanticBaseModel +from crewai.tools.structured_tool import CrewStructuredTool + class BaseTool(BaseModel, ABC): class _ArgsSchemaPlaceholder(PydanticBaseModel): @@ -63,9 +64,10 @@ class BaseTool(BaseModel, ABC): ) -> Any: """Here goes the actual implementation of the tool.""" - def to_langchain(self) -> StructuredTool: + def to_structured_tool(self) -> CrewStructuredTool: + """Convert this tool to a CrewStructuredTool instance.""" self._set_args_schema() - return StructuredTool( + return CrewStructuredTool( name=self.name, description=self.description, args_schema=self.args_schema, @@ -73,10 +75,10 @@ class BaseTool(BaseModel, ABC): ) @classmethod - def from_langchain(cls, tool: StructuredTool) -> "BaseTool": + def from_langchain(cls, tool: CrewStructuredTool) -> "BaseTool": if cls == Tool: if tool.func is None: - raise ValueError("StructuredTool must have a callable 'func'") + raise ValueError("CrewStructuredTool must have a callable 'func'") return Tool( name=tool.name, description=tool.description, @@ -142,9 +144,9 @@ class Tool(BaseTool): def to_langchain( - tools: list[BaseTool | StructuredTool], -) -> list[StructuredTool]: - return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools] + tools: list[BaseTool | CrewStructuredTool], +) -> list[CrewStructuredTool]: + return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools] def tool(*args): diff --git a/src/crewai/tools/cache_tools/cache_tools.py b/src/crewai/tools/cache_tools/cache_tools.py index a0bb2dbad..a81ce98cf 100644 --- a/src/crewai/tools/cache_tools/cache_tools.py +++ b/src/crewai/tools/cache_tools/cache_tools.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field from crewai.agents.cache import CacheHandler +from crewai.tools.structured_tool import CrewStructuredTool class CacheTools(BaseModel): @@ -13,9 +14,7 @@ class CacheTools(BaseModel): ) def tool(self): - from langchain.tools import StructuredTool - - return StructuredTool.from_function( + return CrewStructuredTool.from_function( func=self.hit_cache, name=self.name, description="Reads directly from the cache", diff --git a/src/crewai/tools/structured_tool.py b/src/crewai/tools/structured_tool.py new file mode 100644 index 000000000..bd6818605 --- /dev/null +++ b/src/crewai/tools/structured_tool.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import inspect +import textwrap +from typing import Any, Callable, Optional, Union, get_type_hints + +from pydantic import BaseModel, Field, create_model + +from crewai.utilities.logger import Logger + + +class CrewStructuredTool: + """A structured tool that can operate on any number of inputs. + + This tool intends to replace StructuredTool with a custom implementation + that integrates better with CrewAI's ecosystem. + """ + + def __init__( + self, + name: str, + description: str, + args_schema: type[BaseModel], + func: Callable[..., Any], + ) -> None: + """Initialize the structured tool. + + Args: + name: The name of the tool + description: A description of what the tool does + args_schema: The pydantic model for the tool's arguments + func: The function to run when the tool is called + """ + self.name = name + self.description = description + self.args_schema = args_schema + self.func = func + self._logger = Logger() + + # Validate the function signature matches the schema + self._validate_function_signature() + + @classmethod + def from_function( + cls, + func: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + args_schema: Optional[type[BaseModel]] = None, + infer_schema: bool = True, + **kwargs: Any, + ) -> CrewStructuredTool: + """Create a tool from a function. + + Args: + func: The function to create a tool from + name: The name of the tool. Defaults to the function name + description: The description of the tool. Defaults to the function docstring + return_direct: Whether to return the output directly + args_schema: Optional schema for the function arguments + infer_schema: Whether to infer the schema from the function signature + **kwargs: Additional arguments to pass to the tool + + Returns: + A CrewStructuredTool instance + + Example: + >>> def add(a: int, b: int) -> int: + ... '''Add two numbers''' + ... return a + b + >>> tool = CrewStructuredTool.from_function(add) + """ + name = name or func.__name__ + description = description or inspect.getdoc(func) + + if description is None: + raise ValueError( + f"Function {name} must have a docstring if description not provided." + ) + + # Clean up the description + description = textwrap.dedent(description).strip() + + if args_schema is not None: + # Use provided schema + schema = args_schema + elif infer_schema: + # Infer schema from function signature + schema = cls._create_schema_from_function(name, func) + else: + raise ValueError( + "Either args_schema must be provided or infer_schema must be True." + ) + + return cls( + name=name, + description=description, + args_schema=schema, + func=func, + ) + + @staticmethod + def _create_schema_from_function( + name: str, + func: Callable, + ) -> type[BaseModel]: + """Create a Pydantic schema from a function's signature. + + Args: + name: The name to use for the schema + func: The function to create a schema from + + Returns: + A Pydantic model class + """ + # Get function signature + sig = inspect.signature(func) + + # Get type hints + type_hints = get_type_hints(func) + + # Create field definitions + fields = {} + for param_name, param in sig.parameters.items(): + # Skip self/cls for methods + if param_name in ("self", "cls"): + continue + + # Get type annotation + annotation = type_hints.get(param_name, Any) + + # Get default value + default = ... if param.default == param.empty else param.default + + # Add field + fields[param_name] = (annotation, Field(default=default)) + + # Create model + schema_name = f"{name.title()}Schema" + return create_model(schema_name, **fields) + + def _validate_function_signature(self) -> None: + """Validate that the function signature matches the args schema.""" + sig = inspect.signature(self.func) + schema_fields = self.args_schema.model_fields + + # Check required parameters + for param_name, param in sig.parameters.items(): + # Skip self/cls for methods + if param_name in ("self", "cls"): + continue + + # Skip **kwargs parameters + if param.kind in ( + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + ): + continue + + # Only validate required parameters without defaults + if param.default == inspect.Parameter.empty: + if param_name not in schema_fields: + raise ValueError( + f"Required function parameter '{param_name}' " + f"not found in args_schema" + ) + + def _parse_args(self, raw_args: Union[str, dict]) -> dict: + """Parse and validate the input arguments against the schema. + + Args: + raw_args: The raw arguments to parse, either as a string or dict + + Returns: + The validated arguments as a dictionary + """ + if isinstance(raw_args, str): + try: + import json + + raw_args = json.loads(raw_args) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse arguments as JSON: {e}") + + try: + validated_args = self.args_schema.model_validate(raw_args) + return validated_args.model_dump() + except Exception as e: + raise ValueError(f"Arguments validation failed: {e}") + + async def ainvoke( + self, + input: Union[str, dict], + config: Optional[dict] = None, + **kwargs: Any, + ) -> Any: + """Asynchronously invoke the tool. + + Args: + input: The input arguments + config: Optional configuration + **kwargs: Additional keyword arguments + + Returns: + The result of the tool execution + """ + parsed_args = self._parse_args(input) + + if inspect.iscoroutinefunction(self.func): + return await self.func(**parsed_args, **kwargs) + else: + # Run sync functions in a thread pool + import asyncio + + return await asyncio.get_event_loop().run_in_executor( + None, lambda: self.func(**parsed_args, **kwargs) + ) + + def _run(self, *args, **kwargs) -> Any: + """Legacy method for compatibility.""" + # Convert args/kwargs to our expected format + input_dict = dict(zip(self.args_schema.model_fields.keys(), args)) + input_dict.update(kwargs) + return self.invoke(input_dict) + + def invoke( + self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any + ) -> Any: + """Main method for tool execution.""" + parsed_args = self._parse_args(input) + return self.func(**parsed_args, **kwargs) + + @property + def args(self) -> dict: + """Get the tool's input arguments schema.""" + return self.args_schema.model_json_schema()["properties"] + + def __repr__(self) -> str: + return ( + f"CrewStructuredTool(name='{self.name}', description='{self.description}')" + ) diff --git a/tests/tools/test_base_tool.py b/tests/tools/test_base_tool.py index eca36739c..cd4b53caf 100644 --- a/tests/tools/test_base_tool.py +++ b/tests/tools/test_base_tool.py @@ -1,4 +1,5 @@ from typing import Callable + from crewai.tools import BaseTool, tool @@ -21,8 +22,7 @@ def test_creating_a_tool_using_annotation(): my_tool.func("What is the meaning of life?") == "What is the meaning of life?" ) - # Assert the langchain tool conversion worked as expected - converted_tool = my_tool.to_langchain() + converted_tool = my_tool.to_structured_tool() assert converted_tool.name == "Name of my tool" assert ( @@ -41,9 +41,7 @@ def test_creating_a_tool_using_annotation(): def test_creating_a_tool_using_baseclass(): class MyCustomTool(BaseTool): name: str = "Name of my tool" - description: str = ( - "Clear description for what this tool is useful for, you agent will need this information to use it." - ) + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." def _run(self, question: str) -> str: return question @@ -61,8 +59,7 @@ def test_creating_a_tool_using_baseclass(): } assert my_tool.run("What is the meaning of life?") == "What is the meaning of life?" - # Assert the langchain tool conversion worked as expected - converted_tool = my_tool.to_langchain() + converted_tool = my_tool.to_structured_tool() assert converted_tool.name == "Name of my tool" assert ( @@ -73,7 +70,7 @@ def test_creating_a_tool_using_baseclass(): "question": {"title": "Question", "type": "string"} } assert ( - converted_tool.run("What is the meaning of life?") + converted_tool._run("What is the meaning of life?") == "What is the meaning of life?" ) @@ -81,9 +78,7 @@ def test_creating_a_tool_using_baseclass(): def test_setting_cache_function(): class MyCustomTool(BaseTool): name: str = "Name of my tool" - description: str = ( - "Clear description for what this tool is useful for, you agent will need this information to use it." - ) + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." cache_function: Callable = lambda: False def _run(self, question: str) -> str: @@ -97,9 +92,7 @@ def test_setting_cache_function(): def test_default_cache_function_is_true(): class MyCustomTool(BaseTool): name: str = "Name of my tool" - description: str = ( - "Clear description for what this tool is useful for, you agent will need this information to use it." - ) + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." def _run(self, question: str) -> str: return question diff --git a/tests/tools/test_structured_tool.py b/tests/tools/test_structured_tool.py new file mode 100644 index 000000000..32ebd805b --- /dev/null +++ b/tests/tools/test_structured_tool.py @@ -0,0 +1,146 @@ +from typing import Optional + +import pytest +from pydantic import BaseModel, Field + +from crewai.tools.structured_tool import CrewStructuredTool + + +# Test fixtures +@pytest.fixture +def basic_function(): + def test_func(param1: str, param2: int = 0) -> str: + """Test function with basic params.""" + return f"{param1} {param2}" + + return test_func + + +@pytest.fixture +def schema_class(): + class TestSchema(BaseModel): + param1: str + param2: int = Field(default=0) + + return TestSchema + + +class TestCrewStructuredTool: + def test_initialization(self, basic_function, schema_class): + """Test basic initialization of CrewStructuredTool""" + tool = CrewStructuredTool( + name="test_tool", + description="Test tool description", + func=basic_function, + args_schema=schema_class, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test tool description" + assert tool.func == basic_function + assert tool.args_schema == schema_class + + def test_from_function(self, basic_function): + """Test creating tool from function""" + tool = CrewStructuredTool.from_function( + func=basic_function, name="test_tool", description="Test description" + ) + + assert tool.name == "test_tool" + assert tool.description == "Test description" + assert tool.func == basic_function + assert isinstance(tool.args_schema, type(BaseModel)) + + def test_validate_function_signature(self, basic_function, schema_class): + """Test function signature validation""" + tool = CrewStructuredTool( + name="test_tool", + description="Test tool", + func=basic_function, + args_schema=schema_class, + ) + + # Should not raise any exceptions + tool._validate_function_signature() + + @pytest.mark.asyncio + async def test_ainvoke(self, basic_function): + """Test asynchronous invocation""" + tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") + + result = await tool.ainvoke(input={"param1": "test"}) + assert result == "test 0" + + def test_parse_args_dict(self, basic_function): + """Test parsing dictionary arguments""" + tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") + + parsed = tool._parse_args({"param1": "test", "param2": 42}) + assert parsed["param1"] == "test" + assert parsed["param2"] == 42 + + def test_parse_args_string(self, basic_function): + """Test parsing string arguments""" + tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") + + parsed = tool._parse_args('{"param1": "test", "param2": 42}') + assert parsed["param1"] == "test" + assert parsed["param2"] == 42 + + def test_complex_types(self): + """Test handling of complex parameter types""" + + def complex_func(nested: dict, items: list) -> str: + """Process complex types.""" + return f"Processed {len(items)} items with {len(nested)} nested keys" + + tool = CrewStructuredTool.from_function( + func=complex_func, name="test_tool", description="Test complex types" + ) + result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]}) + assert result == "Processed 3 items with 1 nested keys" + + def test_schema_inheritance(self): + """Test tool creation with inherited schema""" + + def extended_func(base_param: str, extra_param: int) -> str: + """Test function with inherited schema.""" + return f"{base_param} {extra_param}" + + class BaseSchema(BaseModel): + base_param: str + + class ExtendedSchema(BaseSchema): + extra_param: int + + tool = CrewStructuredTool.from_function( + func=extended_func, name="test_tool", args_schema=ExtendedSchema + ) + + result = tool.invoke({"base_param": "test", "extra_param": 42}) + assert result == "test 42" + + def test_default_values_in_schema(self): + """Test handling of default values in schema""" + + def default_func( + required_param: str, + optional_param: str = "default", + nullable_param: Optional[int] = None, + ) -> str: + """Test function with default values.""" + return f"{required_param} {optional_param} {nullable_param}" + + tool = CrewStructuredTool.from_function( + func=default_func, name="test_tool", description="Test defaults" + ) + + # Test with minimal parameters + result = tool.invoke({"required_param": "test"}) + assert result == "test default None" + + # Test with all parameters + result = tool.invoke( + {"required_param": "test", "optional_param": "custom", "nullable_param": 42} + ) + assert result == "test custom 42" diff --git a/uv.lock b/uv.lock index baff09ac8..050602e61 100644 --- a/uv.lock +++ b/uv.lock @@ -619,7 +619,6 @@ dependencies = [ { name = "instructor" }, { name = "json-repair" }, { name = "jsonref" }, - { name = "langchain" }, { name = "litellm" }, { name = "openai" }, { name = "openpyxl" }, @@ -692,7 +691,6 @@ requires-dist = [ { name = "instructor", specifier = ">=1.3.3" }, { name = "json-repair", specifier = ">=0.25.2" }, { name = "jsonref", specifier = ">=1.1.0" }, - { name = "langchain", specifier = ">=0.2.16" }, { name = "litellm", specifier = ">=1.44.22" }, { name = "mem0ai", marker = "extra == 'mem0'", specifier = ">=0.1.29" }, { name = "openai", specifier = ">=1.13.3" },