diff --git a/src/crewai/utilities/events/llm_events.py b/src/crewai/utilities/events/llm_events.py index 7df638e8b..82250c01d 100644 --- a/src/crewai/utilities/events/llm_events.py +++ b/src/crewai/utilities/events/llm_events.py @@ -1,10 +1,13 @@ +import logging from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Type from pydantic import BaseModel, model_validator from crewai.utilities.events.base_events import BaseEvent +logger = logging.getLogger(__name__) + class LLMCallType(Enum): """Type of LLM call being made""" @@ -29,11 +32,40 @@ class LLMCallStartedEvent(BaseEvent): @model_validator(mode='before') @classmethod - def sanitize_tools(cls, values): - """Sanitize tools list to only include dict objects, filtering out non-dict objects like TokenCalcHandler""" - if isinstance(values, dict) and 'tools' in values and values['tools'] is not None: - if isinstance(values['tools'], list): - values['tools'] = [tool for tool in values['tools'] if isinstance(tool, dict)] + def sanitize_tools(cls: Type["LLMCallStartedEvent"], values: Any) -> Any: + """Sanitize tools list to only include dict objects, filtering out non-dict objects like TokenCalcHandler. + + Args: + values (dict): Input values dictionary containing tools and other event data. + + Returns: + dict: Sanitized values with filtered tools list containing only valid dict objects. + + Example: + >>> from crewai.utilities.token_counter_callback import TokenCalcHandler + >>> from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess + >>> token_handler = TokenCalcHandler(TokenProcess()) + >>> tools = [{"name": "tool1"}, token_handler, {"name": "tool2"}] + >>> sanitized = cls.sanitize_tools({"tools": tools}) + >>> # Expected: {"tools": [{"name": "tool1"}, {"name": "tool2"}]} + """ + try: + if isinstance(values, dict) and 'tools' in values and values['tools'] is not None: + if isinstance(values['tools'], list): + sanitized_tools = [] + for tool in values['tools']: + if isinstance(tool, dict): + if all(isinstance(v, (str, int, float, bool, dict, list, type(None))) for v in tool.values()): + sanitized_tools.append(tool) + else: + logger.warning(f"Tool dict contains invalid value types: {tool}") + else: + logger.debug(f"Filtering out non-dict tool object: {type(tool).__name__}") + + values['tools'] = sanitized_tools + except Exception as e: + logger.warning(f"Error during tools sanitization: {e}") + return values diff --git a/tests/utilities/events/test_llm_events_validation.py b/tests/utilities/events/test_llm_events_validation.py index ee16606d2..abde2d718 100644 --- a/tests/utilities/events/test_llm_events_validation.py +++ b/tests/utilities/events/test_llm_events_validation.py @@ -1,4 +1,5 @@ import pytest +import logging from crewai.utilities.events.llm_events import LLMCallStartedEvent from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess @@ -133,3 +134,70 @@ class TestLLMCallStartedEventValidation: assert event.tools == [{"name": "tool1"}] assert event.available_functions == available_funcs + + @pytest.mark.parametrize("tools_input,expected", [ + ([{"name": "tool1"}, TokenCalcHandler(TokenProcess())], [{"name": "tool1"}]), + ([{"name": "tool1"}, "string_tool", {"name": "tool2"}], [{"name": "tool1"}, {"name": "tool2"}]), + ([TokenCalcHandler(TokenProcess()), 123, ["list_tool"]], []), + ([{"name": "tool1", "type": "function", "enabled": True}], [{"name": "tool1", "type": "function", "enabled": True}]), + ([], []), + (None, None), + ]) + def test_tools_sanitization_parameterized(self, tools_input, expected): + """Parameterized test for various tools sanitization scenarios""" + event = LLMCallStartedEvent( + messages=[{"role": "user", "content": "test message"}], + tools=tools_input, + callbacks=None + ) + assert event.tools == expected + + def test_tools_with_invalid_dict_values_filtered(self): + """Test that dicts with invalid value types are filtered out""" + class CustomObject: + pass + + invalid_tool = {"name": "tool1", "custom_obj": CustomObject()} + valid_tool = {"name": "tool2", "type": "function"} + + event = LLMCallStartedEvent( + messages=[{"role": "user", "content": "test message"}], + tools=[valid_tool, invalid_tool], + callbacks=None + ) + + assert event.tools == [valid_tool] + + def test_sanitize_tools_performance_large_dataset(self): + """Test sanitization performance with large datasets""" + token_handler = TokenCalcHandler(TokenProcess()) + + large_tools_list = [] + for i in range(1000): + if i % 3 == 0: + large_tools_list.append({"name": f"tool_{i}", "type": "function"}) + elif i % 3 == 1: + large_tools_list.append(token_handler) + else: + large_tools_list.append(f"string_tool_{i}") + + event = LLMCallStartedEvent( + messages=[{"role": "user", "content": "test message"}], + tools=large_tools_list, + callbacks=None + ) + + expected_count = len([i for i in range(1000) if i % 3 == 0]) + assert len(event.tools) == expected_count + assert all(isinstance(tool, dict) for tool in event.tools) + + def test_sanitization_error_handling(self, caplog): + """Test that sanitization errors are handled gracefully""" + with caplog.at_level(logging.WARNING): + event = LLMCallStartedEvent( + messages=[{"role": "user", "content": "test message"}], + tools=[{"name": "tool1"}, TokenCalcHandler(TokenProcess())], + callbacks=None + ) + + assert event.tools == [{"name": "tool1"}]