Fix CI failures and address code review feedback

- Add max_usage_count/current_usage_count to CrewStructuredTool
- Add input validation for positive max_usage_count
- Add reset_usage_count method to BaseTool
- Extract usage limit check into separate method
- Add comprehensive edge case tests
- Add proper type hints throughout
- Fix linting issues

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-26 07:01:15 +00:00
parent 7f730fbe02
commit 3b81adefbc
4 changed files with 141 additions and 15 deletions

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import Any, Callable, Type, get_args, get_origin from typing import Any, Callable, Type, get_args, get_origin
@@ -58,6 +57,13 @@ class BaseTool(BaseModel, ABC):
}, },
}, },
) )
@field_validator("max_usage_count", mode="before")
@classmethod
def validate_max_usage_count(cls, v: int | None) -> int | None:
if v is not None and v <= 0:
raise ValueError("max_usage_count must be a positive integer")
return v
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
self._generate_description() self._generate_description()
@@ -74,9 +80,15 @@ class BaseTool(BaseModel, ABC):
# If _run is async, we safely run it # If _run is async, we safely run it
if asyncio.iscoroutine(result): if asyncio.iscoroutine(result):
return asyncio.run(result) result = asyncio.run(result)
self.current_usage_count += 1
return result return result
def reset_usage_count(self) -> None:
"""Reset the current usage count to zero."""
self.current_usage_count = 0
@abstractmethod @abstractmethod
def _run( def _run(
@@ -95,6 +107,8 @@ class BaseTool(BaseModel, ABC):
args_schema=self.args_schema, args_schema=self.args_schema,
func=self._run, func=self._run,
result_as_answer=self.result_as_answer, result_as_answer=self.result_as_answer,
max_usage_count=self.max_usage_count,
current_usage_count=self.current_usage_count,
) )
@classmethod @classmethod
@@ -255,7 +269,7 @@ def to_langchain(
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools] return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
def tool(*args, result_as_answer=False, max_usage_count=None): def tool(*args, result_as_answer: bool = False, max_usage_count: int | None = None) -> Callable:
""" """
Decorator to create a tool from a function. Decorator to create a tool from a function.

View File

@@ -23,6 +23,8 @@ class CrewStructuredTool:
args_schema: type[BaseModel], args_schema: type[BaseModel],
func: Callable[..., Any], func: Callable[..., Any],
result_as_answer: bool = False, result_as_answer: bool = False,
max_usage_count: int | None = None,
current_usage_count: int = 0,
) -> None: ) -> None:
"""Initialize the structured tool. """Initialize the structured tool.
@@ -32,6 +34,8 @@ class CrewStructuredTool:
args_schema: The pydantic model for the tool's arguments args_schema: The pydantic model for the tool's arguments
func: The function to run when the tool is called func: The function to run when the tool is called
result_as_answer: Whether to return the output directly result_as_answer: Whether to return the output directly
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
current_usage_count: Current number of times this tool has been used.
""" """
self.name = name self.name = name
self.description = description self.description = description
@@ -39,6 +43,8 @@ class CrewStructuredTool:
self.func = func self.func = func
self._logger = Logger() self._logger = Logger()
self.result_as_answer = result_as_answer self.result_as_answer = result_as_answer
self.max_usage_count = max_usage_count
self.current_usage_count = current_usage_count
# Validate the function signature matches the schema # Validate the function signature matches the schema
self._validate_function_signature() self._validate_function_signature()

View File

