feat: fix tests and adapt code for args

This commit is contained in:
Eduardo Chiarotti
2024-11-25 19:45:41 -03:00
parent 65f0730d5a
commit 44b9fc5058
2 changed files with 20 additions and 29 deletions

View File

@@ -152,7 +152,10 @@ class CrewStructuredTool:
continue continue
# Skip **kwargs parameters # Skip **kwargs parameters
if param.kind == inspect.Parameter.VAR_KEYWORD: if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
):
continue continue
# Only validate required parameters without defaults # Only validate required parameters without defaults
@@ -214,22 +217,17 @@ class CrewStructuredTool:
None, lambda: self.func(**parsed_args, **kwargs) None, lambda: self.func(**parsed_args, **kwargs)
) )
def _run(self, *args, **kwargs) -> Any:
"""Legacy method for compatibility."""
# Convert args/kwargs to our expected format
input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
input_dict.update(kwargs)
return self.invoke(input_dict)
def invoke( def invoke(
self, self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
input: Union[str, dict],
config: Optional[dict] = None,
**kwargs: Any,
) -> Any: ) -> Any:
"""Synchronously invoke the tool. """Main method for tool execution."""
Args:
input: The input arguments
config: Optional configuration
**kwargs: Additional keyword arguments
Returns:
The result of the tool execution
"""
parsed_args = self._parse_args(input) parsed_args = self._parse_args(input)
return self.func(**parsed_args, **kwargs) return self.func(**parsed_args, **kwargs)

View File

@@ -1,4 +1,5 @@
from typing import Callable from typing import Callable
from crewai.tools import BaseTool, tool from crewai.tools import BaseTool, tool
@@ -21,8 +22,7 @@ def test_creating_a_tool_using_annotation():
my_tool.func("What is the meaning of life?") == "What is the meaning of life?" my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
) )
# Assert the langchain tool conversion worked as expected converted_tool = my_tool.to_structured_tool()
converted_tool = my_tool.to_langchain()
assert converted_tool.name == "Name of my tool" assert converted_tool.name == "Name of my tool"
assert ( assert (
@@ -41,9 +41,7 @@ def test_creating_a_tool_using_annotation():
def test_creating_a_tool_using_baseclass(): def test_creating_a_tool_using_baseclass():
class MyCustomTool(BaseTool): class MyCustomTool(BaseTool):
name: str = "Name of my tool" name: str = "Name of my tool"
description: str = ( description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
"Clear description for what this tool is useful for, you agent will need this information to use it."
)
def _run(self, question: str) -> str: def _run(self, question: str) -> str:
return question return question
@@ -61,8 +59,7 @@ def test_creating_a_tool_using_baseclass():
} }
assert my_tool.run("What is the meaning of life?") == "What is the meaning of life?" assert my_tool.run("What is the meaning of life?") == "What is the meaning of life?"
# Assert the langchain tool conversion worked as expected converted_tool = my_tool.to_structured_tool()
converted_tool = my_tool.to_langchain()
assert converted_tool.name == "Name of my tool" assert converted_tool.name == "Name of my tool"
assert ( assert (
@@ -73,7 +70,7 @@ def test_creating_a_tool_using_baseclass():
"question": {"title": "Question", "type": "string"} "question": {"title": "Question", "type": "string"}
} }
assert ( assert (
converted_tool.run("What is the meaning of life?") converted_tool._run("What is the meaning of life?")
== "What is the meaning of life?" == "What is the meaning of life?"
) )
@@ -81,9 +78,7 @@ def test_creating_a_tool_using_baseclass():
def test_setting_cache_function(): def test_setting_cache_function():
class MyCustomTool(BaseTool): class MyCustomTool(BaseTool):
name: str = "Name of my tool" name: str = "Name of my tool"
description: str = ( description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
"Clear description for what this tool is useful for, you agent will need this information to use it."
)
cache_function: Callable = lambda: False cache_function: Callable = lambda: False
def _run(self, question: str) -> str: def _run(self, question: str) -> str:
@@ -97,9 +92,7 @@ def test_setting_cache_function():
def test_default_cache_function_is_true(): def test_default_cache_function_is_true():
class MyCustomTool(BaseTool): class MyCustomTool(BaseTool):
name: str = "Name of my tool" name: str = "Name of my tool"
description: str = ( description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
"Clear description for what this tool is useful for, you agent will need this information to use it."
)
def _run(self, question: str) -> str: def _run(self, question: str) -> str:
return question return question