mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 13:58:15 +00:00
fix tests
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# conftest.py
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -75,18 +77,214 @@ def filter_stream_parameter(request):
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure_litellm_for_testing():
|
||||
"""Configure litellm to work better with VCR cassettes."""
|
||||
def reset_litellm_state():
|
||||
"""Reset LiteLLM global state before each test to ensure isolation."""
|
||||
import litellm
|
||||
|
||||
# Disable litellm's internal streaming optimizations that might conflict with VCR
|
||||
original_drop_params = litellm.drop_params
|
||||
# Save original state
|
||||
original_drop_params = getattr(litellm, 'drop_params', None)
|
||||
original_success_callback = getattr(litellm, 'success_callback', [])
|
||||
original_async_success_callback = getattr(litellm, '_async_success_callback', [])
|
||||
original_callbacks = getattr(litellm, 'callbacks', [])
|
||||
original_failure_callback = getattr(litellm, 'failure_callback', [])
|
||||
|
||||
# Reset to clean state for the test
|
||||
litellm.drop_params = True
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.callbacks = []
|
||||
litellm.failure_callback = []
|
||||
|
||||
yield
|
||||
|
||||
# Restore original setting
|
||||
litellm.drop_params = original_drop_params
|
||||
# Restore original state after test
|
||||
if original_drop_params is not None:
|
||||
litellm.drop_params = original_drop_params
|
||||
litellm.success_callback = original_success_callback
|
||||
litellm._async_success_callback = original_async_success_callback
|
||||
litellm.callbacks = original_callbacks
|
||||
litellm.failure_callback = original_failure_callback
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_background_threads():
|
||||
"""Cleanup all background threads and resources after each test."""
|
||||
original_threads = set(threading.enumerate())
|
||||
|
||||
yield
|
||||
|
||||
# Force cleanup of various background resources
|
||||
try:
|
||||
# Cleanup RPM controllers first - they create Timer threads
|
||||
try:
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
# Find any RPMController instances and stop their timers
|
||||
import gc
|
||||
for obj in gc.get_objects():
|
||||
if isinstance(obj, RPMController):
|
||||
if hasattr(obj, 'stop_rpm_counter'):
|
||||
obj.stop_rpm_counter()
|
||||
if hasattr(obj, '_timer') and obj._timer:
|
||||
obj._timer.cancel()
|
||||
obj._timer = None
|
||||
if hasattr(obj, '_shutdown_flag'):
|
||||
obj._shutdown_flag = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Cleanup all Timer threads explicitly
|
||||
try:
|
||||
for thread in threading.enumerate():
|
||||
if isinstance(thread, threading.Timer):
|
||||
thread.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Cleanup CrewAI telemetry
|
||||
try:
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
telemetry = Telemetry()
|
||||
if hasattr(telemetry, 'provider') and telemetry.provider:
|
||||
if hasattr(telemetry.provider, 'shutdown'):
|
||||
telemetry.provider.shutdown()
|
||||
if hasattr(telemetry.provider, 'force_flush'):
|
||||
telemetry.provider.force_flush(timeout_millis=1000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Cleanup console formatters and spinners
|
||||
try:
|
||||
from crewai.utilities.events.utils.console_formatter import ConsoleFormatter, SimpleSpinner
|
||||
# Force stop any running spinners and formatters
|
||||
import gc
|
||||
for obj in gc.get_objects():
|
||||
if isinstance(obj, ConsoleFormatter):
|
||||
if hasattr(obj, 'cleanup'):
|
||||
obj.cleanup()
|
||||
elif isinstance(obj, SimpleSpinner):
|
||||
if hasattr(obj, 'stop'):
|
||||
obj.stop()
|
||||
|
||||
# Also check threads for spinners
|
||||
for thread in threading.enumerate():
|
||||
if hasattr(thread, '_target') and thread._target:
|
||||
target_name = getattr(thread._target, '__name__', '')
|
||||
if 'spin' in target_name.lower():
|
||||
if hasattr(thread, '_stop_event'):
|
||||
thread._stop_event.set()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Cleanup HTTP clients and connection pools
|
||||
try:
|
||||
import httpx
|
||||
import gc
|
||||
# Force close any httpx clients
|
||||
for obj in gc.get_objects():
|
||||
if isinstance(obj, httpx.Client):
|
||||
try:
|
||||
obj.close()
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(obj, httpx.AsyncClient):
|
||||
try:
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop.create_task(obj.aclose())
|
||||
else:
|
||||
asyncio.run(obj.aclose())
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Cleanup any ThreadPoolExecutor instances
|
||||
try:
|
||||
import concurrent.futures
|
||||
import gc
|
||||
for obj in gc.get_objects():
|
||||
if isinstance(obj, concurrent.futures.ThreadPoolExecutor):
|
||||
try:
|
||||
obj.shutdown(wait=False)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force cleanup of any remaining daemon threads
|
||||
new_threads = set(threading.enumerate()) - original_threads
|
||||
for thread in new_threads:
|
||||
if thread.daemon and thread.is_alive():
|
||||
try:
|
||||
# Try to stop the thread gracefully
|
||||
if hasattr(thread, '_stop_event'):
|
||||
thread._stop_event.set()
|
||||
elif hasattr(thread, 'stop'):
|
||||
thread.stop()
|
||||
|
||||
# Give it a moment to stop
|
||||
thread.join(timeout=0.1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Final aggressive cleanup - wait a bit for threads to finish
|
||||
time.sleep(0.1)
|
||||
|
||||
# Force garbage collection to clean up any remaining references
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
except Exception:
|
||||
# If cleanup fails, don't fail the test
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_httpx_for_vcr():
|
||||
"""Patch httpx Response to handle VCR replay issues."""
|
||||
try:
|
||||
import httpx
|
||||
from unittest.mock import patch
|
||||
|
||||
original_content = httpx.Response.content
|
||||
|
||||
def patched_content(self):
|
||||
try:
|
||||
return original_content.fget(self)
|
||||
except httpx.ResponseNotRead:
|
||||
# If content hasn't been read, try to read it from _content or return empty
|
||||
if hasattr(self, '_content') and self._content is not None:
|
||||
return self._content
|
||||
# For VCR-replayed responses, try to get content from text if available
|
||||
try:
|
||||
return self.text.encode(self.encoding or 'utf-8')
|
||||
except Exception:
|
||||
return b''
|
||||
|
||||
with patch.object(httpx.Response, 'content', property(patched_content)):
|
||||
yield
|
||||
|
||||
except ImportError:
|
||||
# httpx not available, skip patching
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_telemetry():
|
||||
"""Disable telemetry during tests to prevent background threads."""
|
||||
original_telemetry = os.environ.get("CREWAI_TELEMETRY", None)
|
||||
os.environ["CREWAI_TELEMETRY"] = "false"
|
||||
|
||||
yield
|
||||
|
||||
if original_telemetry is not None:
|
||||
os.environ["CREWAI_TELEMETRY"] = original_telemetry
|
||||
else:
|
||||
os.environ.pop("CREWAI_TELEMETRY", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -97,3 +295,29 @@ def vcr_config(request) -> dict:
|
||||
"filter_headers": [("authorization", "AUTHORIZATION-XXX")],
|
||||
"before_record_request": filter_stream_parameter,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_rpm_controller():
|
||||
"""Patch RPMController to prevent recurring timers during tests."""
|
||||
try:
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from unittest.mock import patch
|
||||
|
||||
original_reset_request_count = RPMController._reset_request_count
|
||||
|
||||
def mock_reset_request_count(self):
|
||||
"""Mock that prevents the recurring timer from being set up."""
|
||||
if self._lock:
|
||||
with self._lock:
|
||||
self._current_rpm = 0
|
||||
# Don't start a new timer during tests
|
||||
else:
|
||||
self._current_rpm = 0
|
||||
|
||||
with patch.object(RPMController, '_reset_request_count', mock_reset_request_count):
|
||||
yield
|
||||
|
||||
except ImportError:
|
||||
# If RPMController is not available, just yield
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user