mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
Merge branch 'main' into feat/bump-version-1.10.0
This commit is contained in:
@@ -50,6 +50,7 @@ from crewai.utilities.agent_utils import (
|
|||||||
handle_unknown_error,
|
handle_unknown_error,
|
||||||
has_reached_max_iterations,
|
has_reached_max_iterations,
|
||||||
is_context_length_exceeded,
|
is_context_length_exceeded,
|
||||||
|
parse_tool_call_args,
|
||||||
process_llm_response,
|
process_llm_response,
|
||||||
track_delegation_if_needed,
|
track_delegation_if_needed,
|
||||||
)
|
)
|
||||||
@@ -894,13 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
ToolUsageStartedEvent,
|
ToolUsageStartedEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(func_args, str):
|
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
|
||||||
try:
|
if parse_error is not None:
|
||||||
args_dict = json.loads(func_args)
|
return parse_error
|
||||||
except json.JSONDecodeError:
|
|
||||||
args_dict = {}
|
|
||||||
else:
|
|
||||||
args_dict = func_args
|
|
||||||
|
|
||||||
if original_tool is None:
|
if original_tool is None:
|
||||||
for tool in self.original_tools or []:
|
for tool in self.original_tools or []:
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ from crewai.utilities.agent_utils import (
|
|||||||
has_reached_max_iterations,
|
has_reached_max_iterations,
|
||||||
is_context_length_exceeded,
|
is_context_length_exceeded,
|
||||||
is_inside_event_loop,
|
is_inside_event_loop,
|
||||||
|
parse_tool_call_args,
|
||||||
process_llm_response,
|
process_llm_response,
|
||||||
track_delegation_if_needed,
|
track_delegation_if_needed,
|
||||||
)
|
)
|
||||||
@@ -848,13 +849,9 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
|||||||
call_id, func_name, func_args = info
|
call_id, func_name, func_args = info
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
if isinstance(func_args, str):
|
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id)
|
||||||
try:
|
if parse_error is not None:
|
||||||
args_dict = json.loads(func_args)
|
return parse_error
|
||||||
except json.JSONDecodeError:
|
|
||||||
args_dict = {}
|
|
||||||
else:
|
|
||||||
args_dict = func_args
|
|
||||||
|
|
||||||
# Get agent_key for event tracking
|
# Get agent_key for event tracking
|
||||||
agent_key = getattr(self.agent, "key", "unknown") if self.agent else "unknown"
|
agent_key = getattr(self.agent, "key", "unknown") if self.agent else "unknown"
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from pydantic import (
|
|||||||
BaseModel as PydanticBaseModel,
|
BaseModel as PydanticBaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
|
ValidationError,
|
||||||
create_model,
|
create_model,
|
||||||
field_validator,
|
field_validator,
|
||||||
)
|
)
|
||||||
@@ -150,14 +151,37 @@ class BaseTool(BaseModel, ABC):
|
|||||||
|
|
||||||
super().model_post_init(__context)
|
super().model_post_init(__context)
|
||||||
|
|
||||||
|
def _validate_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Validate keyword arguments against args_schema if present.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: The keyword arguments to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated (and possibly coerced) keyword arguments.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If validation against args_schema fails.
|
||||||
|
"""
|
||||||
|
if kwargs and self.args_schema is not None and self.args_schema.model_fields:
|
||||||
|
try:
|
||||||
|
validated = self.args_schema.model_validate(kwargs)
|
||||||
|
return validated.model_dump()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool '{self.name}' arguments validation failed: {e}"
|
||||||
|
) from e
|
||||||
|
return kwargs
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
kwargs = self._validate_kwargs(kwargs)
|
||||||
|
|
||||||
result = self._run(*args, **kwargs)
|
result = self._run(*args, **kwargs)
|
||||||
|
|
||||||
# If _run is async, we safely run it
|
|
||||||
if asyncio.iscoroutine(result):
|
if asyncio.iscoroutine(result):
|
||||||
result = asyncio.run(result)
|
result = asyncio.run(result)
|
||||||
|
|
||||||
@@ -179,6 +203,7 @@ class BaseTool(BaseModel, ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
The result of the tool execution.
|
The result of the tool execution.
|
||||||
"""
|
"""
|
||||||
|
kwargs = self._validate_kwargs(kwargs)
|
||||||
result = await self._arun(*args, **kwargs)
|
result = await self._arun(*args, **kwargs)
|
||||||
self.current_usage_count += 1
|
self.current_usage_count += 1
|
||||||
return result
|
return result
|
||||||
@@ -331,6 +356,8 @@ class Tool(BaseTool, Generic[P, R]):
|
|||||||
Returns:
|
Returns:
|
||||||
The result of the tool execution.
|
The result of the tool execution.
|
||||||
"""
|
"""
|
||||||
|
kwargs = self._validate_kwargs(kwargs)
|
||||||
|
|
||||||
result = self.func(*args, **kwargs)
|
result = self.func(*args, **kwargs)
|
||||||
|
|
||||||
if asyncio.iscoroutine(result):
|
if asyncio.iscoroutine(result):
|
||||||
@@ -361,6 +388,7 @@ class Tool(BaseTool, Generic[P, R]):
|
|||||||
Returns:
|
Returns:
|
||||||
The result of the tool execution.
|
The result of the tool execution.
|
||||||
"""
|
"""
|
||||||
|
kwargs = self._validate_kwargs(kwargs)
|
||||||
result = await self._arun(*args, **kwargs)
|
result = await self._arun(*args, **kwargs)
|
||||||
self.current_usage_count += 1
|
self.current_usage_count += 1
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -1146,6 +1146,36 @@ def extract_tool_call_info(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_call_args(
|
||||||
|
func_args: dict[str, Any] | str,
|
||||||
|
func_name: str,
|
||||||
|
call_id: str,
|
||||||
|
original_tool: Any = None,
|
||||||
|
) -> tuple[dict[str, Any], None] | tuple[None, dict[str, Any]]:
|
||||||
|
"""Parse tool call arguments from a JSON string or dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``(args_dict, None)`` on success, or ``(None, error_result)`` on
|
||||||
|
JSON parse failure where ``error_result`` is a ready-to-return dict
|
||||||
|
with the same shape as ``_execute_single_native_tool_call`` return values.
|
||||||
|
"""
|
||||||
|
if isinstance(func_args, str):
|
||||||
|
try:
|
||||||
|
return json.loads(func_args), None
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return None, {
|
||||||
|
"call_id": call_id,
|
||||||
|
"func_name": func_name,
|
||||||
|
"result": (
|
||||||
|
f"Error: Failed to parse tool arguments as JSON: {e}. "
|
||||||
|
f"Please provide valid JSON arguments for the '{func_name}' tool."
|
||||||
|
),
|
||||||
|
"from_cache": False,
|
||||||
|
"original_tool": original_tool,
|
||||||
|
}
|
||||||
|
return func_args, None
|
||||||
|
|
||||||
|
|
||||||
def _setup_before_llm_call_hooks(
|
def _setup_before_llm_call_hooks(
|
||||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
|
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
|
||||||
printer: Printer,
|
printer: Printer,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from unittest.mock import patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -1129,3 +1129,150 @@ class TestMaxUsageCountWithNativeToolCalling:
|
|||||||
# Verify the requested calls occurred while keeping usage bounded.
|
# Verify the requested calls occurred while keeping usage bounded.
|
||||||
assert tool.current_usage_count >= 2
|
assert tool.current_usage_count >= 2
|
||||||
assert tool.current_usage_count <= tool.max_usage_count
|
assert tool.current_usage_count <= tool.max_usage_count
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# JSON Parse Error Handling Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestNativeToolCallingJsonParseError:
|
||||||
|
"""Tests that malformed JSON tool arguments produce clear errors
|
||||||
|
instead of silently dropping all arguments."""
|
||||||
|
|
||||||
|
def _make_executor(self, tools: list[BaseTool]) -> "CrewAgentExecutor":
|
||||||
|
"""Create a minimal CrewAgentExecutor with mocked dependencies."""
|
||||||
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
|
from crewai.tools.base_tool import to_langchain
|
||||||
|
|
||||||
|
structured_tools = to_langchain(tools)
|
||||||
|
mock_agent = Mock()
|
||||||
|
mock_agent.key = "test_agent"
|
||||||
|
mock_agent.role = "tester"
|
||||||
|
mock_agent.verbose = False
|
||||||
|
mock_agent.fingerprint = None
|
||||||
|
mock_agent.tools_results = []
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.name = "test"
|
||||||
|
mock_task.description = "test"
|
||||||
|
mock_task.id = "test-id"
|
||||||
|
|
||||||
|
executor = object.__new__(CrewAgentExecutor)
|
||||||
|
executor.agent = mock_agent
|
||||||
|
executor.task = mock_task
|
||||||
|
executor.crew = Mock()
|
||||||
|
executor.tools = structured_tools
|
||||||
|
executor.original_tools = tools
|
||||||
|
executor.tools_handler = None
|
||||||
|
executor._printer = Mock()
|
||||||
|
executor.messages = []
|
||||||
|
|
||||||
|
return executor
|
||||||
|
|
||||||
|
def test_malformed_json_returns_parse_error(self) -> None:
|
||||||
|
"""Malformed JSON args must return a descriptive error, not silently become {}."""
|
||||||
|
|
||||||
|
class CodeTool(BaseTool):
|
||||||
|
name: str = "execute_code"
|
||||||
|
description: str = "Run code"
|
||||||
|
|
||||||
|
def _run(self, code: str) -> str:
|
||||||
|
return f"ran: {code}"
|
||||||
|
|
||||||
|
tool = CodeTool()
|
||||||
|
executor = self._make_executor([tool])
|
||||||
|
|
||||||
|
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||||
|
_, available_functions = convert_tools_to_openai_schema([tool])
|
||||||
|
|
||||||
|
malformed_json = '{"code": "print("hello")"}'
|
||||||
|
|
||||||
|
result = executor._execute_single_native_tool_call(
|
||||||
|
call_id="call_123",
|
||||||
|
func_name="execute_code",
|
||||||
|
func_args=malformed_json,
|
||||||
|
available_functions=available_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Failed to parse tool arguments as JSON" in result["result"]
|
||||||
|
assert tool.current_usage_count == 0
|
||||||
|
|
||||||
|
def test_valid_json_still_executes_normally(self) -> None:
|
||||||
|
"""Valid JSON args should execute the tool as before."""
|
||||||
|
|
||||||
|
class CodeTool(BaseTool):
|
||||||
|
name: str = "execute_code"
|
||||||
|
description: str = "Run code"
|
||||||
|
|
||||||
|
def _run(self, code: str) -> str:
|
||||||
|
return f"ran: {code}"
|
||||||
|
|
||||||
|
tool = CodeTool()
|
||||||
|
executor = self._make_executor([tool])
|
||||||
|
|
||||||
|
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||||
|
_, available_functions = convert_tools_to_openai_schema([tool])
|
||||||
|
|
||||||
|
valid_json = '{"code": "print(1)"}'
|
||||||
|
|
||||||
|
result = executor._execute_single_native_tool_call(
|
||||||
|
call_id="call_456",
|
||||||
|
func_name="execute_code",
|
||||||
|
func_args=valid_json,
|
||||||
|
available_functions=available_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["result"] == "ran: print(1)"
|
||||||
|
|
||||||
|
def test_dict_args_bypass_json_parsing(self) -> None:
|
||||||
|
"""When func_args is already a dict, no JSON parsing occurs."""
|
||||||
|
|
||||||
|
class CodeTool(BaseTool):
|
||||||
|
name: str = "execute_code"
|
||||||
|
description: str = "Run code"
|
||||||
|
|
||||||
|
def _run(self, code: str) -> str:
|
||||||
|
return f"ran: {code}"
|
||||||
|
|
||||||
|
tool = CodeTool()
|
||||||
|
executor = self._make_executor([tool])
|
||||||
|
|
||||||
|
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||||
|
_, available_functions = convert_tools_to_openai_schema([tool])
|
||||||
|
|
||||||
|
result = executor._execute_single_native_tool_call(
|
||||||
|
call_id="call_789",
|
||||||
|
func_name="execute_code",
|
||||||
|
func_args={"code": "x = 42"},
|
||||||
|
available_functions=available_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["result"] == "ran: x = 42"
|
||||||
|
|
||||||
|
def test_schema_validation_catches_missing_args_on_native_path(self) -> None:
|
||||||
|
"""The native function calling path should now enforce args_schema,
|
||||||
|
catching missing required fields before _run is called."""
|
||||||
|
|
||||||
|
class StrictTool(BaseTool):
|
||||||
|
name: str = "strict_tool"
|
||||||
|
description: str = "A tool with required args"
|
||||||
|
|
||||||
|
def _run(self, code: str, language: str) -> str:
|
||||||
|
return f"{language}: {code}"
|
||||||
|
|
||||||
|
tool = StrictTool()
|
||||||
|
executor = self._make_executor([tool])
|
||||||
|
|
||||||
|
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||||
|
_, available_functions = convert_tools_to_openai_schema([tool])
|
||||||
|
|
||||||
|
result = executor._execute_single_native_tool_call(
|
||||||
|
call_id="call_schema",
|
||||||
|
func_name="strict_tool",
|
||||||
|
func_args={"code": "print(1)"},
|
||||||
|
available_functions=available_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Error" in result["result"]
|
||||||
|
assert "validation failed" in result["result"].lower() or "missing" in result["result"].lower()
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from typing import Callable
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
@@ -230,3 +232,204 @@ def test_max_usage_count_is_respected():
|
|||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
assert tool.max_usage_count == 5
|
assert tool.max_usage_count == 5
|
||||||
assert tool.current_usage_count == 5
|
assert tool.current_usage_count == 5
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Schema Validation in run() Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class CodeExecutorInput(BaseModel):
|
||||||
|
code: str = Field(description="The code to execute")
|
||||||
|
language: str = Field(default="python", description="Programming language")
|
||||||
|
|
||||||
|
|
||||||
|
class CodeExecutorTool(BaseTool):
|
||||||
|
name: str = "code_executor"
|
||||||
|
description: str = "Execute code snippets"
|
||||||
|
args_schema: type[BaseModel] = CodeExecutorInput
|
||||||
|
|
||||||
|
def _run(self, code: str, language: str = "python") -> str:
|
||||||
|
return f"Executed {language}: {code}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseToolRunValidation:
|
||||||
|
"""Tests for args_schema validation in BaseTool.run()."""
|
||||||
|
|
||||||
|
def test_run_with_valid_kwargs_passes_validation(self) -> None:
|
||||||
|
"""Valid keyword arguments should pass schema validation and execute."""
|
||||||
|
t = CodeExecutorTool()
|
||||||
|
result = t.run(code="print('hello')")
|
||||||
|
assert result == "Executed python: print('hello')"
|
||||||
|
|
||||||
|
def test_run_with_all_kwargs_passes_validation(self) -> None:
|
||||||
|
"""All keyword arguments including optional ones should pass."""
|
||||||
|
t = CodeExecutorTool()
|
||||||
|
result = t.run(code="console.log('hi')", language="javascript")
|
||||||
|
assert result == "Executed javascript: console.log('hi')"
|
||||||
|
|
||||||
|
def test_run_with_missing_required_kwarg_raises(self) -> None:
|
||||||
|
"""Missing required kwargs should raise ValueError from schema validation."""
|
||||||
|
t = CodeExecutorTool()
|
||||||
|
with pytest.raises(ValueError, match="validation failed"):
|
||||||
|
t.run(language="python")
|
||||||
|
|
||||||
|
def test_run_with_wrong_field_name_raises(self) -> None:
|
||||||
|
"""Kwargs not matching any schema field should trigger validation error
|
||||||
|
for missing required fields."""
|
||||||
|
t = CodeExecutorTool()
|
||||||
|
with pytest.raises(ValueError, match="validation failed"):
|
||||||
|
t.run(wrong_arg="value")
|
||||||
|
|
||||||
|
def test_run_with_positional_args_skips_validation(self) -> None:
|
||||||
|
"""Positional-arg calls should bypass schema validation (backwards compat)."""
|
||||||
|
class SimpleTool(BaseTool):
|
||||||
|
name: str = "simple"
|
||||||
|
description: str = "A simple tool"
|
||||||
|
|
||||||
|
def _run(self, question: str) -> str:
|
||||||
|
return question
|
||||||
|
|
||||||
|
t = SimpleTool()
|
||||||
|
result = t.run("What is life?")
|
||||||
|
assert result == "What is life?"
|
||||||
|
|
||||||
|
def test_run_strips_extra_kwargs_from_llm(self) -> None:
|
||||||
|
"""Extra kwargs not in the schema should be silently stripped,
|
||||||
|
preventing unexpected-keyword crashes in _run."""
|
||||||
|
t = CodeExecutorTool()
|
||||||
|
result = t.run(code="1+1", extra_hallucinated_field="junk")
|
||||||
|
assert result == "Executed python: 1+1"
|
||||||
|
|
||||||
|
def test_run_increments_usage_after_validation(self) -> None:
|
||||||
|
"""Usage count should still increment after validated execution."""
|
||||||
|
t = CodeExecutorTool()
|
||||||
|
assert t.current_usage_count == 0
|
||||||
|
t.run(code="x = 1")
|
||||||
|
assert t.current_usage_count == 1
|
||||||
|
|
||||||
|
def test_run_does_not_increment_usage_on_validation_error(self) -> None:
|
||||||
|
"""Usage count should NOT increment when validation fails."""
|
||||||
|
t = CodeExecutorTool()
|
||||||
|
assert t.current_usage_count == 0
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
t.run(wrong="bad")
|
||||||
|
assert t.current_usage_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolDecoratorRunValidation:
|
||||||
|
"""Tests for args_schema validation in Tool.run() (decorator-based tools)."""
|
||||||
|
|
||||||
|
def test_decorator_tool_run_validates_kwargs(self) -> None:
|
||||||
|
"""Decorator-created tools should also validate kwargs against schema."""
|
||||||
|
@tool("execute_code")
|
||||||
|
def execute_code(code: str, language: str = "python") -> str:
|
||||||
|
"""Execute a code snippet."""
|
||||||
|
return f"Executed {language}: {code}"
|
||||||
|
|
||||||
|
result = execute_code.run(code="x = 1")
|
||||||
|
assert result == "Executed python: x = 1"
|
||||||
|
|
||||||
|
def test_decorator_tool_run_rejects_missing_required(self) -> None:
|
||||||
|
"""Decorator tools should reject missing required args via validation."""
|
||||||
|
@tool("execute_code")
|
||||||
|
def execute_code(code: str) -> str:
|
||||||
|
"""Execute a code snippet."""
|
||||||
|
return f"Executed: {code}"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="validation failed"):
|
||||||
|
execute_code.run(wrong_arg="value")
|
||||||
|
|
||||||
|
def test_decorator_tool_positional_args_still_work(self) -> None:
|
||||||
|
"""Positional args to decorator tools should bypass validation."""
|
||||||
|
@tool("greet")
|
||||||
|
def greet(name: str) -> str:
|
||||||
|
"""Greet someone."""
|
||||||
|
return f"Hello, {name}!"
|
||||||
|
|
||||||
|
result = greet.run("World")
|
||||||
|
assert result == "Hello, World!"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Async arun() Schema Validation Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCodeExecutorTool(BaseTool):
|
||||||
|
name: str = "async_code_executor"
|
||||||
|
description: str = "Execute code snippets asynchronously"
|
||||||
|
args_schema: type[BaseModel] = CodeExecutorInput
|
||||||
|
|
||||||
|
async def _arun(self, code: str, language: str = "python") -> str:
|
||||||
|
return f"Async executed {language}: {code}"
|
||||||
|
|
||||||
|
def _run(self, code: str, language: str = "python") -> str:
|
||||||
|
return f"Executed {language}: {code}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseToolArunValidation:
|
||||||
|
"""Tests for args_schema validation in BaseTool.arun()."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_arun_with_valid_kwargs_passes_validation(self) -> None:
|
||||||
|
"""Valid keyword arguments should pass schema validation in arun."""
|
||||||
|
t = AsyncCodeExecutorTool()
|
||||||
|
result = await t.arun(code="print('hello')")
|
||||||
|
assert result == "Async executed python: print('hello')"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_arun_with_missing_required_kwarg_raises(self) -> None:
|
||||||
|
"""Missing required kwargs should raise ValueError in arun."""
|
||||||
|
t = AsyncCodeExecutorTool()
|
||||||
|
with pytest.raises(ValueError, match="validation failed"):
|
||||||
|
await t.arun(language="python")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_arun_with_wrong_field_name_raises(self) -> None:
|
||||||
|
"""Kwargs not matching schema fields should trigger validation error in arun."""
|
||||||
|
t = AsyncCodeExecutorTool()
|
||||||
|
with pytest.raises(ValueError, match="validation failed"):
|
||||||
|
await t.arun(wrong_arg="value")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_arun_strips_extra_kwargs(self) -> None:
|
||||||
|
"""Extra kwargs not in the schema should be stripped in arun."""
|
||||||
|
t = AsyncCodeExecutorTool()
|
||||||
|
result = await t.arun(code="1+1", extra_field="junk")
|
||||||
|
assert result == "Async executed python: 1+1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_arun_does_not_increment_usage_on_validation_error(self) -> None:
|
||||||
|
"""Usage count should NOT increment when arun validation fails."""
|
||||||
|
t = AsyncCodeExecutorTool()
|
||||||
|
assert t.current_usage_count == 0
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await t.arun(wrong="bad")
|
||||||
|
assert t.current_usage_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolDecoratorArunValidation:
|
||||||
|
"""Tests for args_schema validation in Tool.arun() (decorator-based async tools)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_decorator_tool_arun_validates_kwargs(self) -> None:
|
||||||
|
"""Async decorator tools should validate kwargs in arun."""
|
||||||
|
@tool("async_execute")
|
||||||
|
async def async_execute(code: str, language: str = "python") -> str:
|
||||||
|
"""Execute code asynchronously."""
|
||||||
|
return f"Async {language}: {code}"
|
||||||
|
|
||||||
|
result = await async_execute.arun(code="x = 1")
|
||||||
|
assert result == "Async python: x = 1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_decorator_tool_arun_rejects_missing_required(self) -> None:
|
||||||
|
"""Async decorator tools should reject missing required args in arun."""
|
||||||
|
@tool("async_execute")
|
||||||
|
async def async_execute(code: str) -> str:
|
||||||
|
"""Execute code asynchronously."""
|
||||||
|
return f"Async: {code}"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="validation failed"):
|
||||||
|
await async_execute.arun(wrong_arg="value")
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from crewai.utilities.agent_utils import (
|
|||||||
_format_messages_for_summary,
|
_format_messages_for_summary,
|
||||||
_split_messages_into_chunks,
|
_split_messages_into_chunks,
|
||||||
convert_tools_to_openai_schema,
|
convert_tools_to_openai_schema,
|
||||||
|
parse_tool_call_args,
|
||||||
summarize_messages,
|
summarize_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -922,3 +923,56 @@ class TestParallelSummarizationVCR:
|
|||||||
assert summary_msg["role"] == "user"
|
assert summary_msg["role"] == "user"
|
||||||
assert "files" in summary_msg
|
assert "files" in summary_msg
|
||||||
assert "report.pdf" in summary_msg["files"]
|
assert "report.pdf" in summary_msg["files"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseToolCallArgs:
|
||||||
|
"""Unit tests for parse_tool_call_args."""
|
||||||
|
|
||||||
|
def test_valid_json_string_returns_dict(self) -> None:
|
||||||
|
args_dict, error = parse_tool_call_args('{"code": "print(1)"}', "run_code", "call_1")
|
||||||
|
assert error is None
|
||||||
|
assert args_dict == {"code": "print(1)"}
|
||||||
|
|
||||||
|
def test_malformed_json_returns_error_dict(self) -> None:
|
||||||
|
args_dict, error = parse_tool_call_args('{"code": "print("hi")"}', "run_code", "call_1")
|
||||||
|
assert args_dict is None
|
||||||
|
assert error is not None
|
||||||
|
assert error["call_id"] == "call_1"
|
||||||
|
assert error["func_name"] == "run_code"
|
||||||
|
assert error["from_cache"] is False
|
||||||
|
assert "Failed to parse tool arguments as JSON" in error["result"]
|
||||||
|
assert "run_code" in error["result"]
|
||||||
|
|
||||||
|
def test_malformed_json_preserves_original_tool(self) -> None:
|
||||||
|
mock_tool = object()
|
||||||
|
_, error = parse_tool_call_args("{bad}", "my_tool", "call_2", original_tool=mock_tool)
|
||||||
|
assert error is not None
|
||||||
|
assert error["original_tool"] is mock_tool
|
||||||
|
|
||||||
|
def test_malformed_json_original_tool_defaults_to_none(self) -> None:
|
||||||
|
_, error = parse_tool_call_args("{bad}", "my_tool", "call_3")
|
||||||
|
assert error is not None
|
||||||
|
assert error["original_tool"] is None
|
||||||
|
|
||||||
|
def test_dict_input_returned_directly(self) -> None:
|
||||||
|
func_args = {"code": "x = 42"}
|
||||||
|
args_dict, error = parse_tool_call_args(func_args, "run_code", "call_4")
|
||||||
|
assert error is None
|
||||||
|
assert args_dict == {"code": "x = 42"}
|
||||||
|
|
||||||
|
def test_empty_dict_input_returned_directly(self) -> None:
|
||||||
|
args_dict, error = parse_tool_call_args({}, "run_code", "call_5")
|
||||||
|
assert error is None
|
||||||
|
assert args_dict == {}
|
||||||
|
|
||||||
|
def test_valid_json_with_nested_values(self) -> None:
|
||||||
|
args_dict, error = parse_tool_call_args(
|
||||||
|
'{"query": "hello", "options": {"limit": 10}}', "search", "call_6"
|
||||||
|
)
|
||||||
|
assert error is None
|
||||||
|
assert args_dict == {"query": "hello", "options": {"limit": 10}}
|
||||||
|
|
||||||
|
def test_error_result_has_correct_keys(self) -> None:
|
||||||
|
_, error = parse_tool_call_args("{bad json}", "tool", "call_7")
|
||||||
|
assert error is not None
|
||||||
|
assert set(error.keys()) == {"call_id", "func_name", "result", "from_cache", "original_tool"}
|
||||||
|
|||||||
Reference in New Issue
Block a user