mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
14 Commits
devin/1760
...
devin/1740
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c801cf5279 | ||
|
|
10a0e260a8 | ||
|
|
296039c345 | ||
|
|
33a8e0254b | ||
|
|
2cf0b0c342 | ||
|
|
3c41d3aa60 | ||
|
|
4d8817dfa4 | ||
|
|
d64fde1a79 | ||
|
|
78a4ab6ff6 | ||
|
|
3c9e066779 | ||
|
|
a3b3b411df | ||
|
|
b44842d1de | ||
|
|
570977acf8 | ||
|
|
c41bb4b8c7 |
@@ -1,13 +1,22 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic import Field as PydanticField
|
||||
|
||||
|
||||
class ToolArguments(TypedDict, total=False):
|
||||
"""Arguments that can be passed to a tool.
|
||||
|
||||
Set total=False to make all fields optional, which maintains backward
|
||||
compatibility with existing tools that may not use all arguments.
|
||||
"""
|
||||
question: str
|
||||
|
||||
|
||||
class ToolCalling(BaseModel):
|
||||
tool_name: str = Field(..., description="The name of the tool to be called.")
|
||||
arguments: Optional[Dict[str, Any]] = Field(
|
||||
arguments: Optional[ToolArguments] = Field(
|
||||
..., description="A dictionary of arguments to be passed to the tool."
|
||||
)
|
||||
|
||||
@@ -16,6 +25,6 @@ class InstructorToolCalling(PydanticBaseModel):
|
||||
tool_name: str = PydanticField(
|
||||
..., description="The name of the tool to be called."
|
||||
)
|
||||
arguments: Optional[Dict[str, Any]] = PydanticField(
|
||||
arguments: Optional[ToolArguments] = PydanticField(
|
||||
..., description="A dictionary of arguments to be passed to the tool."
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
OPENAI_BIGGER_MODELS = [
|
||||
"gpt-4",
|
||||
@@ -74,6 +75,7 @@ class ToolUsage:
|
||||
self._i18n: I18N = agent.i18n
|
||||
self._printer: Printer = Printer()
|
||||
self._telemetry: Telemetry = Telemetry()
|
||||
self._logger: Logger = Logger()
|
||||
self._run_attempts: int = 1
|
||||
self._max_parsing_attempts: int = 3
|
||||
self._remember_format_after_usages: int = 3
|
||||
@@ -184,22 +186,26 @@ class ToolUsage:
|
||||
)
|
||||
self.task.increment_delegations(coworker)
|
||||
|
||||
if calling.arguments:
|
||||
try:
|
||||
acceptable_args = tool.args_schema.model_json_schema()[
|
||||
"properties"
|
||||
].keys() # type: ignore
|
||||
arguments = {
|
||||
k: v
|
||||
for k, v in calling.arguments.items()
|
||||
if k in acceptable_args
|
||||
}
|
||||
result = tool.invoke(input=arguments)
|
||||
except Exception:
|
||||
arguments = calling.arguments
|
||||
result = tool.invoke(input=arguments)
|
||||
else:
|
||||
result = tool.invoke(input={})
|
||||
if not calling.arguments:
|
||||
raise ValueError("Tool arguments cannot be empty")
|
||||
|
||||
try:
|
||||
acceptable_args = tool.args_schema.model_json_schema()[
|
||||
"properties"
|
||||
].keys() # type: ignore
|
||||
arguments = {
|
||||
k: v
|
||||
for k, v in calling.arguments.items()
|
||||
if k in acceptable_args
|
||||
}
|
||||
result = tool.invoke(input=arguments)
|
||||
except Exception as e:
|
||||
if isinstance(e, TypeError) and "missing 1 required positional argument" in str(e):
|
||||
raise ValueError("Required arguments missing for tool")
|
||||
arguments = calling.arguments
|
||||
result = tool.invoke(input=arguments)
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
except Exception as e:
|
||||
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||
self._run_attempts += 1
|
||||
@@ -283,13 +289,50 @@ class ToolUsage:
|
||||
|
||||
def _check_tool_repeated_usage(
|
||||
self, calling: Union[ToolCalling, InstructorToolCalling]
|
||||
) -> None:
|
||||
if not self.tools_handler:
|
||||
return False # type: ignore # No return value expected
|
||||
if last_tool_usage := self.tools_handler.last_used_tool:
|
||||
return (calling.tool_name == last_tool_usage.tool_name) and ( # type: ignore # No return value expected
|
||||
calling.arguments == last_tool_usage.arguments
|
||||
)
|
||||
) -> bool:
|
||||
"""Check if a tool is being called with the same arguments as the last call.
|
||||
|
||||
This method prevents duplicate tool executions by comparing the current tool call
|
||||
with the last one. For WebSocket tools, it specifically checks if the 'question'
|
||||
argument is identical. For other tools, it compares all arguments.
|
||||
|
||||
Args:
|
||||
calling: The tool calling to check for repetition, containing the tool name
|
||||
and arguments.
|
||||
|
||||
Returns:
|
||||
bool: True if the tool is being called with the same name and arguments as
|
||||
the last call, False otherwise.
|
||||
"""
|
||||
self._logger.log("debug", f"Checking for repeated usage of tool: {calling.tool_name}")
|
||||
|
||||
if not self.tools_handler or not self.tools_handler.last_used_tool:
|
||||
self._logger.log("debug", "No previous tool usage found")
|
||||
return False
|
||||
|
||||
last_tool_usage = self.tools_handler.last_used_tool
|
||||
if calling.tool_name != last_tool_usage.tool_name:
|
||||
self._logger.log("debug", f"Different tool name: {calling.tool_name} vs {last_tool_usage.tool_name}")
|
||||
return False
|
||||
|
||||
if not calling.arguments or not last_tool_usage.arguments:
|
||||
self._logger.log("debug", "Missing arguments in current or last tool usage")
|
||||
return False
|
||||
|
||||
try:
|
||||
# For WebSocket tools, only compare the question argument
|
||||
if "question" in calling.arguments and "question" in last_tool_usage.arguments:
|
||||
is_repeated = calling.arguments["question"] == last_tool_usage.arguments["question"]
|
||||
self._logger.log("debug", f"WebSocket tool question comparison: {is_repeated}")
|
||||
return is_repeated
|
||||
|
||||
# For other tools, compare all arguments
|
||||
is_repeated = calling.arguments == last_tool_usage.arguments
|
||||
self._logger.log("debug", f"Full arguments comparison: {is_repeated}")
|
||||
return is_repeated
|
||||
except (KeyError, TypeError) as e:
|
||||
self._logger.log("debug", f"Error comparing arguments: {str(e)}")
|
||||
return False
|
||||
|
||||
def _select_tool(self, tool_name: str) -> Any:
|
||||
order_tools = sorted(
|
||||
|
||||
@@ -6,7 +6,9 @@ import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import Agent, Task
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.tool_calling import ToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.tool_usage_events import (
|
||||
@@ -128,6 +130,177 @@ def test_tool_usage_render():
|
||||
)
|
||||
|
||||
|
||||
class WebSocketToolInput(BaseModel):
|
||||
question: str = Field(..., description="Question to ask")
|
||||
|
||||
|
||||
class MockWebSocketTool(BaseTool):
|
||||
name: str = "WebSocket Tool"
|
||||
description: str = "A tool that uses WebSocket for communication"
|
||||
args_schema: type[BaseModel] = WebSocketToolInput
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return f"Answer to: {question}"
|
||||
|
||||
def invoke(self, input: dict) -> str:
|
||||
return self._run(**input)
|
||||
|
||||
|
||||
class TestWebSocketToolUsage:
|
||||
"""Test cases for WebSocket tool usage and duplicate detection."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_websocket_tool(self):
|
||||
"""Fixture to set up WebSocket tool and agent for testing."""
|
||||
tool = MockWebSocketTool()
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test WebSocket tools",
|
||||
backstory="Testing WebSocket tool execution",
|
||||
tools=[tool],
|
||||
verbose=True,
|
||||
)
|
||||
return tool, agent
|
||||
|
||||
def test_first_execution(self, setup_websocket_tool):
|
||||
"""Test first execution of WebSocket tool."""
|
||||
tool, agent = setup_websocket_tool
|
||||
task = Task(
|
||||
description="Test WebSocket tool",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=ToolsHandler(),
|
||||
tools=[tool],
|
||||
original_tools=[tool],
|
||||
tools_description="WebSocket tool for testing",
|
||||
tools_names="websocket_tool",
|
||||
task=task,
|
||||
function_calling_llm=MagicMock(),
|
||||
agent=agent,
|
||||
action=MagicMock(),
|
||||
)
|
||||
calling = ToolCalling(
|
||||
tool_name="WebSocket Tool",
|
||||
arguments={"question": "Test question"},
|
||||
log="Test log",
|
||||
)
|
||||
result = tool_usage.use(calling, "Test string")
|
||||
assert "Answer to: Test question" in result
|
||||
|
||||
def test_repeated_execution(self, setup_websocket_tool):
|
||||
"""Test repeated execution with same question is detected."""
|
||||
tool, agent = setup_websocket_tool
|
||||
task = Task(
|
||||
description="Test WebSocket tool",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=ToolsHandler(),
|
||||
tools=[tool],
|
||||
original_tools=[tool],
|
||||
tools_description="WebSocket tool for testing",
|
||||
tools_names="websocket_tool",
|
||||
task=task,
|
||||
function_calling_llm=MagicMock(),
|
||||
agent=agent,
|
||||
action=MagicMock(),
|
||||
)
|
||||
# First call
|
||||
calling1 = ToolCalling(
|
||||
tool_name="WebSocket Tool",
|
||||
arguments={"question": "Test question"},
|
||||
log="Test log",
|
||||
)
|
||||
result1 = tool_usage.use(calling1, "Test string")
|
||||
assert "Answer to: Test question" in result1
|
||||
|
||||
# Same question should be detected as repeated
|
||||
calling2 = ToolCalling(
|
||||
tool_name="WebSocket Tool",
|
||||
arguments={"question": "Test question"},
|
||||
log="Test log",
|
||||
)
|
||||
result2 = tool_usage.use(calling2, "Test string")
|
||||
assert "reusing the same input" in result2.lower()
|
||||
|
||||
def test_different_question(self, setup_websocket_tool):
|
||||
"""Test execution with different questions works."""
|
||||
tool, agent = setup_websocket_tool
|
||||
task = Task(
|
||||
description="Test WebSocket tool",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=ToolsHandler(),
|
||||
tools=[tool],
|
||||
original_tools=[tool],
|
||||
tools_description="WebSocket tool for testing",
|
||||
tools_names="websocket_tool",
|
||||
task=task,
|
||||
function_calling_llm=MagicMock(),
|
||||
agent=agent,
|
||||
action=MagicMock(),
|
||||
)
|
||||
# First question
|
||||
calling1 = ToolCalling(
|
||||
tool_name="WebSocket Tool",
|
||||
arguments={"question": "First question"},
|
||||
log="Test log",
|
||||
)
|
||||
result1 = tool_usage.use(calling1, "Test string")
|
||||
assert "Answer to: First question" in result1
|
||||
|
||||
# Different question should work
|
||||
calling2 = ToolCalling(
|
||||
tool_name="WebSocket Tool",
|
||||
arguments={"question": "Second question"},
|
||||
log="Test log",
|
||||
)
|
||||
result2 = tool_usage.use(calling2, "Test string")
|
||||
assert "Answer to: Second question" in result2
|
||||
|
||||
def test_invalid_arguments(self, setup_websocket_tool):
|
||||
"""Test handling of invalid arguments."""
|
||||
tool, agent = setup_websocket_tool
|
||||
task = Task(
|
||||
description="Test WebSocket tool",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=ToolsHandler(),
|
||||
tools=[tool],
|
||||
original_tools=[tool],
|
||||
tools_description="WebSocket tool for testing",
|
||||
tools_names="websocket_tool",
|
||||
task=task,
|
||||
function_calling_llm=MagicMock(),
|
||||
agent=agent,
|
||||
action=MagicMock(),
|
||||
)
|
||||
# Test with empty arguments
|
||||
calling = ToolCalling(
|
||||
tool_name="WebSocket Tool",
|
||||
arguments={},
|
||||
log="Test log",
|
||||
)
|
||||
with pytest.raises(ValueError, match="Tool arguments cannot be empty"):
|
||||
tool_usage.use(calling, "Test string")
|
||||
|
||||
# Test with None arguments
|
||||
calling = ToolCalling(
|
||||
tool_name="WebSocket Tool",
|
||||
arguments=None,
|
||||
log="Test log",
|
||||
)
|
||||
with pytest.raises(ValueError, match="Tool arguments cannot be empty"):
|
||||
tool_usage.use(calling, "Test string")
|
||||
|
||||
|
||||
def test_validate_tool_input_booleans_and_none():
|
||||
# Create a ToolUsage instance with mocks
|
||||
tool_usage = ToolUsage(
|
||||
|
||||
Reference in New Issue
Block a user