@@ -200,16 +200,16 @@ class ToolUsage:
None, None,
) )
if available_tool and hasattr(available_tool, 'max_usage_count') and available_tool.max_usage_count is not None: usage_limit_error = self._check_usage_limit(available_tool, tool.name)
if available_tool.current_usage_count >= available_tool.max_usage_count: if usage_limit_error:
try: try:
result = f"Tool '{tool.name}' has reached its usage limit of {available_tool.max_usage_count} times and cannot be used anymore." result = usage_limit_error
self._telemetry.tool_usage_error(llm=self.function_calling_llm) self._telemetry.tool_usage_error(llm=self.function_calling_llm)
result = self._format_result(result=result) result = self._format_result(result=result)
return result return result
except Exception: except Exception:
if self.task: if self.task:
self.task.increment_tools_errors() self.task.increment_tools_errors()
if result is None: if result is None:
try: try:
@@ -313,6 +313,11 @@ class ToolUsage:
if available_tool and hasattr(available_tool, 'current_usage_count'): if available_tool and hasattr(available_tool, 'current_usage_count'):
available_tool.current_usage_count += 1 available_tool.current_usage_count += 1
if hasattr(available_tool, 'max_usage_count') and available_tool.max_usage_count is not None:
self._printer.print(
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
color="blue"
)
return result return result
@@ -345,6 +350,24 @@ class ToolUsage:
calling.arguments == last_tool_usage.arguments calling.arguments == last_tool_usage.arguments
) )
return False return False
def _check_usage_limit(self, tool: Any, tool_name: str) -> str | None:
"""Check if tool has reached its usage limit.
Args:
tool: The tool to check
tool_name: The name of the tool (used for error message)
Returns:
Error message if limit reached, None otherwise
"""
if (
hasattr(tool, 'max_usage_count')
and tool.max_usage_count is not None
and tool.current_usage_count >= tool.max_usage_count
):
return f"Tool '{tool_name}' has reached its usage limit of {tool.max_usage_count} times and cannot be used anymore."
return None
def _select_tool(self, tool_name: str) -> Any: def _select_tool(self, tool_name: str) -> Any:
order_tools = sorted( order_tools = sorted(

View File

@@ -1,7 +1,8 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock
from crewai.tools import BaseTool, tool from crewai.tools import BaseTool, tool
from crewai.tools.tool_usage import ToolUsage
def test_tool_usage_limit(): def test_tool_usage_limit():
@@ -66,3 +67,85 @@ def test_default_unlimited_usage():
assert default_tool.max_usage_count is None assert default_tool.max_usage_count is None
assert default_tool.current_usage_count == 0 assert default_tool.current_usage_count == 0
def test_invalid_usage_limit():
"""Test that negative usage limits raise ValueError."""
class ValidTool(BaseTool):
name: str = "Valid Tool"
description: str = "A tool with valid usage limit"
def _run(self, input_text: str) -> str:
return f"Processed {input_text}"
with pytest.raises(ValueError, match="max_usage_count must be a positive integer"):
ValidTool(max_usage_count=-1)
def test_reset_usage_count():
"""Test that reset_usage_count method works correctly."""
class LimitedTool(BaseTool):
name: str = "Limited Tool"
description: str = "A tool with usage limits for testing"
max_usage_count: int = 3
def _run(self, input_text: str) -> str:
return f"Processed {input_text}"
tool = LimitedTool()
tool.run(input_text="test1")
tool.run(input_text="test2")
assert tool.current_usage_count == 2
tool.reset_usage_count()
assert tool.current_usage_count == 0
result = tool.run(input_text="test3")
assert result == "Processed test3"
assert tool.current_usage_count == 1
def test_tool_usage_with_toolusage_class():
"""Test that ToolUsage class correctly enforces usage limits."""
class LimitedTool(BaseTool):
name: str = "Limited Tool"
description: str = "A tool with usage limits for testing"
max_usage_count: int = 2
def _run(self, input_text: str) -> str:
return f"Processed {input_text}"
tool = LimitedTool()
mock_agent = MagicMock()
mock_task = MagicMock()
mock_tools_handler = MagicMock()
tool_usage = ToolUsage(
tools=[tool],
agent=mock_agent,
task=mock_task,
tools_handler=mock_tools_handler,
function_calling_llm=MagicMock(),
)
tool_usage._check_tool_repeated_usage = MagicMock(return_value=False)
tool_usage._format_result = lambda result: result
mock_calling = MagicMock()
mock_calling.tool_name = "Limited Tool"
mock_calling.arguments = {"input_text": "test"}
result1 = tool_usage._check_usage_limit(tool, "Limited Tool")
assert result1 is None
tool.current_usage_count += 1
result2 = tool_usage._check_usage_limit(tool, "Limited Tool")
assert result2 is None
tool.current_usage_count += 1
result3 = tool_usage._check_usage_limit(tool, "Limited Tool")
assert "has reached its usage limit of 2 times" in result3