test: Add proper environment variable cleanup in memory reset tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-09 23:43:39 +00:00
parent 528ab0c410
commit dbea3758eb

View File

@@ -24,16 +24,35 @@ def temp_db_dir() -> Generator[Path, None, None]:
def test_memory_reset_with_openai(temp_db_dir): def test_memory_reset_with_openai(temp_db_dir):
"""Test memory reset with default OpenAI provider.""" """Test memory reset with default OpenAI provider."""
original_key = os.environ.get("OPENAI_API_KEY")
original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER")
try:
os.environ["OPENAI_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key"
if "CREWAI_EMBEDDING_PROVIDER" in os.environ:
del os.environ["CREWAI_EMBEDDING_PROVIDER"]
memory = ShortTermMemory(path=str(temp_db_dir)) memory = ShortTermMemory(path=str(temp_db_dir))
memory.reset() # Should work with OpenAI as default memory.reset() # Should work with OpenAI as default
finally:
if original_key:
os.environ["OPENAI_API_KEY"] = original_key
else:
del os.environ["OPENAI_API_KEY"]
if original_provider:
os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider
def test_memory_reset_with_ollama(temp_db_dir): def test_memory_reset_with_ollama(temp_db_dir):
"""Test memory reset with Ollama provider.""" """Test memory reset with Ollama provider."""
original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER")
try:
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama" os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
memory = ShortTermMemory(path=str(temp_db_dir)) memory = ShortTermMemory(path=str(temp_db_dir))
memory.reset() # Should not raise any OpenAI-related errors memory.reset() # Should not raise any OpenAI-related errors
finally:
if original_provider:
os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider
elif "CREWAI_EMBEDDING_PROVIDER" in os.environ:
del os.environ["CREWAI_EMBEDDING_PROVIDER"]
def test_memory_reset_with_custom_provider(temp_db_dir): def test_memory_reset_with_custom_provider(temp_db_dir):
@@ -53,32 +72,56 @@ def test_memory_reset_with_custom_provider(temp_db_dir):
def test_memory_reset_with_invalid_provider(temp_db_dir): def test_memory_reset_with_invalid_provider(temp_db_dir):
"""Test memory reset with invalid provider raises appropriate error.""" """Test memory reset with invalid provider raises appropriate error."""
original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER")
original_key = os.environ.get("OPENAI_API_KEY")
try:
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider" os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider"
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
memory = ShortTermMemory(path=str(temp_db_dir)) memory = ShortTermMemory(path=str(temp_db_dir))
memory.reset() memory.reset()
assert "Unsupported embedding provider" in str(exc_info.value) assert "Unsupported embedding provider" in str(exc_info.value)
finally:
if original_provider:
os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider
elif "CREWAI_EMBEDDING_PROVIDER" in os.environ:
del os.environ["CREWAI_EMBEDDING_PROVIDER"]
if original_key:
os.environ["OPENAI_API_KEY"] = original_key
def test_memory_reset_with_missing_api_key(temp_db_dir): def test_memory_reset_with_missing_api_key(temp_db_dir):
"""Test memory reset with missing API key raises appropriate error.""" """Test memory reset with missing API key raises appropriate error."""
os.environ.pop("OPENAI_API_KEY", None) # Ensure key is not set original_key = os.environ.get("OPENAI_API_KEY")
original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER")
try:
if "OPENAI_API_KEY" in os.environ:
del os.environ["OPENAI_API_KEY"]
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai" os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai"
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
memory = ShortTermMemory(path=str(temp_db_dir)) memory = ShortTermMemory(path=str(temp_db_dir))
memory.reset() memory.reset()
assert "openai api key" in str(exc_info.value).lower() assert "openai api key" in str(exc_info.value).lower()
finally:
if original_key:
os.environ["OPENAI_API_KEY"] = original_key
if original_provider:
os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider
elif "CREWAI_EMBEDDING_PROVIDER" in os.environ:
del os.environ["CREWAI_EMBEDDING_PROVIDER"]
def test_memory_reset_cleans_up_files(temp_db_dir): def test_memory_reset_cleans_up_files(temp_db_dir):
"""Test that memory reset properly cleans up database files.""" """Test that memory reset properly cleans up database files."""
original_provider = os.environ.get("CREWAI_EMBEDDING_PROVIDER")
try:
class TestEmbedder(EmbeddingFunction): class TestEmbedder(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings: def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
return [[0.5] * 10] * len(input) return [[0.5] * 10] * len(input)
if "CREWAI_EMBEDDING_PROVIDER" in os.environ:
del os.environ["CREWAI_EMBEDDING_PROVIDER"]
memory = ShortTermMemory( memory = ShortTermMemory(
path=str(temp_db_dir), embedder_config={"provider": TestEmbedder()} path=str(temp_db_dir), embedder_config={"provider": TestEmbedder()}
) )
@@ -87,3 +130,6 @@ def test_memory_reset_cleans_up_files(temp_db_dir):
memory.reset() memory.reset()
# After reset, directory should either not exist or be empty # After reset, directory should either not exist or be empty
assert not os.path.exists(temp_db_dir) or not any(temp_db_dir.iterdir()) assert not os.path.exists(temp_db_dir) or not any(temp_db_dir.iterdir())
finally:
if original_provider:
os.environ["CREWAI_EMBEDDING_PROVIDER"] = original_provider