mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-04 08:42:38 +00:00
Compare commits
4 Commits
devin/1747
...
devin/1747
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa1b915209 | ||
|
|
8483d1c772 | ||
|
|
aac875508d | ||
|
|
fed397f745 |
@@ -20,12 +20,12 @@ from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Converter, Prompts
|
||||
from crewai.utilities.agent_utils import (
|
||||
get_tool_names,
|
||||
load_agent_from_repository,
|
||||
parse_tools,
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.utilities.events.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
@@ -143,42 +143,9 @@ class Agent(BaseAgent):
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v):
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return cls._load_agent_from_repository(from_repository) | v
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def _load_agent_from_repository(cls, from_repository: str) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
if from_repository:
|
||||
import importlib
|
||||
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
client = PlusAPI(api_key=get_auth_token())
|
||||
response = client.get_agent(from_repository)
|
||||
if response.status_code != 200:
|
||||
raise AgentRepositoryError(
|
||||
f"Agent {from_repository} could not be loaded: {response.text}"
|
||||
)
|
||||
|
||||
agent = response.json()
|
||||
for key, value in agent.items():
|
||||
if key == "tools":
|
||||
attributes[key] = []
|
||||
for tool_name in value:
|
||||
try:
|
||||
module = importlib.import_module("crewai_tools")
|
||||
tool_class = getattr(module, tool_name)
|
||||
attributes[key].append(tool_class())
|
||||
except Exception as e:
|
||||
raise AgentRepositoryError(
|
||||
f"Tool {tool_name} could not be loaded: {e}"
|
||||
) from e
|
||||
else:
|
||||
attributes[key] = value
|
||||
return attributes
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
@@ -173,11 +173,18 @@ class CrewStructuredTool:
|
||||
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
This method handles different input formats from various LLM providers,
|
||||
including nested dictionaries with 'value' fields that some providers use.
|
||||
|
||||
Args:
|
||||
raw_args: The raw arguments to parse, either as a string or dict
|
||||
raw_args: The raw arguments to parse, either as a string or dict.
|
||||
Supports nested dictionaries with 'value' field for LLM provider compatibility.
|
||||
|
||||
Returns:
|
||||
The validated arguments as a dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If argument parsing or validation fails
|
||||
"""
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
@@ -187,6 +194,31 @@ class CrewStructuredTool:
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse arguments as JSON: {e}")
|
||||
|
||||
# Handle nested dictionaries with 'value' field for all parameter types
|
||||
if isinstance(raw_args, dict):
|
||||
schema_fields = self.args_schema.model_fields
|
||||
|
||||
for field_name, field_value in list(raw_args.items()):
|
||||
# Check if this field exists in the schema
|
||||
if field_name in schema_fields:
|
||||
# Handle nested dictionaries with 'value' field
|
||||
if isinstance(field_value, dict):
|
||||
if 'value' in field_value:
|
||||
# Extract the value from the nested dictionary
|
||||
value = field_value['value']
|
||||
self._logger.debug(f"Extracting value from nested dict for {field_name}")
|
||||
|
||||
expected_type = schema_fields[field_name].annotation
|
||||
|
||||
if expected_type in (str, int, float, bool) and not isinstance(value, expected_type):
|
||||
self._logger.warning(
|
||||
f"Type mismatch for {field_name}: expected {expected_type}, got {type(value)}"
|
||||
)
|
||||
|
||||
raw_args[field_name] = value
|
||||
else:
|
||||
self._logger.debug(f"Nested dict for {field_name} has no 'value' key")
|
||||
|
||||
try:
|
||||
validated_args = self.args_schema.model_validate(raw_args)
|
||||
return validated_args.model_dump()
|
||||
|
||||
@@ -16,6 +16,7 @@ from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities import I18N, Printer
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
@@ -428,3 +429,36 @@ def show_agent_logs(
|
||||
printer.print(
|
||||
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
|
||||
)
|
||||
|
||||
|
||||
def load_agent_from_repository(from_repository: str) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
if from_repository:
|
||||
import importlib
|
||||
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
client = PlusAPI(api_key=get_auth_token())
|
||||
response = client.get_agent(from_repository)
|
||||
if response.status_code != 200:
|
||||
raise AgentRepositoryError(
|
||||
f"Agent {from_repository} could not be loaded: {response.text}"
|
||||
)
|
||||
|
||||
agent = response.json()
|
||||
for key, value in agent.items():
|
||||
if key == "tools":
|
||||
attributes[key] = []
|
||||
for tool_name in value:
|
||||
try:
|
||||
module = importlib.import_module("crewai_tools")
|
||||
tool_class = getattr(module, tool_name)
|
||||
attributes[key].append(tool_class())
|
||||
except Exception as e:
|
||||
raise AgentRepositoryError(
|
||||
f"Tool {tool_name} could not be loaded: {e}"
|
||||
) from e
|
||||
else:
|
||||
attributes[key] = value
|
||||
return attributes
|
||||
|
||||
173
tests/tools/test_structured_tool_nested_dict.py
Normal file
173
tests/tools/test_structured_tool_nested_dict.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
class StringInputSchema(BaseModel):
|
||||
"""Schema with a string input field."""
|
||||
query: str = Field(description="A string input parameter")
|
||||
|
||||
|
||||
class IntInputSchema(BaseModel):
|
||||
"""Schema with an integer input field."""
|
||||
number: int = Field(description="An integer input parameter")
|
||||
|
||||
|
||||
class ComplexInputSchema(BaseModel):
|
||||
"""Schema with multiple fields of different types."""
|
||||
text: str = Field(description="A string parameter")
|
||||
number: int = Field(description="An integer parameter")
|
||||
flag: bool = Field(description="A boolean parameter")
|
||||
|
||||
|
||||
def test_parse_args_with_string_input():
|
||||
"""Test that string inputs are parsed correctly."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with direct string input
|
||||
result = tool._parse_args({"query": "test string"})
|
||||
assert result["query"] == "test string"
|
||||
assert isinstance(result["query"], str)
|
||||
|
||||
# Test with JSON string input
|
||||
result = tool._parse_args('{"query": "json string"}')
|
||||
assert result["query"] == "json string"
|
||||
assert isinstance(result["query"], str)
|
||||
|
||||
|
||||
def test_parse_args_with_nested_dict_for_string():
|
||||
"""Test that nested dictionaries with 'value' field are handled correctly for string fields."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with nested dict input (simulating the issue from different LLM providers)
|
||||
nested_input = {"query": {"description": "A string input parameter", "value": "test value"}}
|
||||
result = tool._parse_args(nested_input)
|
||||
assert result["query"] == "test value"
|
||||
assert isinstance(result["query"], str)
|
||||
|
||||
|
||||
def test_parse_args_with_nested_dict_for_int():
|
||||
"""Test that nested dictionaries with 'value' field are handled correctly for int fields."""
|
||||
def test_func(number: int) -> str:
|
||||
return f"Processed: {number}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="IntTool",
|
||||
description="A tool that processes integer input"
|
||||
)
|
||||
|
||||
# Test with nested dict input for int field
|
||||
nested_input = {"number": {"description": "An integer input parameter", "value": 42}}
|
||||
result = tool._parse_args(nested_input)
|
||||
assert result["number"] == 42
|
||||
assert isinstance(result["number"], int)
|
||||
|
||||
|
||||
def test_parse_args_with_complex_input():
|
||||
"""Test that complex inputs with multiple fields are handled correctly."""
|
||||
def test_func(text: str, number: int, flag: bool) -> str:
|
||||
return f"Processed: {text}, {number}, {flag}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="ComplexTool",
|
||||
description="A tool that processes complex input"
|
||||
)
|
||||
|
||||
# Test with mixed nested dict input
|
||||
complex_input = {
|
||||
"text": {"description": "A string parameter", "value": "test text"},
|
||||
"number": 42,
|
||||
"flag": True
|
||||
}
|
||||
result = tool._parse_args(complex_input)
|
||||
assert result["text"] == "test text"
|
||||
assert isinstance(result["text"], str)
|
||||
assert result["number"] == 42
|
||||
assert isinstance(result["number"], int)
|
||||
assert result["flag"] is True
|
||||
assert isinstance(result["flag"], bool)
|
||||
|
||||
|
||||
def test_invoke_with_nested_dict():
|
||||
"""Test that invoking a tool with nested dict input works correctly."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test invoking with nested dict input
|
||||
nested_input = {"query": {"description": "A string input parameter", "value": "test value"}}
|
||||
result = tool.invoke(nested_input)
|
||||
assert result == "Processed: test value"
|
||||
|
||||
|
||||
def test_nested_dict_without_value_key():
|
||||
"""Test that nested dictionaries without 'value' field raise appropriate errors."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with nested dict without 'value' key
|
||||
invalid_input = {"query": {"description": "A string input parameter", "other_key": "test"}}
|
||||
with pytest.raises(ValueError):
|
||||
tool._parse_args(invalid_input)
|
||||
|
||||
|
||||
def test_empty_nested_dict():
|
||||
"""Test handling of empty nested dictionaries."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with empty nested dict
|
||||
empty_dict_input = {"query": {}}
|
||||
with pytest.raises(ValueError):
|
||||
tool._parse_args(empty_dict_input)
|
||||
|
||||
|
||||
def test_deeply_nested_structure():
|
||||
"""Test handling of deeply nested structures."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with deeply nested structure
|
||||
deeply_nested = {"query": {"nested": {"deeper": {"value": "deep value"}}}}
|
||||
with pytest.raises(ValueError):
|
||||
tool._parse_args(deeply_nested)
|
||||
Reference in New Issue
Block a user