mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
2 Commits
devin/1761
...
devin/1750
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b75090dc2 | ||
|
|
b2bda39e56 |
@@ -1,10 +1,13 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from crewai.utilities.events.base_events import BaseEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMCallType(Enum):
|
||||
"""Type of LLM call being made"""
|
||||
@@ -27,6 +30,44 @@ class LLMCallStartedEvent(BaseEvent):
|
||||
callbacks: Optional[List[Any]] = None
|
||||
available_functions: Optional[Dict[str, Any]] = None
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def sanitize_tools(cls: Type["LLMCallStartedEvent"], values: Any) -> Any:
|
||||
"""Sanitize tools list to only include dict objects, filtering out non-dict objects like TokenCalcHandler.
|
||||
|
||||
Args:
|
||||
values (dict): Input values dictionary containing tools and other event data.
|
||||
|
||||
Returns:
|
||||
dict: Sanitized values with filtered tools list containing only valid dict objects.
|
||||
|
||||
Example:
|
||||
>>> from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
>>> from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
>>> token_handler = TokenCalcHandler(TokenProcess())
|
||||
>>> tools = [{"name": "tool1"}, token_handler, {"name": "tool2"}]
|
||||
>>> sanitized = cls.sanitize_tools({"tools": tools})
|
||||
>>> # Expected: {"tools": [{"name": "tool1"}, {"name": "tool2"}]}
|
||||
"""
|
||||
try:
|
||||
if isinstance(values, dict) and 'tools' in values and values['tools'] is not None:
|
||||
if isinstance(values['tools'], list):
|
||||
sanitized_tools = []
|
||||
for tool in values['tools']:
|
||||
if isinstance(tool, dict):
|
||||
if all(isinstance(v, (str, int, float, bool, dict, list, type(None))) for v in tool.values()):
|
||||
sanitized_tools.append(tool)
|
||||
else:
|
||||
logger.warning(f"Tool dict contains invalid value types: {tool}")
|
||||
else:
|
||||
logger.debug(f"Filtering out non-dict tool object: {type(tool).__name__}")
|
||||
|
||||
values['tools'] = sanitized_tools
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during tools sanitization: {e}")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class LLMCallCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a LLM call completes"""
|
||||
|
||||
203
tests/utilities/events/test_llm_events_validation.py
Normal file
203
tests/utilities/events/test_llm_events_validation.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import pytest
|
||||
import logging
|
||||
from crewai.utilities.events.llm_events import LLMCallStartedEvent
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
|
||||
|
||||
class TestLLMCallStartedEventValidation:
|
||||
"""Test cases for LLMCallStartedEvent validation and sanitization"""
|
||||
|
||||
def test_normal_dict_tools_work(self):
|
||||
"""Test that normal dict tools work correctly"""
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
callbacks=None
|
||||
)
|
||||
assert event.tools == [{"name": "tool1"}, {"name": "tool2"}]
|
||||
assert event.type == "llm_call_started"
|
||||
|
||||
def test_token_calc_handler_in_tools_filtered_out(self):
|
||||
"""Test that TokenCalcHandler objects in tools list are filtered out"""
|
||||
token_handler = TokenCalcHandler(TokenProcess())
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[{"name": "tool1"}, token_handler, {"name": "tool2"}],
|
||||
callbacks=None
|
||||
)
|
||||
|
||||
assert event.tools == [{"name": "tool1"}, {"name": "tool2"}]
|
||||
assert len(event.tools) == 2
|
||||
|
||||
def test_mixed_objects_in_tools_only_dicts_preserved(self):
|
||||
"""Test that only dict objects are preserved when mixed types are in tools"""
|
||||
token_handler = TokenCalcHandler(TokenProcess())
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[
|
||||
{"name": "tool1"},
|
||||
token_handler,
|
||||
"string_tool",
|
||||
{"name": "tool2"},
|
||||
123,
|
||||
{"name": "tool3"}
|
||||
],
|
||||
callbacks=None
|
||||
)
|
||||
|
||||
assert event.tools == [{"name": "tool1"}, {"name": "tool2"}, {"name": "tool3"}]
|
||||
assert len(event.tools) == 3
|
||||
|
||||
def test_empty_tools_list_handled(self):
|
||||
"""Test that empty tools list is handled correctly"""
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[],
|
||||
callbacks=None
|
||||
)
|
||||
assert event.tools == []
|
||||
|
||||
def test_none_tools_handled(self):
|
||||
"""Test that None tools value is handled correctly"""
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=None,
|
||||
callbacks=None
|
||||
)
|
||||
assert event.tools is None
|
||||
|
||||
def test_all_non_dict_tools_results_in_empty_list(self):
|
||||
"""Test that when all tools are non-dict objects, result is empty list"""
|
||||
token_handler = TokenCalcHandler(TokenProcess())
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[token_handler, "string_tool", 123, ["list_tool"]],
|
||||
callbacks=None
|
||||
)
|
||||
|
||||
assert event.tools == []
|
||||
|
||||
def test_reproduction_case_from_issue_3043(self):
|
||||
"""Test the exact reproduction case from GitHub issue #3043"""
|
||||
token_handler = TokenCalcHandler(TokenProcess())
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[{"name": "tool1"}, token_handler],
|
||||
callbacks=None
|
||||
)
|
||||
|
||||
assert event.tools == [{"name": "tool1"}]
|
||||
assert len(event.tools) == 1
|
||||
|
||||
def test_callbacks_with_token_handler_still_work(self):
|
||||
"""Test that TokenCalcHandler in callbacks still works normally"""
|
||||
token_handler = TokenCalcHandler(TokenProcess())
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[{"name": "tool1"}],
|
||||
callbacks=[token_handler]
|
||||
)
|
||||
|
||||
assert event.tools == [{"name": "tool1"}]
|
||||
assert event.callbacks == [token_handler]
|
||||
|
||||
def test_string_messages_work(self):
|
||||
"""Test that string messages work with tool sanitization"""
|
||||
token_handler = TokenCalcHandler(TokenProcess())
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
messages="test message",
|
||||
tools=[{"name": "tool1"}, token_handler],
|
||||
callbacks=None
|
||||
)
|
||||
|
||||
assert event.messages == "test message"
|
||||
assert event.tools == [{"name": "tool1"}]
|
||||
|
||||
def test_available_functions_preserved(self):
|
||||
"""Test that available_functions are preserved during sanitization"""
|
||||
token_handler = TokenCalcHandler(TokenProcess())
|
||||
available_funcs = {"func1": lambda x: x}
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[{"name": "tool1"}, token_handler],
|
||||
callbacks=None,
|
||||
available_functions=available_funcs
|
||||
)
|
||||
|
||||
assert event.tools == [{"name": "tool1"}]
|
||||
assert event.available_functions == available_funcs
|
||||
|
||||
@pytest.mark.parametrize("tools_input,expected", [
|
||||
([{"name": "tool1"}, TokenCalcHandler(TokenProcess())], [{"name": "tool1"}]),
|
||||
([{"name": "tool1"}, "string_tool", {"name": "tool2"}], [{"name": "tool1"}, {"name": "tool2"}]),
|
||||
([TokenCalcHandler(TokenProcess()), 123, ["list_tool"]], []),
|
||||
([{"name": "tool1", "type": "function", "enabled": True}], [{"name": "tool1", "type": "function", "enabled": True}]),
|
||||
([], []),
|
||||
(None, None),
|
||||
])
|
||||
def test_tools_sanitization_parameterized(self, tools_input, expected):
|
||||
"""Parameterized test for various tools sanitization scenarios"""
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=tools_input,
|
||||
callbacks=None
|
||||
)
|
||||
assert event.tools == expected
|
||||
|
||||
def test_tools_with_invalid_dict_values_filtered(self):
|
||||
"""Test that dicts with invalid value types are filtered out"""
|
||||
class CustomObject:
|
||||
pass
|
||||
|
||||
invalid_tool = {"name": "tool1", "custom_obj": CustomObject()}
|
||||
valid_tool = {"name": "tool2", "type": "function"}
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[valid_tool, invalid_tool],
|
||||
callbacks=None
|
||||
)
|
||||
|
||||
assert event.tools == [valid_tool]
|
||||
|
||||
def test_sanitize_tools_performance_large_dataset(self):
|
||||
"""Test sanitization performance with large datasets"""
|
||||
token_handler = TokenCalcHandler(TokenProcess())
|
||||
|
||||
large_tools_list = []
|
||||
for i in range(1000):
|
||||
if i % 3 == 0:
|
||||
large_tools_list.append({"name": f"tool_{i}", "type": "function"})
|
||||
elif i % 3 == 1:
|
||||
large_tools_list.append(token_handler)
|
||||
else:
|
||||
large_tools_list.append(f"string_tool_{i}")
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=large_tools_list,
|
||||
callbacks=None
|
||||
)
|
||||
|
||||
expected_count = len([i for i in range(1000) if i % 3 == 0])
|
||||
assert len(event.tools) == expected_count
|
||||
assert all(isinstance(tool, dict) for tool in event.tools)
|
||||
|
||||
def test_sanitization_error_handling(self, caplog):
|
||||
"""Test that sanitization errors are handled gracefully"""
|
||||
with caplog.at_level(logging.WARNING):
|
||||
event = LLMCallStartedEvent(
|
||||
messages=[{"role": "user", "content": "test message"}],
|
||||
tools=[{"name": "tool1"}, TokenCalcHandler(TokenProcess())],
|
||||
callbacks=None
|
||||
)
|
||||
|
||||
assert event.tools == [{"name": "tool1"}]
|
||||
Reference in New Issue
Block a user