mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 23:32:39 +00:00
feat: enhance JSON argument parsing and validation in CrewAgentExecutor and BaseTool
* feat: enhance JSON argument parsing and validation in CrewAgentExecutor and BaseTool - Added error handling for malformed JSON tool arguments in CrewAgentExecutor, providing descriptive error messages. - Implemented schema validation for tool arguments in BaseTool, ensuring that invalid arguments raise appropriate exceptions. - Introduced tests to verify correct behavior for both valid and invalid JSON inputs, enhancing robustness of tool execution. * refactor: improve argument validation in BaseTool - Introduced a new private method to handle argument validation for tools, enhancing code clarity and reusability. - Updated the method to utilize the new validation method, ensuring consistent error handling for invalid arguments. - Enhanced exception handling to specifically catch , providing clearer error messages for tool argument validation failures. * feat: introduce parse_tool_call_args for improved argument parsing - Added a new utility function, parse_tool_call_args, to handle parsing of tool call arguments from JSON strings or dictionaries, enhancing error handling for malformed JSON inputs. - Updated CrewAgentExecutor and AgentExecutor to utilize the new parsing function, streamlining argument validation and improving clarity in error reporting. - Introduced unit tests for parse_tool_call_args to ensure robust functionality and correct handling of various input scenarios. * feat: add keyword argument validation in BaseTool and Tool classes - Introduced a new method `_validate_kwargs` in BaseTool to validate keyword arguments against the defined schema, ensuring proper argument handling. - Updated the `run` and `arun` methods in both BaseTool and Tool classes to utilize the new validation method, improving error handling and robustness. - Added comprehensive tests for asynchronous execution in `TestBaseToolArunValidation` to verify correct behavior for valid and invalid keyword arguments. * Potential fix for pull request finding 'Syntax error' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --------- Co-authored-by: lorenzejay <lorenzejaytech@gmail.com> Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,8 @@ from typing import Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
@@ -230,3 +232,204 @@ def test_max_usage_count_is_respected():
|
||||
crew.kickoff()
|
||||
assert tool.max_usage_count == 5
|
||||
assert tool.current_usage_count == 5
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Validation in run() Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CodeExecutorInput(BaseModel):
|
||||
code: str = Field(description="The code to execute")
|
||||
language: str = Field(default="python", description="Programming language")
|
||||
|
||||
|
||||
class CodeExecutorTool(BaseTool):
|
||||
name: str = "code_executor"
|
||||
description: str = "Execute code snippets"
|
||||
args_schema: type[BaseModel] = CodeExecutorInput
|
||||
|
||||
def _run(self, code: str, language: str = "python") -> str:
|
||||
return f"Executed {language}: {code}"
|
||||
|
||||
|
||||
class TestBaseToolRunValidation:
|
||||
"""Tests for args_schema validation in BaseTool.run()."""
|
||||
|
||||
def test_run_with_valid_kwargs_passes_validation(self) -> None:
|
||||
"""Valid keyword arguments should pass schema validation and execute."""
|
||||
t = CodeExecutorTool()
|
||||
result = t.run(code="print('hello')")
|
||||
assert result == "Executed python: print('hello')"
|
||||
|
||||
def test_run_with_all_kwargs_passes_validation(self) -> None:
|
||||
"""All keyword arguments including optional ones should pass."""
|
||||
t = CodeExecutorTool()
|
||||
result = t.run(code="console.log('hi')", language="javascript")
|
||||
assert result == "Executed javascript: console.log('hi')"
|
||||
|
||||
def test_run_with_missing_required_kwarg_raises(self) -> None:
|
||||
"""Missing required kwargs should raise ValueError from schema validation."""
|
||||
t = CodeExecutorTool()
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
t.run(language="python")
|
||||
|
||||
def test_run_with_wrong_field_name_raises(self) -> None:
|
||||
"""Kwargs not matching any schema field should trigger validation error
|
||||
for missing required fields."""
|
||||
t = CodeExecutorTool()
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
t.run(wrong_arg="value")
|
||||
|
||||
def test_run_with_positional_args_skips_validation(self) -> None:
|
||||
"""Positional-arg calls should bypass schema validation (backwards compat)."""
|
||||
class SimpleTool(BaseTool):
|
||||
name: str = "simple"
|
||||
description: str = "A simple tool"
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
|
||||
t = SimpleTool()
|
||||
result = t.run("What is life?")
|
||||
assert result == "What is life?"
|
||||
|
||||
def test_run_strips_extra_kwargs_from_llm(self) -> None:
|
||||
"""Extra kwargs not in the schema should be silently stripped,
|
||||
preventing unexpected-keyword crashes in _run."""
|
||||
t = CodeExecutorTool()
|
||||
result = t.run(code="1+1", extra_hallucinated_field="junk")
|
||||
assert result == "Executed python: 1+1"
|
||||
|
||||
def test_run_increments_usage_after_validation(self) -> None:
|
||||
"""Usage count should still increment after validated execution."""
|
||||
t = CodeExecutorTool()
|
||||
assert t.current_usage_count == 0
|
||||
t.run(code="x = 1")
|
||||
assert t.current_usage_count == 1
|
||||
|
||||
def test_run_does_not_increment_usage_on_validation_error(self) -> None:
|
||||
"""Usage count should NOT increment when validation fails."""
|
||||
t = CodeExecutorTool()
|
||||
assert t.current_usage_count == 0
|
||||
with pytest.raises(ValueError):
|
||||
t.run(wrong="bad")
|
||||
assert t.current_usage_count == 0
|
||||
|
||||
|
||||
class TestToolDecoratorRunValidation:
|
||||
"""Tests for args_schema validation in Tool.run() (decorator-based tools)."""
|
||||
|
||||
def test_decorator_tool_run_validates_kwargs(self) -> None:
|
||||
"""Decorator-created tools should also validate kwargs against schema."""
|
||||
@tool("execute_code")
|
||||
def execute_code(code: str, language: str = "python") -> str:
|
||||
"""Execute a code snippet."""
|
||||
return f"Executed {language}: {code}"
|
||||
|
||||
result = execute_code.run(code="x = 1")
|
||||
assert result == "Executed python: x = 1"
|
||||
|
||||
def test_decorator_tool_run_rejects_missing_required(self) -> None:
|
||||
"""Decorator tools should reject missing required args via validation."""
|
||||
@tool("execute_code")
|
||||
def execute_code(code: str) -> str:
|
||||
"""Execute a code snippet."""
|
||||
return f"Executed: {code}"
|
||||
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
execute_code.run(wrong_arg="value")
|
||||
|
||||
def test_decorator_tool_positional_args_still_work(self) -> None:
|
||||
"""Positional args to decorator tools should bypass validation."""
|
||||
@tool("greet")
|
||||
def greet(name: str) -> str:
|
||||
"""Greet someone."""
|
||||
return f"Hello, {name}!"
|
||||
|
||||
result = greet.run("World")
|
||||
assert result == "Hello, World!"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Async arun() Schema Validation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AsyncCodeExecutorTool(BaseTool):
|
||||
name: str = "async_code_executor"
|
||||
description: str = "Execute code snippets asynchronously"
|
||||
args_schema: type[BaseModel] = CodeExecutorInput
|
||||
|
||||
async def _arun(self, code: str, language: str = "python") -> str:
|
||||
return f"Async executed {language}: {code}"
|
||||
|
||||
def _run(self, code: str, language: str = "python") -> str:
|
||||
return f"Executed {language}: {code}"
|
||||
|
||||
|
||||
class TestBaseToolArunValidation:
|
||||
"""Tests for args_schema validation in BaseTool.arun()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_with_valid_kwargs_passes_validation(self) -> None:
|
||||
"""Valid keyword arguments should pass schema validation in arun."""
|
||||
t = AsyncCodeExecutorTool()
|
||||
result = await t.arun(code="print('hello')")
|
||||
assert result == "Async executed python: print('hello')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_with_missing_required_kwarg_raises(self) -> None:
|
||||
"""Missing required kwargs should raise ValueError in arun."""
|
||||
t = AsyncCodeExecutorTool()
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
await t.arun(language="python")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_with_wrong_field_name_raises(self) -> None:
|
||||
"""Kwargs not matching schema fields should trigger validation error in arun."""
|
||||
t = AsyncCodeExecutorTool()
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
await t.arun(wrong_arg="value")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_strips_extra_kwargs(self) -> None:
|
||||
"""Extra kwargs not in the schema should be stripped in arun."""
|
||||
t = AsyncCodeExecutorTool()
|
||||
result = await t.arun(code="1+1", extra_field="junk")
|
||||
assert result == "Async executed python: 1+1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_does_not_increment_usage_on_validation_error(self) -> None:
|
||||
"""Usage count should NOT increment when arun validation fails."""
|
||||
t = AsyncCodeExecutorTool()
|
||||
assert t.current_usage_count == 0
|
||||
with pytest.raises(ValueError):
|
||||
await t.arun(wrong="bad")
|
||||
assert t.current_usage_count == 0
|
||||
|
||||
|
||||
class TestToolDecoratorArunValidation:
|
||||
"""Tests for args_schema validation in Tool.arun() (decorator-based async tools)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_decorator_tool_arun_validates_kwargs(self) -> None:
|
||||
"""Async decorator tools should validate kwargs in arun."""
|
||||
@tool("async_execute")
|
||||
async def async_execute(code: str, language: str = "python") -> str:
|
||||
"""Execute code asynchronously."""
|
||||
return f"Async {language}: {code}"
|
||||
|
||||
result = await async_execute.arun(code="x = 1")
|
||||
assert result == "Async python: x = 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_decorator_tool_arun_rejects_missing_required(self) -> None:
|
||||
"""Async decorator tools should reject missing required args in arun."""
|
||||
@tool("async_execute")
|
||||
async def async_execute(code: str) -> str:
|
||||
"""Execute code asynchronously."""
|
||||
return f"Async: {code}"
|
||||
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
await async_execute.arun(wrong_arg="value")
|
||||
|
||||
Reference in New Issue
Block a user