diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index 0e0c35727..fb0428ccd 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -1,5 +1,4 @@ import asyncio -import warnings from abc import ABC, abstractmethod from inspect import signature from typing import Any, Callable, Type, get_args, get_origin @@ -58,6 +57,13 @@ class BaseTool(BaseModel, ABC): }, }, ) + + @field_validator("max_usage_count", mode="before") + @classmethod + def validate_max_usage_count(cls, v: int | None) -> int | None: + if v is not None and v <= 0: + raise ValueError("max_usage_count must be a positive integer") + return v def model_post_init(self, __context: Any) -> None: self._generate_description() @@ -74,9 +80,15 @@ class BaseTool(BaseModel, ABC): # If _run is async, we safely run it if asyncio.iscoroutine(result): - return asyncio.run(result) - + result = asyncio.run(result) + + self.current_usage_count += 1 + return result + + def reset_usage_count(self) -> None: + """Reset the current usage count to zero.""" + self.current_usage_count = 0 @abstractmethod def _run( @@ -95,6 +107,8 @@ class BaseTool(BaseModel, ABC): args_schema=self.args_schema, func=self._run, result_as_answer=self.result_as_answer, + max_usage_count=self.max_usage_count, + current_usage_count=self.current_usage_count, ) @classmethod @@ -255,7 +269,7 @@ def to_langchain( return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools] -def tool(*args, result_as_answer=False, max_usage_count=None): +def tool(*args, result_as_answer: bool = False, max_usage_count: int | None = None) -> Callable: """ Decorator to create a tool from a function. diff --git a/src/crewai/tools/structured_tool.py b/src/crewai/tools/structured_tool.py index dfd23a9cb..cff5d8b7a 100644 --- a/src/crewai/tools/structured_tool.py +++ b/src/crewai/tools/structured_tool.py @@ -23,6 +23,8 @@ class CrewStructuredTool: args_schema: type[BaseModel], func: Callable[..., Any], result_as_answer: bool = False, + max_usage_count: int | None = None, + current_usage_count: int = 0, ) -> None: """Initialize the structured tool. @@ -32,6 +34,8 @@ class CrewStructuredTool: args_schema: The pydantic model for the tool's arguments func: The function to run when the tool is called result_as_answer: Whether to return the output directly + max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. + current_usage_count: Current number of times this tool has been used. """ self.name = name self.description = description @@ -39,6 +43,8 @@ class CrewStructuredTool: self.func = func self._logger = Logger() self.result_as_answer = result_as_answer + self.max_usage_count = max_usage_count + self.current_usage_count = current_usage_count # Validate the function signature matches the schema self._validate_function_signature() diff --git a/src/crewai/tools/tool_usage.py b/src/crewai/tools/tool_usage.py index 2ac9198c3..7ebf009fc 100644 --- a/src/crewai/tools/tool_usage.py +++ b/src/crewai/tools/tool_usage.py @@ -200,16 +200,16 @@ class ToolUsage: None, ) - if available_tool and hasattr(available_tool, 'max_usage_count') and available_tool.max_usage_count is not None: - if available_tool.current_usage_count >= available_tool.max_usage_count: - try: - result = f"Tool '{tool.name}' has reached its usage limit of {available_tool.max_usage_count} times and cannot be used anymore." - self._telemetry.tool_usage_error(llm=self.function_calling_llm) - result = self._format_result(result=result) - return result - except Exception: - if self.task: - self.task.increment_tools_errors() + usage_limit_error = self._check_usage_limit(available_tool, tool.name) + if usage_limit_error: + try: + result = usage_limit_error + self._telemetry.tool_usage_error(llm=self.function_calling_llm) + result = self._format_result(result=result) + return result + except Exception: + if self.task: + self.task.increment_tools_errors() if result is None: try: @@ -313,6 +313,11 @@ class ToolUsage: if available_tool and hasattr(available_tool, 'current_usage_count'): available_tool.current_usage_count += 1 + if hasattr(available_tool, 'max_usage_count') and available_tool.max_usage_count is not None: + self._printer.print( + content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}", + color="blue" + ) return result @@ -345,6 +350,24 @@ class ToolUsage: calling.arguments == last_tool_usage.arguments ) return False + + def _check_usage_limit(self, tool: Any, tool_name: str) -> str | None: + """Check if tool has reached its usage limit. + + Args: + tool: The tool to check + tool_name: The name of the tool (used for error message) + + Returns: + Error message if limit reached, None otherwise + """ + if ( + hasattr(tool, 'max_usage_count') + and tool.max_usage_count is not None + and tool.current_usage_count >= tool.max_usage_count + ): + return f"Tool '{tool_name}' has reached its usage limit of {tool.max_usage_count} times and cannot be used anymore." + return None def _select_tool(self, tool_name: str) -> Any: order_tools = sorted( diff --git a/tests/tools/test_tool_usage_limit.py b/tests/tools/test_tool_usage_limit.py index 132588a26..7d23f4b63 100644 --- a/tests/tools/test_tool_usage_limit.py +++ b/tests/tools/test_tool_usage_limit.py @@ -1,7 +1,8 @@ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from crewai.tools import BaseTool, tool +from crewai.tools.tool_usage import ToolUsage def test_tool_usage_limit(): @@ -66,3 +67,85 @@ def test_default_unlimited_usage(): assert default_tool.max_usage_count is None assert default_tool.current_usage_count == 0 + + +def test_invalid_usage_limit(): + """Test that negative usage limits raise ValueError.""" + class ValidTool(BaseTool): + name: str = "Valid Tool" + description: str = "A tool with valid usage limit" + + def _run(self, input_text: str) -> str: + return f"Processed {input_text}" + + with pytest.raises(ValueError, match="max_usage_count must be a positive integer"): + ValidTool(max_usage_count=-1) + + +def test_reset_usage_count(): + """Test that reset_usage_count method works correctly.""" + class LimitedTool(BaseTool): + name: str = "Limited Tool" + description: str = "A tool with usage limits for testing" + max_usage_count: int = 3 + + def _run(self, input_text: str) -> str: + return f"Processed {input_text}" + + tool = LimitedTool() + + tool.run(input_text="test1") + tool.run(input_text="test2") + assert tool.current_usage_count == 2 + + tool.reset_usage_count() + assert tool.current_usage_count == 0 + + result = tool.run(input_text="test3") + assert result == "Processed test3" + assert tool.current_usage_count == 1 + + +def test_tool_usage_with_toolusage_class(): + """Test that ToolUsage class correctly enforces usage limits.""" + class LimitedTool(BaseTool): + name: str = "Limited Tool" + description: str = "A tool with usage limits for testing" + max_usage_count: int = 2 + + def _run(self, input_text: str) -> str: + return f"Processed {input_text}" + + tool = LimitedTool() + + mock_agent = MagicMock() + mock_task = MagicMock() + mock_tools_handler = MagicMock() + + tool_usage = ToolUsage( + tools=[tool], + agent=mock_agent, + task=mock_task, + tools_handler=mock_tools_handler, + function_calling_llm=MagicMock(), + ) + + tool_usage._check_tool_repeated_usage = MagicMock(return_value=False) + tool_usage._format_result = lambda result: result + + mock_calling = MagicMock() + mock_calling.tool_name = "Limited Tool" + mock_calling.arguments = {"input_text": "test"} + + result1 = tool_usage._check_usage_limit(tool, "Limited Tool") + assert result1 is None + + tool.current_usage_count += 1 + + result2 = tool_usage._check_usage_limit(tool, "Limited Tool") + assert result2 is None + + tool.current_usage_count += 1 + + result3 = tool_usage._check_usage_limit(tool, "Limited Tool") + assert "has reached its usage limit of 2 times" in result3