mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
fix: prevent duplicate execution of WebSocket tools
- Add specific handling for WebSocket tools in _check_tool_repeated_usage - Add test cases for WebSocket tool execution - Fix issue #2209 Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -283,13 +283,22 @@ class ToolUsage:
|
|||||||
|
|
||||||
def _check_tool_repeated_usage(
|
def _check_tool_repeated_usage(
|
||||||
self, calling: Union[ToolCalling, InstructorToolCalling]
|
self, calling: Union[ToolCalling, InstructorToolCalling]
|
||||||
) -> None:
|
) -> bool:
|
||||||
if not self.tools_handler:
|
if not self.tools_handler:
|
||||||
return False # type: ignore # No return value expected
|
return False
|
||||||
if last_tool_usage := self.tools_handler.last_used_tool:
|
if last_tool_usage := self.tools_handler.last_used_tool:
|
||||||
return (calling.tool_name == last_tool_usage.tool_name) and ( # type: ignore # No return value expected
|
# For WebSocket tools, we need to check if the question is the same
|
||||||
calling.arguments == last_tool_usage.arguments
|
if "question" in calling.arguments and "question" in last_tool_usage.arguments:
|
||||||
|
return (
|
||||||
|
calling.tool_name == last_tool_usage.tool_name
|
||||||
|
and calling.arguments["question"] == last_tool_usage.arguments["question"]
|
||||||
|
)
|
||||||
|
# For other tools, check all arguments
|
||||||
|
return (
|
||||||
|
calling.tool_name == last_tool_usage.tool_name
|
||||||
|
and calling.arguments == last_tool_usage.arguments
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
def _select_tool(self, tool_name: str) -> Any:
|
def _select_tool(self, tool_name: str) -> Any:
|
||||||
order_tools = sorted(
|
order_tools = sorted(
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import pytest
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai import Agent, Task
|
from crewai import Agent, Task
|
||||||
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
from crewai.tools.tool_calling import ToolCalling
|
||||||
from crewai.tools.tool_usage import ToolUsage
|
from crewai.tools.tool_usage import ToolUsage
|
||||||
from crewai.utilities.events import crewai_event_bus
|
from crewai.utilities.events import crewai_event_bus
|
||||||
from crewai.utilities.events.tool_usage_events import (
|
from crewai.utilities.events.tool_usage_events import (
|
||||||
@@ -128,6 +130,78 @@ def test_tool_usage_render():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketToolInput(BaseModel):
|
||||||
|
question: str = Field(..., description="Question to ask")
|
||||||
|
|
||||||
|
|
||||||
|
class MockWebSocketTool(BaseTool):
|
||||||
|
name: str = "WebSocket Tool"
|
||||||
|
description: str = "A tool that uses WebSocket for communication"
|
||||||
|
args_schema: type[BaseModel] = WebSocketToolInput
|
||||||
|
|
||||||
|
def _run(self, question: str) -> str:
|
||||||
|
return f"Answer to: {question}"
|
||||||
|
|
||||||
|
def invoke(self, input: dict) -> str:
|
||||||
|
return self._run(**input)
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_tool_repeated_usage():
|
||||||
|
tool = MockWebSocketTool()
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test WebSocket tools",
|
||||||
|
backstory="Testing WebSocket tool execution",
|
||||||
|
tools=[tool],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test WebSocket tool",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_usage = ToolUsage(
|
||||||
|
tools_handler=ToolsHandler(),
|
||||||
|
tools=[tool],
|
||||||
|
original_tools=[tool],
|
||||||
|
tools_description="WebSocket tool for testing",
|
||||||
|
tools_names="websocket_tool",
|
||||||
|
task=task,
|
||||||
|
function_calling_llm=MagicMock(),
|
||||||
|
agent=agent,
|
||||||
|
action=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call
|
||||||
|
calling1 = ToolCalling(
|
||||||
|
tool_name="WebSocket Tool",
|
||||||
|
arguments={"question": "Test question"},
|
||||||
|
log="Test log",
|
||||||
|
)
|
||||||
|
result1 = tool_usage.use(calling1, "Test string")
|
||||||
|
assert "Answer to: Test question" in result1
|
||||||
|
|
||||||
|
# Same question should be detected as repeated
|
||||||
|
calling2 = ToolCalling(
|
||||||
|
tool_name="WebSocket Tool",
|
||||||
|
arguments={"question": "Test question"},
|
||||||
|
log="Test log",
|
||||||
|
)
|
||||||
|
result2 = tool_usage.use(calling2, "Test string")
|
||||||
|
assert "reusing the same input" in result2.lower()
|
||||||
|
|
||||||
|
# Different question should work
|
||||||
|
calling3 = ToolCalling(
|
||||||
|
tool_name="WebSocket Tool",
|
||||||
|
arguments={"question": "Different question"},
|
||||||
|
log="Test log",
|
||||||
|
)
|
||||||
|
result3 = tool_usage.use(calling3, "Test string")
|
||||||
|
assert "Answer to: Different question" in result3
|
||||||
|
|
||||||
|
|
||||||
def test_validate_tool_input_booleans_and_none():
|
def test_validate_tool_input_booleans_and_none():
|
||||||
# Create a ToolUsage instance with mocks
|
# Create a ToolUsage instance with mocks
|
||||||
tool_usage = ToolUsage(
|
tool_usage = ToolUsage(
|
||||||
|
|||||||
Reference in New Issue
Block a user