From dbea3758ebd80c8aab9a024138963a25322a4b6b Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 9 Feb 2025 23:43:39 +0000 Subject: [PATCH] test: Add proper environment variable cleanup in memory reset tests Co-Authored-By: Joe Moura --- tests/memory/test_memory_reset.py | 108 +++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 31 deletions(-) diff --git a/tests/memory/test_memory_reset.py b/tests/memory/test_memory_reset.py index eadab6b58..b1c8591df 100644 --- a/tests/memory/test_memory_reset.py +++ b/tests/memory/test_memory_reset.py @@ -24,16 +24,35 @@ def temp_db_dir() -> Generator[Path, None, None]: def test_memory_reset_with_openai(temp_db_dir): """Test memory reset with default OpenAI provider.""" - os.environ["OPENAI_API_KEY"] = "test-key" - memory = ShortTermMemory(path=str(temp_db_dir)) - memory.reset() # Should work with OpenAI as default + original_key = os.environ.get("OPENAI_API_KEY") + original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER") + try: + os.environ["OPENAI_API_KEY"] = "test-key" + if "CREWAI_EMBEDDING_PROVIDER" in os.environ: + del os.environ["CREWAI_EMBEDDING_PROVIDER"] + memory = ShortTermMemory(path=str(temp_db_dir)) + memory.reset() # Should work with OpenAI as default + finally: + if original_key: + os.environ["OPENAI_API_KEY"] = original_key + else: + del os.environ["OPENAI_API_KEY"] + if original_provider: + os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider def test_memory_reset_with_ollama(temp_db_dir): """Test memory reset with Ollama provider.""" - os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama" - memory = ShortTermMemory(path=str(temp_db_dir)) - memory.reset() # Should not raise any OpenAI-related errors + original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER") + try: + os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama" + memory = ShortTermMemory(path=str(temp_db_dir)) + memory.reset() # Should not raise any OpenAI-related errors + finally: + if original_provider: + os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider + elif "CREWAI_EMBEDDING_PROVIDER" in os.environ: + del os.environ["CREWAI_EMBEDDING_PROVIDER"] def test_memory_reset_with_custom_provider(temp_db_dir): @@ -53,37 +72,64 @@ def test_memory_reset_with_custom_provider(temp_db_dir): def test_memory_reset_with_invalid_provider(temp_db_dir): """Test memory reset with invalid provider raises appropriate error.""" - os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider" - with pytest.raises(Exception) as exc_info: - memory = ShortTermMemory(path=str(temp_db_dir)) - memory.reset() - assert "Unsupported embedding provider" in str(exc_info.value) + original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER") + original_key = os.environ.get("OPENAI_API_KEY") + try: + os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider" + with pytest.raises(Exception) as exc_info: + memory = ShortTermMemory(path=str(temp_db_dir)) + memory.reset() + assert "Unsupported embedding provider" in str(exc_info.value) + finally: + if original_provider: + os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider + elif "CREWAI_EMBEDDING_PROVIDER" in os.environ: + del os.environ["CREWAI_EMBEDDING_PROVIDER"] + if original_key: + os.environ["OPENAI_API_KEY"] = original_key def test_memory_reset_with_missing_api_key(temp_db_dir): """Test memory reset with missing API key raises appropriate error.""" - os.environ.pop("OPENAI_API_KEY", None) # Ensure key is not set - os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai" - with pytest.raises(ValueError) as exc_info: - memory = ShortTermMemory(path=str(temp_db_dir)) - memory.reset() - assert "openai api key" in str(exc_info.value).lower() + original_key = os.environ.get("OPENAI_API_KEY") + original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER") + try: + if "OPENAI_API_KEY" in os.environ: + del os.environ["OPENAI_API_KEY"] + os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai" + with pytest.raises(ValueError) as exc_info: + memory = ShortTermMemory(path=str(temp_db_dir)) + memory.reset() + assert "openai api key" in str(exc_info.value).lower() + finally: + if original_key: + os.environ["OPENAI_API_KEY"] = original_key + if original_provider: + os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider + elif "CREWAI_EMBEDDING_PROVIDER" in os.environ: + del os.environ["CREWAI_EMBEDDING_PROVIDER"] def test_memory_reset_cleans_up_files(temp_db_dir): """Test that memory reset properly cleans up database files.""" + original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER") + try: + class TestEmbedder(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + if isinstance(input, str): + input = [input] + return [[0.5] * 10] * len(input) - class TestEmbedder(EmbeddingFunction): - def __call__(self, input: Documents) -> Embeddings: - if isinstance(input, str): - input = [input] - return [[0.5] * 10] * len(input) - - memory = ShortTermMemory( - path=str(temp_db_dir), embedder_config={"provider": TestEmbedder()} - ) - memory.save("test memory", {"test": "metadata"}) - assert any(temp_db_dir.iterdir()) # Directory should have files - memory.reset() - # After reset, directory should either not exist or be empty - assert not os.path.exists(temp_db_dir) or not any(temp_db_dir.iterdir()) + if "CREWAI_EMBEDDING_PROVIDER" in os.environ: + del os.environ["CREWAI_EMBEDDING_PROVIDER"] + memory = ShortTermMemory( + path=str(temp_db_dir), embedder_config={"provider": TestEmbedder()} + ) + memory.save("test memory", {"test": "metadata"}) + assert any(temp_db_dir.iterdir()) # Directory should have files + memory.reset() + # After reset, directory should either not exist or be empty + assert not os.path.exists(temp_db_dir) or not any(temp_db_dir.iterdir()) + finally: + if original_provider: + os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider