diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index 0e8a7a22b..fb0428ccd 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -1,5 +1,4 @@ import asyncio -import warnings from abc import ABC, abstractmethod from inspect import signature from typing import Any, Callable, Type, get_args, get_origin @@ -36,6 +35,10 @@ class BaseTool(BaseModel, ABC): """Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.""" result_as_answer: bool = False """Flag to check if the tool should be the final agent answer.""" + max_usage_count: int | None = None + """Maximum number of times this tool can be used. None means unlimited usage.""" + current_usage_count: int = 0 + """Current number of times this tool has been used.""" @field_validator("args_schema", mode="before") @classmethod @@ -54,6 +57,13 @@ class BaseTool(BaseModel, ABC): }, }, ) + + @field_validator("max_usage_count", mode="before") + @classmethod + def validate_max_usage_count(cls, v: int | None) -> int | None: + if v is not None and v <= 0: + raise ValueError("max_usage_count must be a positive integer") + return v def model_post_init(self, __context: Any) -> None: self._generate_description() @@ -70,9 +80,15 @@ class BaseTool(BaseModel, ABC): # If _run is async, we safely run it if asyncio.iscoroutine(result): - return asyncio.run(result) - + result = asyncio.run(result) + + self.current_usage_count += 1 + return result + + def reset_usage_count(self) -> None: + """Reset the current usage count to zero.""" + self.current_usage_count = 0 @abstractmethod def _run( @@ -91,6 +107,8 @@ class BaseTool(BaseModel, ABC): args_schema=self.args_schema, func=self._run, result_as_answer=self.result_as_answer, + max_usage_count=self.max_usage_count, + current_usage_count=self.current_usage_count, ) @classmethod @@ -251,13 +269,14 @@ def to_langchain( return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools] -def tool(*args, result_as_answer=False): +def tool(*args, result_as_answer: bool = False, max_usage_count: int | None = None) -> Callable: """ Decorator to create a tool from a function. Args: *args: Positional arguments, either the function to decorate or the tool name. result_as_answer: Flag to indicate if the tool result should be used as the final agent answer. + max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. """ def _make_with_name(tool_name: str) -> Callable: @@ -284,6 +303,8 @@ def tool(*args, result_as_answer=False): func=f, args_schema=args_schema, result_as_answer=result_as_answer, + max_usage_count=max_usage_count, + current_usage_count=0, ) return _make_tool diff --git a/src/crewai/tools/structured_tool.py b/src/crewai/tools/structured_tool.py index dfd23a9cb..cff5d8b7a 100644 --- a/src/crewai/tools/structured_tool.py +++ b/src/crewai/tools/structured_tool.py @@ -23,6 +23,8 @@ class CrewStructuredTool: args_schema: type[BaseModel], func: Callable[..., Any], result_as_answer: bool = False, + max_usage_count: int | None = None, + current_usage_count: int = 0, ) -> None: """Initialize the structured tool. @@ -32,6 +34,8 @@ class CrewStructuredTool: args_schema: The pydantic model for the tool's arguments func: The function to run when the tool is called result_as_answer: Whether to return the output directly + max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. + current_usage_count: Current number of times this tool has been used. """ self.name = name self.description = description @@ -39,6 +43,8 @@ class CrewStructuredTool: self.func = func self._logger = Logger() self.result_as_answer = result_as_answer + self.max_usage_count = max_usage_count + self.current_usage_count = current_usage_count # Validate the function signature matches the schema self._validate_function_signature() diff --git a/src/crewai/tools/tool_usage.py b/src/crewai/tools/tool_usage.py index dc5f8f29a..7ebf009fc 100644 --- a/src/crewai/tools/tool_usage.py +++ b/src/crewai/tools/tool_usage.py @@ -200,6 +200,17 @@ class ToolUsage: None, ) + usage_limit_error = self._check_usage_limit(available_tool, tool.name) + if usage_limit_error: + try: + result = usage_limit_error + self._telemetry.tool_usage_error(llm=self.function_calling_llm) + result = self._format_result(result=result) + return result + except Exception: + if self.task: + self.task.increment_tools_errors() + if result is None: try: if calling.tool_name in [ @@ -300,6 +311,14 @@ class ToolUsage: if self.agent and hasattr(self.agent, "tools_results"): self.agent.tools_results.append(data) + if available_tool and hasattr(available_tool, 'current_usage_count'): + available_tool.current_usage_count += 1 + if hasattr(available_tool, 'max_usage_count') and available_tool.max_usage_count is not None: + self._printer.print( + content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}", + color="blue" + ) + return result def _format_result(self, result: Any) -> str: @@ -331,6 +350,24 @@ class ToolUsage: calling.arguments == last_tool_usage.arguments ) return False + + def _check_usage_limit(self, tool: Any, tool_name: str) -> str | None: + """Check if tool has reached its usage limit. + + Args: + tool: The tool to check + tool_name: The name of the tool (used for error message) + + Returns: + Error message if limit reached, None otherwise + """ + if ( + hasattr(tool, 'max_usage_count') + and tool.max_usage_count is not None + and tool.current_usage_count >= tool.max_usage_count + ): + return f"Tool '{tool_name}' has reached its usage limit of {tool.max_usage_count} times and cannot be used anymore." + return None def _select_tool(self, tool_name: str) -> Any: order_tools = sorted( diff --git a/tests/tools/test_tool_usage_limit.py b/tests/tools/test_tool_usage_limit.py new file mode 100644 index 000000000..7d23f4b63 --- /dev/null +++ b/tests/tools/test_tool_usage_limit.py @@ -0,0 +1,151 @@ +import pytest +from unittest.mock import MagicMock + +from crewai.tools import BaseTool, tool +from crewai.tools.tool_usage import ToolUsage + + +def test_tool_usage_limit(): + """Test that tools respect usage limits.""" + class LimitedTool(BaseTool): + name: str = "Limited Tool" + description: str = "A tool with usage limits for testing" + max_usage_count: int = 2 + + def _run(self, input_text: str) -> str: + return f"Processed {input_text}" + + tool = LimitedTool() + + result1 = tool.run(input_text="test1") + assert result1 == "Processed test1" + assert tool.current_usage_count == 1 + + result2 = tool.run(input_text="test2") + assert result2 == "Processed test2" + assert tool.current_usage_count == 2 + + +def test_unlimited_tool_usage(): + """Test that tools without usage limits work normally.""" + class UnlimitedTool(BaseTool): + name: str = "Unlimited Tool" + description: str = "A tool without usage limits" + + def _run(self, input_text: str) -> str: + return f"Processed {input_text}" + + tool = UnlimitedTool() + + for i in range(5): + result = tool.run(input_text=f"test{i}") + assert result == f"Processed test{i}" + assert tool.current_usage_count == i + 1 + + +def test_tool_decorator_with_usage_limit(): + """Test usage limit with @tool decorator.""" + @tool("Test Tool", max_usage_count=3) + def test_tool(input_text: str) -> str: + """A test tool.""" + return f"Result: {input_text}" + + assert test_tool.max_usage_count == 3 + assert test_tool.current_usage_count == 0 + + result = test_tool.run(input_text="test") + assert result == "Result: test" + assert test_tool.current_usage_count == 1 + + +def test_default_unlimited_usage(): + """Test that tools have unlimited usage by default.""" + @tool("Default Tool") + def default_tool(input_text: str) -> str: + """A default tool.""" + return f"Result: {input_text}" + + assert default_tool.max_usage_count is None + assert default_tool.current_usage_count == 0 + + +def test_invalid_usage_limit(): + """Test that negative usage limits raise ValueError.""" + class ValidTool(BaseTool): + name: str = "Valid Tool" + description: str = "A tool with valid usage limit" + + def _run(self, input_text: str) -> str: + return f"Processed {input_text}" + + with pytest.raises(ValueError, match="max_usage_count must be a positive integer"): + ValidTool(max_usage_count=-1) + + +def test_reset_usage_count(): + """Test that reset_usage_count method works correctly.""" + class LimitedTool(BaseTool): + name: str = "Limited Tool" + description: str = "A tool with usage limits for testing" + max_usage_count: int = 3 + + def _run(self, input_text: str) -> str: + return f"Processed {input_text}" + + tool = LimitedTool() + + tool.run(input_text="test1") + tool.run(input_text="test2") + assert tool.current_usage_count == 2 + + tool.reset_usage_count() + assert tool.current_usage_count == 0 + + result = tool.run(input_text="test3") + assert result == "Processed test3" + assert tool.current_usage_count == 1 + + +def test_tool_usage_with_toolusage_class(): + """Test that ToolUsage class correctly enforces usage limits.""" + class LimitedTool(BaseTool): + name: str = "Limited Tool" + description: str = "A tool with usage limits for testing" + max_usage_count: int = 2 + + def _run(self, input_text: str) -> str: + return f"Processed {input_text}" + + tool = LimitedTool() + + mock_agent = MagicMock() + mock_task = MagicMock() + mock_tools_handler = MagicMock() + + tool_usage = ToolUsage( + tools=[tool], + agent=mock_agent, + task=mock_task, + tools_handler=mock_tools_handler, + function_calling_llm=MagicMock(), + ) + + tool_usage._check_tool_repeated_usage = MagicMock(return_value=False) + tool_usage._format_result = lambda result: result + + mock_calling = MagicMock() + mock_calling.tool_name = "Limited Tool" + mock_calling.arguments = {"input_text": "test"} + + result1 = tool_usage._check_usage_limit(tool, "Limited Tool") + assert result1 is None + + tool.current_usage_count += 1 + + result2 = tool_usage._check_usage_limit(tool, "Limited Tool") + assert result2 is None + + tool.current_usage_count += 1 + + result3 = tool_usage._check_usage_limit(tool, "Limited Tool") + assert "has reached its usage limit of 2 times" in result3