diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 19fbf8c70..4dd1b3cbe 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -201,21 +201,36 @@ def suppress_warnings(): yield +_litellm_logger = None + @contextmanager def suppress_litellm_output(): """Contextually suppress litellm-related logging output during LLM calls.""" - litellm_logger = logging.getLogger("litellm") - original_level = litellm_logger.level + global _litellm_logger + if _litellm_logger is None: + _litellm_logger = logging.getLogger("litellm") + + original_level = _litellm_logger.level + + warning_patterns = [ + ".*give feedback.*", + ".*Consider using a smaller input.*", + ".*litellm\\.info:.*", + ".*text splitting strategy.*" + ] with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message=".*give feedback.*") - warnings.filterwarnings("ignore", message=".*Consider using a smaller input.*") + for pattern in warning_patterns: + warnings.filterwarnings("ignore", message=pattern) try: - litellm_logger.setLevel(logging.WARNING) + _litellm_logger.setLevel(logging.WARNING) + yield + except Exception as e: + logging.debug(f"Error in litellm output suppression: {e}") yield finally: - litellm_logger.setLevel(original_level) + _litellm_logger.setLevel(original_level) class Delta(TypedDict): @@ -468,8 +483,9 @@ class LLM(BaseLLM): chunk_content = result except Exception as e: - logging.debug(f"Error extracting content from chunk: {e}") + logging.error(f"Error extracting content from chunk: {e}", exc_info=True) logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}") + continue # Only add non-None content to the response if chunk_content is not None: diff --git a/tests/test_sys_stream_hijacking.py b/tests/test_sys_stream_hijacking.py index d58803c42..76cbe84d4 100644 --- a/tests/test_sys_stream_hijacking.py +++ b/tests/test_sys_stream_hijacking.py @@ -2,7 +2,8 @@ import sys import io -from unittest.mock import patch +from unittest.mock import patch, MagicMock +import pytest def test_crewai_hijacks_sys_streams(): @@ -10,7 +11,7 @@ def test_crewai_hijacks_sys_streams(): original_stdout = sys.stdout original_stderr = sys.stderr - import crewai.llm + import crewai.llm # noqa: F401 try: assert sys.stdout is not original_stdout, "sys.stdout should be hijacked by FilteredStream" @@ -24,7 +25,7 @@ def test_crewai_hijacks_sys_streams(): def test_litellm_output_is_filtered(): """Test that litellm-related output is currently filtered (before fix).""" - import crewai.llm + import crewai.llm # noqa: F401 captured_output = io.StringIO() @@ -51,7 +52,7 @@ def test_litellm_output_is_filtered(): def test_normal_output_passes_through(): """Test that normal output passes through correctly after the fix.""" - import crewai.llm + import crewai.llm # noqa: F401 captured_output = io.StringIO() original_stdout = sys.stdout @@ -76,7 +77,7 @@ def test_crewai_does_not_hijack_sys_streams_after_fix(): if 'crewai' in sys.modules: del sys.modules['crewai'] - import crewai.llm + import crewai.llm # noqa: F401 assert sys.stdout is original_stdout, "sys.stdout should NOT be hijacked after fix" assert sys.stderr is original_stderr, "sys.stderr should NOT be hijacked after fix" @@ -104,3 +105,66 @@ def test_litellm_output_still_suppressed_during_llm_calls(): output = captured_stdout.getvalue() + captured_stderr.getvalue() assert "litellm" not in output.lower(), "litellm output should still be suppressed during calls" + + +def test_concurrent_llm_calls(): + """Test that contextual suppression works correctly with concurrent calls.""" + import threading + from crewai.llm import LLM + + results = [] + + def make_llm_call(): + with patch('litellm.completion') as mock_completion: + mock_completion.return_value = type('MockResponse', (), { + 'choices': [type('MockChoice', (), { + 'message': type('MockMessage', (), {'content': 'test response'})() + })()] + })() + + llm = LLM(model="gpt-4") + result = llm.call([{"role": "user", "content": "test"}]) + results.append(result) + + threads = [threading.Thread(target=make_llm_call) for _ in range(3)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert len(results) == 3 + assert all("test response" in result for result in results) + + +def test_logger_caching_performance(): + """Test that logger instance is cached for performance.""" + from crewai.llm import suppress_litellm_output + + with patch('logging.getLogger') as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + with suppress_litellm_output(): + pass + + with suppress_litellm_output(): + pass + + mock_get_logger.assert_called_once_with("litellm") + + +def test_suppression_error_handling(): + """Test that suppression continues even if logger operations fail.""" + from crewai.llm import suppress_litellm_output + + with patch('logging.getLogger') as mock_get_logger: + mock_logger = MagicMock() + mock_logger.setLevel.side_effect = Exception("Logger error") + mock_get_logger.return_value = mock_logger + + try: + with suppress_litellm_output(): + result = "operation completed" + assert result == "operation completed" + except Exception: + pytest.fail("Suppression should not fail even if logger operations fail")