mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
fix tool calling issue (#1467)
* fix tool calling issue * Update tool type check * Drop print
This commit is contained in:
committed by
GitHub
parent
69cd84d78f
commit
b3d16bba5f
@@ -394,7 +394,7 @@ class Agent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
tool_strings = []
|
tool_strings = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
args_schema = str(tool.args)
|
args_schema = str(tool.model_fields)
|
||||||
if hasattr(tool, "func") and tool.func:
|
if hasattr(tool, "func") and tool.func:
|
||||||
sig = signature(tool.func)
|
sig = signature(tool.func)
|
||||||
description = (
|
description = (
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||||
from crewai.agents.parser import (
|
from crewai.agents.parser import (
|
||||||
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE,
|
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE,
|
||||||
@@ -19,7 +20,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|||||||
)
|
)
|
||||||
from crewai.utilities.logger import Logger
|
from crewai.utilities.logger import Logger
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
|
||||||
|
|
||||||
|
|
||||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||||
@@ -323,9 +323,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
|
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
|
||||||
train_iteration = self.crew._train_iteration
|
train_iteration = self.crew._train_iteration
|
||||||
if agent_id in training_data and isinstance(train_iteration, int):
|
if agent_id in training_data and isinstance(train_iteration, int):
|
||||||
training_data[agent_id][train_iteration]["improved_output"] = (
|
training_data[agent_id][train_iteration][
|
||||||
result.output
|
"improved_output"
|
||||||
)
|
] = result.output
|
||||||
training_handler.save(training_data)
|
training_handler.save(training_data)
|
||||||
else:
|
else:
|
||||||
self._logger.log(
|
self._logger.log(
|
||||||
|
|||||||
@@ -6,14 +6,13 @@ from difflib import SequenceMatcher
|
|||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, List, Union
|
from typing import Any, List, Union
|
||||||
|
|
||||||
|
import crewai.utilities.events as events
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.telemetry import Telemetry
|
from crewai.telemetry import Telemetry
|
||||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||||
from crewai.tools.tool_usage_events import ToolUsageError, ToolUsageFinished
|
from crewai.tools.tool_usage_events import ToolUsageError, ToolUsageFinished
|
||||||
from crewai.utilities import I18N, Converter, ConverterError, Printer
|
from crewai.utilities import I18N, Converter, ConverterError, Printer
|
||||||
import crewai.utilities.events as events
|
|
||||||
|
|
||||||
|
|
||||||
agentops = None
|
agentops = None
|
||||||
if os.environ.get("AGENTOPS_API_KEY"):
|
if os.environ.get("AGENTOPS_API_KEY"):
|
||||||
@@ -300,8 +299,11 @@ class ToolUsage:
|
|||||||
descriptions = []
|
descriptions = []
|
||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
args = {
|
args = {
|
||||||
k: {k2: v2 for k2, v2 in v.items() if k2 in ["description", "type"]}
|
name: {
|
||||||
for k, v in tool.args.items()
|
"description": field.description,
|
||||||
|
"type": field.annotation.__name__,
|
||||||
|
}
|
||||||
|
for name, field in tool.args_schema.model_fields.items()
|
||||||
}
|
}
|
||||||
descriptions.append(
|
descriptions.append(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
|
|||||||
143
tests/tools/test_tool_usage.py
Normal file
143
tests/tools/test_tool_usage.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import json
|
||||||
|
import random
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from crewai_tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai import Agent, Crew, Task
|
||||||
|
from crewai.tools.tool_usage import ToolUsage
|
||||||
|
|
||||||
|
|
||||||
|
class RandomNumberToolInput(BaseModel):
|
||||||
|
min_value: int = Field(
|
||||||
|
..., description="The minimum value of the range (inclusive)"
|
||||||
|
)
|
||||||
|
max_value: int = Field(
|
||||||
|
..., description="The maximum value of the range (inclusive)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomNumberTool(BaseTool):
|
||||||
|
name: str = "Random Number Generator"
|
||||||
|
description: str = "Generates a random number within a specified range"
|
||||||
|
args_schema: type[BaseModel] = RandomNumberToolInput
|
||||||
|
|
||||||
|
def _run(self, min_value: int, max_value: int) -> int:
|
||||||
|
return random.randint(min_value, max_value)
|
||||||
|
|
||||||
|
|
||||||
|
# Example agent and task
|
||||||
|
example_agent = Agent(
|
||||||
|
role="Number Generator",
|
||||||
|
goal="Generate random numbers for various purposes",
|
||||||
|
backstory="You are an AI agent specialized in generating random numbers within specified ranges.",
|
||||||
|
tools=[RandomNumberTool()],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
example_task = Task(
|
||||||
|
description="Generate a random number between 1 and 100",
|
||||||
|
expected_output="A random number between 1 and 100",
|
||||||
|
agent=example_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_number_tool_usage():
|
||||||
|
crew = Crew(
|
||||||
|
agents=[example_agent],
|
||||||
|
tasks=[example_task],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(random, "randint", return_value=42):
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
assert "42" in result.raw
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_number_tool_range():
|
||||||
|
tool = RandomNumberTool()
|
||||||
|
result = tool._run(1, 10)
|
||||||
|
assert 1 <= result <= 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_number_tool_with_crew():
|
||||||
|
crew = Crew(
|
||||||
|
agents=[example_agent],
|
||||||
|
tasks=[example_task],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
# Check if the result contains a number between 1 and 100
|
||||||
|
assert any(str(num) in result.raw for num in range(1, 101))
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_number_tool_invalid_range():
|
||||||
|
tool = RandomNumberTool()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tool._run(10, 1) # min_value > max_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_number_tool_schema():
|
||||||
|
tool = RandomNumberTool()
|
||||||
|
|
||||||
|
# Get the schema using model_json_schema()
|
||||||
|
schema = tool.args_schema.model_json_schema()
|
||||||
|
|
||||||
|
# Convert the schema to a string
|
||||||
|
schema_str = json.dumps(schema)
|
||||||
|
|
||||||
|
# Check if the schema string contains the expected fields
|
||||||
|
assert "min_value" in schema_str
|
||||||
|
assert "max_value" in schema_str
|
||||||
|
|
||||||
|
# Parse the schema string back to a dictionary
|
||||||
|
schema_dict = json.loads(schema_str)
|
||||||
|
|
||||||
|
# Check if the schema contains the correct field types
|
||||||
|
assert schema_dict["properties"]["min_value"]["type"] == "integer"
|
||||||
|
assert schema_dict["properties"]["max_value"]["type"] == "integer"
|
||||||
|
|
||||||
|
# Check if the schema contains the field descriptions
|
||||||
|
assert (
|
||||||
|
"minimum value" in schema_dict["properties"]["min_value"]["description"].lower()
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"maximum value" in schema_dict["properties"]["max_value"]["description"].lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_usage_render():
|
||||||
|
tool = RandomNumberTool()
|
||||||
|
|
||||||
|
tool_usage = ToolUsage(
|
||||||
|
tools_handler=MagicMock(),
|
||||||
|
tools=[tool],
|
||||||
|
original_tools=[tool],
|
||||||
|
tools_description="Sample tool for testing",
|
||||||
|
tools_names="random_number_generator",
|
||||||
|
task=MagicMock(),
|
||||||
|
function_calling_llm=MagicMock(),
|
||||||
|
agent=MagicMock(),
|
||||||
|
action=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
rendered = tool_usage._render()
|
||||||
|
|
||||||
|
# Updated checks to match the actual output
|
||||||
|
assert "Tool Name: random number generator" in rendered
|
||||||
|
assert (
|
||||||
|
"Random Number Generator(min_value: 'integer', max_value: 'integer') - Generates a random number within a specified range min_value: 'The minimum value of the range (inclusive)', max_value: 'The maximum value of the range (inclusive)'"
|
||||||
|
in rendered
|
||||||
|
)
|
||||||
|
assert "Tool Arguments:" in rendered
|
||||||
|
assert (
|
||||||
|
"'min_value': {'description': 'The minimum value of the range (inclusive)', 'type': 'int'}"
|
||||||
|
in rendered
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"'max_value': {'description': 'The maximum value of the range (inclusive)', 'type': 'int'}"
|
||||||
|
in rendered
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user