Add usage limit feature to BaseTool class (#2904)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

* Add usage limit feature to BaseTool class

- Add max_usage_count and current_usage_count attributes to BaseTool
- Implement usage limit checking in ToolUsage._use method
- Add comprehensive tests for usage limit functionality
- Maintain backward compatibility with None default for unlimited usage

Co-Authored-By: Joe Moura <joao@crewai.com>

* Fix CI failures and address code review feedback

- Add max_usage_count/current_usage_count to CrewStructuredTool
- Add input validation for positive max_usage_count
- Add reset_usage_count method to BaseTool
- Extract usage limit check into separate method
- Add comprehensive edge case tests
- Add proper type hints throughout
- Fix linting issues

Co-Authored-By: Joe Moura <joao@crewai.com>

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Joe Moura <joao@crewai.com>
This commit is contained in:
devin-ai-integration[bot]
2025-05-26 08:53:10 -07:00
committed by GitHub
parent 7fe193866d
commit 22db4aae81
4 changed files with 219 additions and 4 deletions

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import Any, Callable, Type, get_args, get_origin from typing import Any, Callable, Type, get_args, get_origin
@@ -36,6 +35,10 @@ class BaseTool(BaseModel, ABC):
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.""" """Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
result_as_answer: bool = False result_as_answer: bool = False
"""Flag to check if the tool should be the final agent answer.""" """Flag to check if the tool should be the final agent answer."""
max_usage_count: int | None = None
"""Maximum number of times this tool can be used. None means unlimited usage."""
current_usage_count: int = 0
"""Current number of times this tool has been used."""
@field_validator("args_schema", mode="before") @field_validator("args_schema", mode="before")
@classmethod @classmethod
@@ -55,6 +58,13 @@ class BaseTool(BaseModel, ABC):
}, },
) )
@field_validator("max_usage_count", mode="before")
@classmethod
def validate_max_usage_count(cls, v: int | None) -> int | None:
if v is not None and v <= 0:
raise ValueError("max_usage_count must be a positive integer")
return v
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
self._generate_description() self._generate_description()
@@ -70,10 +80,16 @@ class BaseTool(BaseModel, ABC):
# If _run is async, we safely run it # If _run is async, we safely run it
if asyncio.iscoroutine(result): if asyncio.iscoroutine(result):
return asyncio.run(result) result = asyncio.run(result)
self.current_usage_count += 1
return result return result
def reset_usage_count(self) -> None:
"""Reset the current usage count to zero."""
self.current_usage_count = 0
@abstractmethod @abstractmethod
def _run( def _run(
self, self,
@@ -91,6 +107,8 @@ class BaseTool(BaseModel, ABC):
args_schema=self.args_schema, args_schema=self.args_schema,
func=self._run, func=self._run,
result_as_answer=self.result_as_answer, result_as_answer=self.result_as_answer,
max_usage_count=self.max_usage_count,
current_usage_count=self.current_usage_count,
) )
@classmethod @classmethod
@@ -251,13 +269,14 @@ def to_langchain(
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools] return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
def tool(*args, result_as_answer=False): def tool(*args, result_as_answer: bool = False, max_usage_count: int | None = None) -> Callable:
""" """
Decorator to create a tool from a function. Decorator to create a tool from a function.
Args: Args:
*args: Positional arguments, either the function to decorate or the tool name. *args: Positional arguments, either the function to decorate or the tool name.
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer. result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
""" """
def _make_with_name(tool_name: str) -> Callable: def _make_with_name(tool_name: str) -> Callable:
@@ -284,6 +303,8 @@ def tool(*args, result_as_answer=False):
func=f, func=f,
args_schema=args_schema, args_schema=args_schema,
result_as_answer=result_as_answer, result_as_answer=result_as_answer,
max_usage_count=max_usage_count,
current_usage_count=0,
) )
return _make_tool return _make_tool

View File

@@ -23,6 +23,8 @@ class CrewStructuredTool:
args_schema: type[BaseModel], args_schema: type[BaseModel],
func: Callable[..., Any], func: Callable[..., Any],
result_as_answer: bool = False, result_as_answer: bool = False,
max_usage_count: int | None = None,
current_usage_count: int = 0,
) -> None: ) -> None:
"""Initialize the structured tool. """Initialize the structured tool.
@@ -32,6 +34,8 @@ class CrewStructuredTool:
args_schema: The pydantic model for the tool's arguments args_schema: The pydantic model for the tool's arguments
func: The function to run when the tool is called func: The function to run when the tool is called
result_as_answer: Whether to return the output directly result_as_answer: Whether to return the output directly
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
current_usage_count: Current number of times this tool has been used.
""" """
self.name = name self.name = name
self.description = description self.description = description
@@ -39,6 +43,8 @@ class CrewStructuredTool:
self.func = func self.func = func
self._logger = Logger() self._logger = Logger()
self.result_as_answer = result_as_answer self.result_as_answer = result_as_answer
self.max_usage_count = max_usage_count
self.current_usage_count = current_usage_count
# Validate the function signature matches the schema # Validate the function signature matches the schema
self._validate_function_signature() self._validate_function_signature()

View File

