Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
1b75090dc2 Address code review feedback: enhance sanitize_tools with better documentation, error handling, and type validation
- Add comprehensive docstring with Args, Returns, and Example sections
- Implement try-catch error handling with logging for unexpected scenarios
- Add stronger type validation for dictionary values
- Include logging for debugging when non-dict objects are filtered
- Add type annotations for better maintainability and IDE support
- Add parameterized tests for better coverage and organization
- Add performance tests for large datasets
- Add tests for invalid dict value types and error handling scenarios

Addresses feedback from joaomdmoura and mplachta on PR #3044

Co-Authored-By: João <joao@crewai.com>
2025-06-21 16:38:19 +00:00
Devin AI
b2bda39e56 Fix Pydantic validation error in LLMCallStartedEvent when TokenCalcHandler in tools list
- Add model_validator to sanitize tools list before validation
- Filter out non-dict objects like TokenCalcHandler from tools list
- Preserve dict tools while removing problematic objects
- Add comprehensive test coverage for the fix and edge cases
- Resolves issue #3043

Co-Authored-By: João <joao@crewai.com>
2025-06-21 16:33:22 +00:00
2 changed files with 246 additions and 2 deletions

View File

@@ -1,10 +1,13 @@
import logging
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Type
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from crewai.utilities.events.base_events import BaseEvent
logger = logging.getLogger(__name__)
class LLMCallType(Enum):
"""Type of LLM call being made"""
@@ -27,6 +30,44 @@ class LLMCallStartedEvent(BaseEvent):
callbacks: Optional[List[Any]] = None
available_functions: Optional[Dict[str, Any]] = None
@model_validator(mode='before')
@classmethod
def sanitize_tools(cls: Type["LLMCallStartedEvent"], values: Any) -> Any:
"""Sanitize tools list to only include dict objects, filtering out non-dict objects like TokenCalcHandler.
Args:
values (dict): Input values dictionary containing tools and other event data.
Returns:
dict: Sanitized values with filtered tools list containing only valid dict objects.
Example:
>>> from crewai.utilities.token_counter_callback import TokenCalcHandler
>>> from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
>>> token_handler = TokenCalcHandler(TokenProcess())
>>> tools = [{"name": "tool1"}, token_handler, {"name": "tool2"}]
>>> sanitized = cls.sanitize_tools({"tools": tools})
>>> # Expected: {"tools": [{"name": "tool1"}, {"name": "tool2"}]}
"""
try:
if isinstance(values, dict) and 'tools' in values and values['tools'] is not None:
if isinstance(values['tools'], list):
sanitized_tools = []
for tool in values['tools']:
if isinstance(tool, dict):
if all(isinstance(v, (str, int, float, bool, dict, list, type(None))) for v in tool.values()):
sanitized_tools.append(tool)
else:
logger.warning(f"Tool dict contains invalid value types: {tool}")
else:
logger.debug(f"Filtering out non-dict tool object: {type(tool).__name__}")
values['tools'] = sanitized_tools
except Exception as e:
logger.warning(f"Error during tools sanitization: {e}")
return values
class LLMCallCompletedEvent(BaseEvent):
"""Event emitted when a LLM call completes"""

View File

