Lorenze/fix tool call twice (#3495)

* test: add test to ensure tool is called only once during crew execution

- Introduced a new test case to validate that the counting_tool is executed exactly once during crew execution.
- Created a CountingTool class to track execution counts and log call history.
- Enhanced the test suite with a YAML cassette for consistent tool behavior verification.

* ensure tool function called only once

* refactor: simplify error handling in CrewStructuredTool

- Removed unnecessary try-except block around the tool function call to streamline execution flow.
- Ensured that the tool function is called directly, improving readability and maintainability.

* linted

* need to ignore for now as we cant infer the complex generic type within pydantic create_model_func

* fix tests
This commit is contained in:
Lorenze Jay
2025-09-10 15:20:21 -07:00
committed by GitHub
parent 01be26ce2a
commit 75b916c85a
5 changed files with 583 additions and 233 deletions

View File

@@ -1,26 +1,22 @@
from __future__ import annotations
import asyncio
import inspect
import textwrap
from typing import Any, Callable, Optional, Union, get_type_hints
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, get_type_hints
from pydantic import BaseModel, Field, create_model
from crewai.utilities.logger import Logger
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
class ToolUsageLimitExceeded(Exception):
class ToolUsageLimitExceededError(Exception):
"""Exception raised when a tool has reached its maximum usage limit."""
pass
class CrewStructuredTool:
"""A structured tool that can operate on any number of inputs.
@@ -69,10 +65,10 @@ class CrewStructuredTool:
def from_function(
cls,
func: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
name: str | None = None,
description: str | None = None,
return_direct: bool = False,
args_schema: Optional[type[BaseModel]] = None,
args_schema: type[BaseModel] | None = None,
infer_schema: bool = True,
**kwargs: Any,
) -> CrewStructuredTool:
@@ -164,7 +160,7 @@ class CrewStructuredTool:
# Create model
schema_name = f"{name.title()}Schema"
return create_model(schema_name, **fields)
return create_model(schema_name, **fields) # type: ignore[call-overload]
def _validate_function_signature(self) -> None:
"""Validate that the function signature matches the args schema."""
@@ -192,7 +188,7 @@ class CrewStructuredTool:
f"not found in args_schema"
)
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
def _parse_args(self, raw_args: str | dict) -> dict:
"""Parse and validate the input arguments against the schema.
Args:
@@ -207,18 +203,18 @@ class CrewStructuredTool:
raw_args = json.loads(raw_args)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments as JSON: {e}")
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
try:
validated_args = self.args_schema.model_validate(raw_args)
return validated_args.model_dump()
except Exception as e:
raise ValueError(f"Arguments validation failed: {e}")
raise ValueError(f"Arguments validation failed: {e}") from e
async def ainvoke(
self,
input: Union[str, dict],
config: Optional[dict] = None,
input: str | dict,
config: dict | None = None,
**kwargs: Any,
) -> Any:
"""Asynchronously invoke the tool.
@@ -234,7 +230,7 @@ class CrewStructuredTool:
parsed_args = self._parse_args(input)
if self.has_reached_max_usage_count():
raise ToolUsageLimitExceeded(
raise ToolUsageLimitExceededError(
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
)
@@ -243,44 +239,37 @@ class CrewStructuredTool:
try:
if inspect.iscoroutinefunction(self.func):
return await self.func(**parsed_args, **kwargs)
else:
# Run sync functions in a thread pool
import asyncio
# Run sync functions in a thread pool
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, lambda: self.func(**parsed_args, **kwargs)
)
return await asyncio.get_event_loop().run_in_executor(
None, lambda: self.func(**parsed_args, **kwargs)
)
except Exception:
raise
def _run(self, *args, **kwargs) -> Any:
"""Legacy method for compatibility."""
# Convert args/kwargs to our expected format
input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
input_dict.update(kwargs)
return self.invoke(input_dict)
def invoke(
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
self, input: str | dict, config: dict | None = None, **kwargs: Any
) -> Any:
"""Main method for tool execution."""
parsed_args = self._parse_args(input)
if self.has_reached_max_usage_count():
raise ToolUsageLimitExceeded(
raise ToolUsageLimitExceededError(
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
)
self._increment_usage_count()
if inspect.iscoroutinefunction(self.func):
result = asyncio.run(self.func(**parsed_args, **kwargs))
return result
try:
result = self.func(**parsed_args, **kwargs)
except Exception:
raise
return asyncio.run(self.func(**parsed_args, **kwargs))
result = self.func(**parsed_args, **kwargs)