mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-03 13:18:29 +00:00
234 lines
7.0 KiB
Python
234 lines
7.0 KiB
Python
import json
|
|
import random
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from pydantic import BaseModel, Field
|
|
|
|
from crewai import Agent, Task
|
|
from crewai.tools import BaseTool
|
|
from crewai.tools.tool_usage import ToolUsage
|
|
|
|
|
|
class RandomNumberToolInput(BaseModel):
|
|
min_value: int = Field(
|
|
..., description="The minimum value of the range (inclusive)"
|
|
)
|
|
max_value: int = Field(
|
|
..., description="The maximum value of the range (inclusive)"
|
|
)
|
|
|
|
|
|
class RandomNumberTool(BaseTool):
|
|
name: str = "Random Number Generator"
|
|
description: str = "Generates a random number within a specified range"
|
|
args_schema: type[BaseModel] = RandomNumberToolInput
|
|
|
|
def _run(self, min_value: int, max_value: int) -> int:
|
|
return random.randint(min_value, max_value)
|
|
|
|
|
|
# Example agent and task
|
|
example_agent = Agent(
|
|
role="Number Generator",
|
|
goal="Generate random numbers for various purposes",
|
|
backstory="You are an AI agent specialized in generating random numbers within specified ranges.",
|
|
tools=[RandomNumberTool()],
|
|
verbose=True,
|
|
)
|
|
|
|
example_task = Task(
|
|
description="Generate a random number between 1 and 100",
|
|
expected_output="A random number between 1 and 100",
|
|
agent=example_agent,
|
|
)
|
|
|
|
|
|
def test_random_number_tool_range():
|
|
tool = RandomNumberTool()
|
|
result = tool._run(1, 10)
|
|
assert 1 <= result <= 10
|
|
|
|
|
|
def test_random_number_tool_invalid_range():
|
|
tool = RandomNumberTool()
|
|
with pytest.raises(ValueError):
|
|
tool._run(10, 1) # min_value > max_value
|
|
|
|
|
|
def test_random_number_tool_schema():
|
|
tool = RandomNumberTool()
|
|
|
|
# Get the schema using model_json_schema()
|
|
schema = tool.args_schema.model_json_schema()
|
|
|
|
# Convert the schema to a string
|
|
schema_str = json.dumps(schema)
|
|
|
|
# Check if the schema string contains the expected fields
|
|
assert "min_value" in schema_str
|
|
assert "max_value" in schema_str
|
|
|
|
# Parse the schema string back to a dictionary
|
|
schema_dict = json.loads(schema_str)
|
|
|
|
# Check if the schema contains the correct field types
|
|
assert schema_dict["properties"]["min_value"]["type"] == "integer"
|
|
assert schema_dict["properties"]["max_value"]["type"] == "integer"
|
|
|
|
# Check if the schema contains the field descriptions
|
|
assert (
|
|
"minimum value" in schema_dict["properties"]["min_value"]["description"].lower()
|
|
)
|
|
assert (
|
|
"maximum value" in schema_dict["properties"]["max_value"]["description"].lower()
|
|
)
|
|
|
|
|
|
def test_tool_usage_render():
|
|
tool = RandomNumberTool()
|
|
|
|
tool_usage = ToolUsage(
|
|
tools_handler=MagicMock(),
|
|
tools=[tool],
|
|
original_tools=[tool],
|
|
tools_description="Sample tool for testing",
|
|
tools_names="random_number_generator",
|
|
task=MagicMock(),
|
|
function_calling_llm=MagicMock(),
|
|
agent=MagicMock(),
|
|
action=MagicMock(),
|
|
)
|
|
|
|
rendered = tool_usage._render()
|
|
|
|
# Updated checks to match the actual output
|
|
assert "Tool Name: Random Number Generator" in rendered
|
|
assert "Tool Arguments:" in rendered
|
|
assert (
|
|
"'min_value': {'description': 'The minimum value of the range (inclusive)', 'type': 'int'}"
|
|
in rendered
|
|
)
|
|
assert (
|
|
"'max_value': {'description': 'The maximum value of the range (inclusive)', 'type': 'int'}"
|
|
in rendered
|
|
)
|
|
assert (
|
|
"Tool Description: Generates a random number within a specified range"
|
|
in rendered
|
|
)
|
|
assert (
|
|
"Tool Name: Random Number Generator\nTool Arguments: {'min_value': {'description': 'The minimum value of the range (inclusive)', 'type': 'int'}, 'max_value': {'description': 'The maximum value of the range (inclusive)', 'type': 'int'}}\nTool Description: Generates a random number within a specified range"
|
|
in rendered
|
|
)
|
|
|
|
|
|
def test_validate_tool_input_booleans_and_none():
|
|
# Create a ToolUsage instance with mocks
|
|
tool_usage = ToolUsage(
|
|
tools_handler=MagicMock(),
|
|
tools=[],
|
|
original_tools=[],
|
|
tools_description="",
|
|
tools_names="",
|
|
task=MagicMock(),
|
|
function_calling_llm=MagicMock(),
|
|
agent=MagicMock(),
|
|
action=MagicMock(),
|
|
)
|
|
|
|
# Input with booleans and None
|
|
tool_input = '{"key1": True, "key2": False, "key3": None}'
|
|
expected_arguments = {"key1": True, "key2": False, "key3": None}
|
|
|
|
arguments = tool_usage._validate_tool_input(tool_input)
|
|
assert arguments == expected_arguments
|
|
|
|
|
|
def test_validate_tool_input_mixed_types():
|
|
# Create a ToolUsage instance with mocks
|
|
tool_usage = ToolUsage(
|
|
tools_handler=MagicMock(),
|
|
tools=[],
|
|
original_tools=[],
|
|
tools_description="",
|
|
tools_names="",
|
|
task=MagicMock(),
|
|
function_calling_llm=MagicMock(),
|
|
agent=MagicMock(),
|
|
action=MagicMock(),
|
|
)
|
|
|
|
# Input with mixed types
|
|
tool_input = '{"number": 123, "text": "Some text", "flag": True}'
|
|
expected_arguments = {"number": 123, "text": "Some text", "flag": True}
|
|
|
|
arguments = tool_usage._validate_tool_input(tool_input)
|
|
assert arguments == expected_arguments
|
|
|
|
|
|
def test_validate_tool_input_single_quotes():
|
|
# Create a ToolUsage instance with mocks
|
|
tool_usage = ToolUsage(
|
|
tools_handler=MagicMock(),
|
|
tools=[],
|
|
original_tools=[],
|
|
tools_description="",
|
|
tools_names="",
|
|
task=MagicMock(),
|
|
function_calling_llm=MagicMock(),
|
|
agent=MagicMock(),
|
|
action=MagicMock(),
|
|
)
|
|
|
|
# Input with single quotes instead of double quotes
|
|
tool_input = "{'key': 'value', 'flag': True}"
|
|
expected_arguments = {"key": "value", "flag": True}
|
|
|
|
arguments = tool_usage._validate_tool_input(tool_input)
|
|
assert arguments == expected_arguments
|
|
|
|
|
|
def test_validate_tool_input_invalid_json_repairable():
|
|
# Create a ToolUsage instance with mocks
|
|
tool_usage = ToolUsage(
|
|
tools_handler=MagicMock(),
|
|
tools=[],
|
|
original_tools=[],
|
|
tools_description="",
|
|
tools_names="",
|
|
task=MagicMock(),
|
|
function_calling_llm=MagicMock(),
|
|
agent=MagicMock(),
|
|
action=MagicMock(),
|
|
)
|
|
|
|
# Invalid JSON input that can be repaired
|
|
tool_input = '{"key": "value", "list": [1, 2, 3,]}'
|
|
expected_arguments = {"key": "value", "list": [1, 2, 3]}
|
|
|
|
arguments = tool_usage._validate_tool_input(tool_input)
|
|
assert arguments == expected_arguments
|
|
|
|
|
|
def test_validate_tool_input_with_special_characters():
|
|
# Create a ToolUsage instance with mocks
|
|
tool_usage = ToolUsage(
|
|
tools_handler=MagicMock(),
|
|
tools=[],
|
|
original_tools=[],
|
|
tools_description="",
|
|
tools_names="",
|
|
task=MagicMock(),
|
|
function_calling_llm=MagicMock(),
|
|
agent=MagicMock(),
|
|
action=MagicMock(),
|
|
)
|
|
|
|
# Input with special characters
|
|
tool_input = '{"message": "Hello, world! \u263A", "valid": True}'
|
|
expected_arguments = {"message": "Hello, world! ☺", "valid": True}
|
|
|
|
arguments = tool_usage._validate_tool_input(tool_input)
|
|
assert arguments == expected_arguments
|