@@ -0,0 +1,203 @@
import pytest
import logging
from crewai.utilities.events.llm_events import LLMCallStartedEvent
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
class TestLLMCallStartedEventValidation:
"""Test cases for LLMCallStartedEvent validation and sanitization"""
def test_normal_dict_tools_work(self):
"""Test that normal dict tools work correctly"""
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[{"name": "tool1"}, {"name": "tool2"}],
callbacks=None
)
assert event.tools == [{"name": "tool1"}, {"name": "tool2"}]
assert event.type == "llm_call_started"
def test_token_calc_handler_in_tools_filtered_out(self):
"""Test that TokenCalcHandler objects in tools list are filtered out"""
token_handler = TokenCalcHandler(TokenProcess())
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[{"name": "tool1"}, token_handler, {"name": "tool2"}],
callbacks=None
)
assert event.tools == [{"name": "tool1"}, {"name": "tool2"}]
assert len(event.tools) == 2
def test_mixed_objects_in_tools_only_dicts_preserved(self):
"""Test that only dict objects are preserved when mixed types are in tools"""
token_handler = TokenCalcHandler(TokenProcess())
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[
{"name": "tool1"},
token_handler,
"string_tool",
{"name": "tool2"},
123,
{"name": "tool3"}
],
callbacks=None
)
assert event.tools == [{"name": "tool1"}, {"name": "tool2"}, {"name": "tool3"}]
assert len(event.tools) == 3
def test_empty_tools_list_handled(self):
"""Test that empty tools list is handled correctly"""
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[],
callbacks=None
)
assert event.tools == []
def test_none_tools_handled(self):
"""Test that None tools value is handled correctly"""
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=None,
callbacks=None
)
assert event.tools is None
def test_all_non_dict_tools_results_in_empty_list(self):
"""Test that when all tools are non-dict objects, result is empty list"""
token_handler = TokenCalcHandler(TokenProcess())
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[token_handler, "string_tool", 123, ["list_tool"]],
callbacks=None
)
assert event.tools == []
def test_reproduction_case_from_issue_3043(self):
"""Test the exact reproduction case from GitHub issue #3043"""
token_handler = TokenCalcHandler(TokenProcess())
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[{"name": "tool1"}, token_handler],
callbacks=None
)
assert event.tools == [{"name": "tool1"}]
assert len(event.tools) == 1
def test_callbacks_with_token_handler_still_work(self):
"""Test that TokenCalcHandler in callbacks still works normally"""
token_handler = TokenCalcHandler(TokenProcess())
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[{"name": "tool1"}],
callbacks=[token_handler]
)
assert event.tools == [{"name": "tool1"}]
assert event.callbacks == [token_handler]
def test_string_messages_work(self):
"""Test that string messages work with tool sanitization"""
token_handler = TokenCalcHandler(TokenProcess())
event = LLMCallStartedEvent(
messages="test message",
tools=[{"name": "tool1"}, token_handler],
callbacks=None
)
assert event.messages == "test message"
assert event.tools == [{"name": "tool1"}]
def test_available_functions_preserved(self):
"""Test that available_functions are preserved during sanitization"""
token_handler = TokenCalcHandler(TokenProcess())
available_funcs = {"func1": lambda x: x}
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[{"name": "tool1"}, token_handler],
callbacks=None,
available_functions=available_funcs
)
assert event.tools == [{"name": "tool1"}]
assert event.available_functions == available_funcs
@pytest.mark.parametrize("tools_input,expected", [
([{"name": "tool1"}, TokenCalcHandler(TokenProcess())], [{"name": "tool1"}]),
([{"name": "tool1"}, "string_tool", {"name": "tool2"}], [{"name": "tool1"}, {"name": "tool2"}]),
([TokenCalcHandler(TokenProcess()), 123, ["list_tool"]], []),
([{"name": "tool1", "type": "function", "enabled": True}], [{"name": "tool1", "type": "function", "enabled": True}]),
([], []),
(None, None),
])
def test_tools_sanitization_parameterized(self, tools_input, expected):
"""Parameterized test for various tools sanitization scenarios"""
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=tools_input,
callbacks=None
)
assert event.tools == expected
def test_tools_with_invalid_dict_values_filtered(self):
"""Test that dicts with invalid value types are filtered out"""
class CustomObject:
pass
invalid_tool = {"name": "tool1", "custom_obj": CustomObject()}
valid_tool = {"name": "tool2", "type": "function"}
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[valid_tool, invalid_tool],
callbacks=None
)
assert event.tools == [valid_tool]
def test_sanitize_tools_performance_large_dataset(self):
"""Test sanitization performance with large datasets"""
token_handler = TokenCalcHandler(TokenProcess())
large_tools_list = []
for i in range(1000):
if i % 3 == 0:
large_tools_list.append({"name": f"tool_{i}", "type": "function"})
elif i % 3 == 1:
large_tools_list.append(token_handler)
else:
large_tools_list.append(f"string_tool_{i}")
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=large_tools_list,
callbacks=None
)
expected_count = len([i for i in range(1000) if i % 3 == 0])
assert len(event.tools) == expected_count
assert all(isinstance(tool, dict) for tool in event.tools)
def test_sanitization_error_handling(self, caplog):
"""Test that sanitization errors are handled gracefully"""
with caplog.at_level(logging.WARNING):
event = LLMCallStartedEvent(
messages=[{"role": "user", "content": "test message"}],
tools=[{"name": "tool1"}, TokenCalcHandler(TokenProcess())],
callbacks=None
)
assert event.tools == [{"name": "tool1"}]