@@ -200,6 +200,17 @@ class ToolUsage:
None, None,
) )
usage_limit_error = self._check_usage_limit(available_tool, tool.name)
if usage_limit_error:
try:
result = usage_limit_error
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
result = self._format_result(result=result)
return result
except Exception:
if self.task:
self.task.increment_tools_errors()
if result is None: if result is None:
try: try:
if calling.tool_name in [ if calling.tool_name in [
@@ -300,6 +311,14 @@ class ToolUsage:
if self.agent and hasattr(self.agent, "tools_results"): if self.agent and hasattr(self.agent, "tools_results"):
self.agent.tools_results.append(data) self.agent.tools_results.append(data)
if available_tool and hasattr(available_tool, 'current_usage_count'):
available_tool.current_usage_count += 1
if hasattr(available_tool, 'max_usage_count') and available_tool.max_usage_count is not None:
self._printer.print(
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
color="blue"
)
return result return result
def _format_result(self, result: Any) -> str: def _format_result(self, result: Any) -> str:
@@ -332,6 +351,24 @@ class ToolUsage:
) )
return False return False
def _check_usage_limit(self, tool: Any, tool_name: str) -> str | None:
"""Check if tool has reached its usage limit.
Args:
tool: The tool to check
tool_name: The name of the tool (used for error message)
Returns:
Error message if limit reached, None otherwise
"""
if (
hasattr(tool, 'max_usage_count')
and tool.max_usage_count is not None
and tool.current_usage_count >= tool.max_usage_count
):
return f"Tool '{tool_name}' has reached its usage limit of {tool.max_usage_count} times and cannot be used anymore."
return None
def _select_tool(self, tool_name: str) -> Any: def _select_tool(self, tool_name: str) -> Any:
order_tools = sorted( order_tools = sorted(
self.tools, self.tools,

View File

@@ -0,0 +1,151 @@
import pytest
from unittest.mock import MagicMock
from crewai.tools import BaseTool, tool
from crewai.tools.tool_usage import ToolUsage
def test_tool_usage_limit():
"""Test that tools respect usage limits."""
class LimitedTool(BaseTool):
name: str = "Limited Tool"
description: str = "A tool with usage limits for testing"
max_usage_count: int = 2
def _run(self, input_text: str) -> str:
return f"Processed {input_text}"
tool = LimitedTool()
result1 = tool.run(input_text="test1")
assert result1 == "Processed test1"
assert tool.current_usage_count == 1
result2 = tool.run(input_text="test2")
assert result2 == "Processed test2"
assert tool.current_usage_count == 2
def test_unlimited_tool_usage():
"""Test that tools without usage limits work normally."""
class UnlimitedTool(BaseTool):
name: str = "Unlimited Tool"
description: str = "A tool without usage limits"
def _run(self, input_text: str) -> str:
return f"Processed {input_text}"
tool = UnlimitedTool()
for i in range(5):
result = tool.run(input_text=f"test{i}")
assert result == f"Processed test{i}"
assert tool.current_usage_count == i + 1
def test_tool_decorator_with_usage_limit():
"""Test usage limit with @tool decorator."""
@tool("Test Tool", max_usage_count=3)
def test_tool(input_text: str) -> str:
"""A test tool."""
return f"Result: {input_text}"
assert test_tool.max_usage_count == 3
assert test_tool.current_usage_count == 0
result = test_tool.run(input_text="test")
assert result == "Result: test"
assert test_tool.current_usage_count == 1
def test_default_unlimited_usage():
"""Test that tools have unlimited usage by default."""
@tool("Default Tool")
def default_tool(input_text: str) -> str:
"""A default tool."""
return f"Result: {input_text}"
assert default_tool.max_usage_count is None
assert default_tool.current_usage_count == 0
def test_invalid_usage_limit():
"""Test that negative usage limits raise ValueError."""
class ValidTool(BaseTool):
name: str = "Valid Tool"
description: str = "A tool with valid usage limit"
def _run(self, input_text: str) -> str:
return f"Processed {input_text}"
with pytest.raises(ValueError, match="max_usage_count must be a positive integer"):
ValidTool(max_usage_count=-1)
def test_reset_usage_count():
"""Test that reset_usage_count method works correctly."""
class LimitedTool(BaseTool):
name: str = "Limited Tool"
description: str = "A tool with usage limits for testing"
max_usage_count: int = 3
def _run(self, input_text: str) -> str:
return f"Processed {input_text}"
tool = LimitedTool()
tool.run(input_text="test1")
tool.run(input_text="test2")
assert tool.current_usage_count == 2
tool.reset_usage_count()
assert tool.current_usage_count == 0
result = tool.run(input_text="test3")
assert result == "Processed test3"
assert tool.current_usage_count == 1
def test_tool_usage_with_toolusage_class():
"""Test that ToolUsage class correctly enforces usage limits."""
class LimitedTool(BaseTool):
name: str = "Limited Tool"
description: str = "A tool with usage limits for testing"
max_usage_count: int = 2
def _run(self, input_text: str) -> str:
return f"Processed {input_text}"
tool = LimitedTool()
mock_agent = MagicMock()
mock_task = MagicMock()
mock_tools_handler = MagicMock()
tool_usage = ToolUsage(
tools=[tool],
agent=mock_agent,
task=mock_task,
tools_handler=mock_tools_handler,
function_calling_llm=MagicMock(),
)
tool_usage._check_tool_repeated_usage = MagicMock(return_value=False)
tool_usage._format_result = lambda result: result
mock_calling = MagicMock()
mock_calling.tool_name = "Limited Tool"
mock_calling.arguments = {"input_text": "test"}
result1 = tool_usage._check_usage_limit(tool, "Limited Tool")
assert result1 is None
tool.current_usage_count += 1
result2 = tool_usage._check_usage_limit(tool, "Limited Tool")
assert result2 is None
tool.current_usage_count += 1
result3 = tool_usage._check_usage_limit(tool, "Limited Tool")
assert "has reached its usage limit of 2 times" in result3