From f700e014c9e6494ad693a6e48a5da7594efc88ad Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Mon, 12 May 2025 16:05:14 -0300 Subject: [PATCH] fix: address race condition in FilteredStream by using context managers (#2818) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit During the sys.stdout = FilteredStream(old_stdout) assignment, if any code (including logging, print, or internal library output) writes to sys.stdout immediately, and that write happens before __init__ completes, the write() method is called on a not-fully-initialized object.. hence _lock doesn’t exist yet. --- src/crewai/llm.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index c8c456297..440cbf903 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -5,8 +5,7 @@ import sys import threading import warnings from collections import defaultdict -from contextlib import contextmanager -from types import SimpleNamespace +from contextlib import contextmanager, redirect_stderr, redirect_stdout from typing import ( Any, DefaultDict, @@ -31,7 +30,6 @@ from crewai.utilities.events.llm_events import ( LLMCallType, LLMStreamChunkEvent, ) -from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) @@ -45,6 +43,9 @@ with warnings.catch_warnings(): from litellm.utils import supports_response_schema +import io +from typing import TextIO + from crewai.llms.base_llm import BaseLLM from crewai.utilities.events import crewai_event_bus from crewai.utilities.exceptions.context_window_exceeding_exception import ( @@ -54,12 +55,17 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( load_dotenv() -class FilteredStream: - def __init__(self, original_stream): +class FilteredStream(io.TextIOBase): + _lock = None + + def __init__(self, original_stream: TextIO): self._original_stream = original_stream self._lock = threading.Lock() - def write(self, s) -> int: + def write(self, s: str) -> int: + if not self._lock: + self._lock = threading.Lock() + with self._lock: # Filter out extraneous messages from LiteLLM if ( @@ -214,15 +220,11 @@ def suppress_warnings(): ) # Redirect stdout and stderr - old_stdout = sys.stdout - old_stderr = sys.stderr - sys.stdout = FilteredStream(old_stdout) - sys.stderr = FilteredStream(old_stderr) - try: + with ( + redirect_stdout(FilteredStream(sys.stdout)), + redirect_stderr(FilteredStream(sys.stderr)), + ): yield - finally: - sys.stdout = old_stdout - sys.stderr = old_stderr class Delta(TypedDict):