mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 13:58:15 +00:00
Compare commits
3 Commits
devin/1768
...
devin/1757
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
acff336363 | ||
|
|
bbf76d0e42 | ||
|
|
a533e111e8 |
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
class ToolUsageLimitExceeded(Exception):
|
class ToolUsageLimitExceededError(Exception):
|
||||||
"""Exception raised when a tool has reached its maximum usage limit."""
|
"""Exception raised when a tool has reached its maximum usage limit."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
@@ -164,7 +164,7 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
schema_name = f"{name.title()}Schema"
|
schema_name = f"{name.title()}Schema"
|
||||||
return create_model(schema_name, **fields)
|
return create_model(schema_name, **fields) # type: ignore[call-overload]
|
||||||
|
|
||||||
def _validate_function_signature(self) -> None:
|
def _validate_function_signature(self) -> None:
|
||||||
"""Validate that the function signature matches the args schema."""
|
"""Validate that the function signature matches the args schema."""
|
||||||
@@ -207,13 +207,13 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
raw_args = json.loads(raw_args)
|
raw_args = json.loads(raw_args)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise ValueError(f"Failed to parse arguments as JSON: {e}")
|
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
validated_args = self.args_schema.model_validate(raw_args)
|
validated_args = self.args_schema.model_validate(raw_args)
|
||||||
return validated_args.model_dump()
|
return validated_args.model_dump()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Arguments validation failed: {e}")
|
raise ValueError(f"Arguments validation failed: {e}") from e
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
@@ -234,7 +234,7 @@ class CrewStructuredTool:
|
|||||||
parsed_args = self._parse_args(input)
|
parsed_args = self._parse_args(input)
|
||||||
|
|
||||||
if self.has_reached_max_usage_count():
|
if self.has_reached_max_usage_count():
|
||||||
raise ToolUsageLimitExceeded(
|
raise ToolUsageLimitExceededError(
|
||||||
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -267,23 +267,20 @@ class CrewStructuredTool:
|
|||||||
parsed_args = self._parse_args(input)
|
parsed_args = self._parse_args(input)
|
||||||
|
|
||||||
if self.has_reached_max_usage_count():
|
if self.has_reached_max_usage_count():
|
||||||
raise ToolUsageLimitExceeded(
|
raise ToolUsageLimitExceededError(
|
||||||
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
||||||
)
|
)
|
||||||
|
|
||||||
self._increment_usage_count()
|
self._increment_usage_count()
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(self.func):
|
if inspect.iscoroutinefunction(self.func):
|
||||||
result = asyncio.run(self.func(**parsed_args, **kwargs))
|
return asyncio.run(self.func(**parsed_args, **kwargs))
|
||||||
return result
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.func(**parsed_args, **kwargs)
|
result = self.func(**parsed_args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
result = self.func(**parsed_args, **kwargs)
|
|
||||||
|
|
||||||
if asyncio.iscoroutine(result):
|
if asyncio.iscoroutine(result):
|
||||||
return asyncio.run(result)
|
return asyncio.run(result)
|
||||||
|
|
||||||
|
|||||||
@@ -34,10 +34,10 @@ def test_initialization(basic_function, schema_class):
|
|||||||
args_schema=schema_class,
|
args_schema=schema_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert tool.name == "test_tool"
|
assert tool.name == "test_tool" # noqa: S101
|
||||||
assert tool.description == "Test tool description"
|
assert tool.description == "Test tool description" # noqa: S101
|
||||||
assert tool.func == basic_function
|
assert tool.func == basic_function # noqa: S101
|
||||||
assert tool.args_schema == schema_class
|
assert tool.args_schema == schema_class # noqa: S101
|
||||||
|
|
||||||
def test_from_function(basic_function):
|
def test_from_function(basic_function):
|
||||||
"""Test creating tool from function"""
|
"""Test creating tool from function"""
|
||||||
@@ -45,10 +45,10 @@ def test_from_function(basic_function):
|
|||||||
func=basic_function, name="test_tool", description="Test description"
|
func=basic_function, name="test_tool", description="Test description"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert tool.name == "test_tool"
|
assert tool.name == "test_tool" # noqa: S101
|
||||||
assert tool.description == "Test description"
|
assert tool.description == "Test description" # noqa: S101
|
||||||
assert tool.func == basic_function
|
assert tool.func == basic_function # noqa: S101
|
||||||
assert isinstance(tool.args_schema, type(BaseModel))
|
assert isinstance(tool.args_schema, type(BaseModel)) # noqa: S101
|
||||||
|
|
||||||
def test_validate_function_signature(basic_function, schema_class):
|
def test_validate_function_signature(basic_function, schema_class):
|
||||||
"""Test function signature validation"""
|
"""Test function signature validation"""
|
||||||
@@ -68,23 +68,23 @@ async def test_ainvoke(basic_function):
|
|||||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||||
|
|
||||||
result = await tool.ainvoke(input={"param1": "test"})
|
result = await tool.ainvoke(input={"param1": "test"})
|
||||||
assert result == "test 0"
|
assert result == "test 0" # noqa: S101
|
||||||
|
|
||||||
def test_parse_args_dict(basic_function):
|
def test_parse_args_dict(basic_function):
|
||||||
"""Test parsing dictionary arguments"""
|
"""Test parsing dictionary arguments"""
|
||||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||||
|
|
||||||
parsed = tool._parse_args({"param1": "test", "param2": 42})
|
parsed = tool._parse_args({"param1": "test", "param2": 42})
|
||||||
assert parsed["param1"] == "test"
|
assert parsed["param1"] == "test" # noqa: S101
|
||||||
assert parsed["param2"] == 42
|
assert parsed["param2"] == 42 # noqa: S101
|
||||||
|
|
||||||
def test_parse_args_string(basic_function):
|
def test_parse_args_string(basic_function):
|
||||||
"""Test parsing string arguments"""
|
"""Test parsing string arguments"""
|
||||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||||
|
|
||||||
parsed = tool._parse_args('{"param1": "test", "param2": 42}')
|
parsed = tool._parse_args('{"param1": "test", "param2": 42}')
|
||||||
assert parsed["param1"] == "test"
|
assert parsed["param1"] == "test" # noqa: S101
|
||||||
assert parsed["param2"] == 42
|
assert parsed["param2"] == 42 # noqa: S101
|
||||||
|
|
||||||
def test_complex_types():
|
def test_complex_types():
|
||||||
"""Test handling of complex parameter types"""
|
"""Test handling of complex parameter types"""
|
||||||
@@ -97,7 +97,7 @@ def test_complex_types():
|
|||||||
func=complex_func, name="test_tool", description="Test complex types"
|
func=complex_func, name="test_tool", description="Test complex types"
|
||||||
)
|
)
|
||||||
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
|
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
|
||||||
assert result == "Processed 3 items with 1 nested keys"
|
assert result == "Processed 3 items with 1 nested keys" # noqa: S101
|
||||||
|
|
||||||
def test_schema_inheritance():
|
def test_schema_inheritance():
|
||||||
"""Test tool creation with inherited schema"""
|
"""Test tool creation with inherited schema"""
|
||||||
@@ -117,7 +117,7 @@ def test_schema_inheritance():
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = tool.invoke({"base_param": "test", "extra_param": 42})
|
result = tool.invoke({"base_param": "test", "extra_param": 42})
|
||||||
assert result == "test 42"
|
assert result == "test 42" # noqa: S101
|
||||||
|
|
||||||
def test_default_values_in_schema():
|
def test_default_values_in_schema():
|
||||||
"""Test handling of default values in schema"""
|
"""Test handling of default values in schema"""
|
||||||
@@ -136,13 +136,39 @@ def test_default_values_in_schema():
|
|||||||
|
|
||||||
# Test with minimal parameters
|
# Test with minimal parameters
|
||||||
result = tool.invoke({"required_param": "test"})
|
result = tool.invoke({"required_param": "test"})
|
||||||
assert result == "test default None"
|
assert result == "test default None" # noqa: S101
|
||||||
|
|
||||||
# Test with all parameters
|
# Test with all parameters
|
||||||
result = tool.invoke(
|
result = tool.invoke(
|
||||||
{"required_param": "test", "optional_param": "custom", "nullable_param": 42}
|
{"required_param": "test", "optional_param": "custom", "nullable_param": 42}
|
||||||
)
|
)
|
||||||
assert result == "test custom 42"
|
assert result == "test custom 42" # noqa: S101
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_not_executed_twice():
|
||||||
|
"""Test that tool function is only executed once per invoke call (bug #3489)"""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def counting_func(param: str) -> str:
|
||||||
|
"""Function that counts how many times it's called."""
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return f"Called {call_count} times with {param}"
|
||||||
|
|
||||||
|
tool = CrewStructuredTool.from_function(
|
||||||
|
func=counting_func, name="counting_tool", description="Counts calls"
|
||||||
|
)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
result = tool.invoke({"param": "test"})
|
||||||
|
|
||||||
|
assert call_count == 1, f"Expected function to be called once, but was called {call_count} times" # noqa: S101
|
||||||
|
assert result == "Called 1 times with test" # noqa: S101
|
||||||
|
|
||||||
|
result = tool.invoke({"param": "test2"})
|
||||||
|
assert call_count == 2, f"Expected function to be called twice total, but was called {call_count} times" # noqa: S101
|
||||||
|
assert result == "Called 2 times with test2" # noqa: S101
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def custom_tool_decorator():
|
def custom_tool_decorator():
|
||||||
@@ -178,22 +204,21 @@ def build_simple_crew(tool):
|
|||||||
description="Use the custom tool result as answer.", agent=agent1, expected_output="Use the tool result"
|
description="Use the custom tool result as answer.", agent=agent1, expected_output="Use the tool result"
|
||||||
)
|
)
|
||||||
|
|
||||||
crew = Crew(agents=[agent1], tasks=[say_hi_task])
|
return Crew(agents=[agent1], tasks=[say_hi_task])
|
||||||
return crew
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_async_tool_using_within_isolated_crew(custom_tool):
|
def test_async_tool_using_within_isolated_crew(custom_tool):
|
||||||
crew = build_simple_crew(custom_tool)
|
crew = build_simple_crew(custom_tool)
|
||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
|
|
||||||
assert result.raw == "Hello World from Custom Tool"
|
assert result.raw == "Hello World from Custom Tool" # noqa: S101
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
||||||
crew = build_simple_crew(custom_tool_decorator)
|
crew = build_simple_crew(custom_tool_decorator)
|
||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
|
|
||||||
assert result.raw == "Hello World from Custom Tool"
|
assert result.raw == "Hello World from Custom Tool" # noqa: S101
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_async_tool_within_flow(custom_tool):
|
def test_async_tool_within_flow(custom_tool):
|
||||||
@@ -205,12 +230,11 @@ def test_async_tool_within_flow(custom_tool):
|
|||||||
@start()
|
@start()
|
||||||
async def start(self):
|
async def start(self):
|
||||||
crew = build_simple_crew(custom_tool)
|
crew = build_simple_crew(custom_tool)
|
||||||
result = await crew.kickoff_async()
|
return await crew.kickoff_async()
|
||||||
return result
|
|
||||||
|
|
||||||
flow = StructuredExampleFlow()
|
flow = StructuredExampleFlow()
|
||||||
result = flow.kickoff()
|
result = flow.kickoff()
|
||||||
assert result.raw == "Hello World from Custom Tool"
|
assert result.raw == "Hello World from Custom Tool" # noqa: S101
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@@ -222,9 +246,8 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
|||||||
@start()
|
@start()
|
||||||
async def start(self):
|
async def start(self):
|
||||||
crew = build_simple_crew(custom_tool_decorator)
|
crew = build_simple_crew(custom_tool_decorator)
|
||||||
result = await crew.kickoff_async()
|
return await crew.kickoff_async()
|
||||||
return result
|
|
||||||
|
|
||||||
flow = StructuredExampleFlow()
|
flow = StructuredExampleFlow()
|
||||||
result = flow.kickoff()
|
result = flow.kickoff()
|
||||||
assert result.raw == "Hello World from Custom Tool"
|
assert result.raw == "Hello World from Custom Tool" # noqa: S101
|
||||||
|
|||||||
Reference in New Issue
Block a user