fix: address race condition in FilteredStream by using context managers (#2818)

During the sys.stdout = FilteredStream(old_stdout) assignment, if any code (including logging, print, or internal library output) writes to sys.stdout immediately, and that write happens before __init__ completes, the write() method is called on a not-fully-initialized object.. hence _lock doesn’t exist yet.
This commit is contained in:
Lucas Gomide
2025-05-12 16:05:14 -03:00
committed by GitHub
parent 4e496d7a20
commit f700e014c9

View File

@@ -5,8 +5,7 @@ import sys
import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from types import SimpleNamespace
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from typing import (
Any,
DefaultDict,
@@ -31,7 +30,6 @@ from crewai.utilities.events.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
)
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
@@ -45,6 +43,9 @@ with warnings.catch_warnings():
from litellm.utils import supports_response_schema
import io
from typing import TextIO
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.events import crewai_event_bus
from crewai.utilities.exceptions.context_window_exceeding_exception import (
@@ -54,12 +55,17 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
load_dotenv()
class FilteredStream:
def __init__(self, original_stream):
class FilteredStream(io.TextIOBase):
_lock = None
def __init__(self, original_stream: TextIO):
self._original_stream = original_stream
self._lock = threading.Lock()
def write(self, s) -> int:
def write(self, s: str) -> int:
if not self._lock:
self._lock = threading.Lock()
with self._lock:
# Filter out extraneous messages from LiteLLM
if (
@@ -214,15 +220,11 @@ def suppress_warnings():
)
# Redirect stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = FilteredStream(old_stdout)
sys.stderr = FilteredStream(old_stderr)
try:
with (
redirect_stdout(FilteredStream(sys.stdout)),
redirect_stderr(FilteredStream(sys.stderr)),
):
yield
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
class Delta(TypedDict):