diff --git a/src/crewai/llm.py b/src/crewai/llm.py index c8c456297..440cbf903 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -5,8 +5,7 @@ import sys import threading import warnings from collections import defaultdict -from contextlib import contextmanager -from types import SimpleNamespace +from contextlib import contextmanager, redirect_stderr, redirect_stdout from typing import ( Any, DefaultDict, @@ -31,7 +30,6 @@ from crewai.utilities.events.llm_events import ( LLMCallType, LLMStreamChunkEvent, ) -from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) @@ -45,6 +43,9 @@ with warnings.catch_warnings(): from litellm.utils import supports_response_schema +import io +from typing import TextIO + from crewai.llms.base_llm import BaseLLM from crewai.utilities.events import crewai_event_bus from crewai.utilities.exceptions.context_window_exceeding_exception import ( @@ -54,12 +55,17 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( load_dotenv() -class FilteredStream: - def __init__(self, original_stream): +class FilteredStream(io.TextIOBase): + _lock = None + + def __init__(self, original_stream: TextIO): self._original_stream = original_stream self._lock = threading.Lock() - def write(self, s) -> int: + def write(self, s: str) -> int: + if not self._lock: + self._lock = threading.Lock() + with self._lock: # Filter out extraneous messages from LiteLLM if ( @@ -214,15 +220,11 @@ def suppress_warnings(): ) # Redirect stdout and stderr - old_stdout = sys.stdout - old_stderr = sys.stderr - sys.stdout = FilteredStream(old_stdout) - sys.stderr = FilteredStream(old_stderr) - try: + with ( + redirect_stdout(FilteredStream(sys.stdout)), + redirect_stderr(FilteredStream(sys.stderr)), + ): yield - finally: - sys.stdout = old_stdout - sys.stderr = old_stderr class Delta(TypedDict):