diff --git a/tests/test_sys_stream_hijacking.py b/tests/test_sys_stream_hijacking.py index 76cbe84d4..7a352dc5c 100644 --- a/tests/test_sys_stream_hijacking.py +++ b/tests/test_sys_stream_hijacking.py @@ -139,18 +139,25 @@ def test_concurrent_llm_calls(): def test_logger_caching_performance(): """Test that logger instance is cached for performance.""" from crewai.llm import suppress_litellm_output + import crewai.llm - with patch('logging.getLogger') as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - - with suppress_litellm_output(): - pass - - with suppress_litellm_output(): - pass + original_logger = crewai.llm._litellm_logger + crewai.llm._litellm_logger = None + + try: + with patch('logging.getLogger') as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger - mock_get_logger.assert_called_once_with("litellm") + with suppress_litellm_output(): + pass + + with suppress_litellm_output(): + pass + + mock_get_logger.assert_called_once_with("litellm") + finally: + crewai.llm._litellm_logger = original_logger def test_suppression_error_handling():