mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
refactor: unify rag storage with instance-specific client support (#3455)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
- ignore line length errors globally - migrate knowledge/memory and crew query_knowledge to `SearchResult` - remove legacy chromadb utils; fix empty metadata handling - restore openai as default embedding provider; support instance-specific clients - update and fix tests for `SearchResult` migration and rag changes
This commit is contained in:
@@ -285,6 +285,43 @@ class TestChromaDBClient:
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with documents that have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
def test_add_documents_all_without_metadata(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents when all documents have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
{"content": "Document 2"},
|
||||
{"content": "Document 3"},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] is None
|
||||
|
||||
def test_add_documents_empty_list_raises_error(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
@@ -358,6 +395,31 @@ class TestChromaDBClient:
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_without_metadata(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with documents that have no metadata."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_empty_list_raises_error(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
|
||||
95
tests/rag/chromadb/test_utils.py
Normal file
95
tests/rag/chromadb/test_utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Tests for ChromaDB utility functions."""
|
||||
|
||||
from crewai.rag.chromadb.utils import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
_is_ipv4_pattern,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
|
||||
|
||||
class TestChromaDBUtils:
|
||||
"""Test suite for ChromaDB utility functions."""
|
||||
|
||||
def test_sanitize_collection_name_long_name(self) -> None:
|
||||
"""Test sanitizing a very long collection name."""
|
||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||
sanitized = _sanitize_collection_name(long_name)
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_special_chars(self) -> None:
|
||||
"""Test sanitizing a name with special characters."""
|
||||
special_chars = "Agent@123!#$%^&*()"
|
||||
sanitized = _sanitize_collection_name(special_chars)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_short_name(self) -> None:
|
||||
"""Test sanitizing a very short name."""
|
||||
short_name = "A"
|
||||
sanitized = _sanitize_collection_name(short_name)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_bad_ends(self) -> None:
|
||||
"""Test sanitizing a name with non-alphanumeric start/end."""
|
||||
bad_ends = "_Agent_"
|
||||
sanitized = _sanitize_collection_name(bad_ends)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_none(self) -> None:
|
||||
"""Test sanitizing a None value."""
|
||||
sanitized = _sanitize_collection_name(None)
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_ipv4_pattern(self) -> None:
|
||||
"""Test sanitizing an IPv4 address."""
|
||||
ipv4 = "192.168.1.1"
|
||||
sanitized = _sanitize_collection_name(ipv4)
|
||||
assert sanitized.startswith("ip_")
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_is_ipv4_pattern(self) -> None:
|
||||
"""Test IPv4 pattern detection."""
|
||||
assert _is_ipv4_pattern("192.168.1.1") is True
|
||||
assert _is_ipv4_pattern("not.an.ip.address") is False
|
||||
|
||||
def test_sanitize_collection_name_properties(self) -> None:
|
||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||
test_cases: list[str] = [
|
||||
"A" * 100, # Very long name
|
||||
"_start_with_underscore",
|
||||
"end_with_underscore_",
|
||||
"contains@special#characters",
|
||||
"192.168.1.1", # IPv4 address
|
||||
"a" * 2, # Too short
|
||||
]
|
||||
for test_case in test_cases:
|
||||
sanitized = _sanitize_collection_name(test_case)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_empty_string(self) -> None:
|
||||
"""Test sanitizing an empty string."""
|
||||
sanitized = _sanitize_collection_name("")
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_whitespace_only(self) -> None:
|
||||
"""Test sanitizing a string with only whitespace."""
|
||||
sanitized = _sanitize_collection_name(" ")
|
||||
assert (
|
||||
sanitized == "a__z"
|
||||
) # Spaces become underscores, padded to meet requirements
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
250
tests/rag/embeddings/test_factory_enhanced.py
Normal file
250
tests/rag/embeddings/test_factory_enhanced.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Enhanced tests for embedding function factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
|
||||
get_embedding_function,
|
||||
)
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
|
||||
|
||||
|
||||
def test_get_embedding_function_default() -> None:
|
||||
"""Test default embedding function when no config provided."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
|
||||
):
|
||||
result = get_embedding_function()
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-api-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_with_embedding_options() -> None:
|
||||
"""Test embedding function creation with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="openai", api_key="test-key", model="text-embedding-3-large"
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
# OpenAI uses model_name parameter, not model
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_sentence_transformer() -> None:
|
||||
"""Test sentence transformer embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
|
||||
) as mock_st:
|
||||
mock_instance = MagicMock()
|
||||
mock_st.return_value = mock_instance
|
||||
|
||||
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_ollama() -> None:
|
||||
"""Test Ollama embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
|
||||
mock_instance = MagicMock()
|
||||
mock_ollama.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_ollama.assert_called_once_with(
|
||||
model_name="nomic-embed-text", url="http://localhost:11434"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_cohere() -> None:
|
||||
"""Test Cohere embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
|
||||
mock_instance = MagicMock()
|
||||
mock_cohere.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"api_key": "cohere-key",
|
||||
"model_name": "embed-english-v3.0",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_cohere.assert_called_once_with(
|
||||
api_key="cohere-key", model_name="embed-english-v3.0"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_huggingface() -> None:
|
||||
"""Test HuggingFace embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
|
||||
mock_instance = MagicMock()
|
||||
mock_hf.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"api_key": "hf-token",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_hf.assert_called_once_with(
|
||||
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_onnx() -> None:
|
||||
"""Test ONNX embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
|
||||
mock_instance = MagicMock()
|
||||
mock_onnx.return_value = mock_instance
|
||||
|
||||
config = {"provider": "onnx"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_google_palm() -> None:
|
||||
"""Test Google PaLM embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
|
||||
) as mock_palm:
|
||||
mock_instance = MagicMock()
|
||||
mock_palm.return_value = mock_instance
|
||||
|
||||
config = {"provider": "google-palm", "api_key": "palm-key"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_palm.assert_called_once_with(api_key="palm-key")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_amazon_bedrock() -> None:
|
||||
"""Test Amazon Bedrock embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
|
||||
) as mock_bedrock:
|
||||
mock_instance = MagicMock()
|
||||
mock_bedrock.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"region_name": "us-west-2",
|
||||
"model_name": "amazon.titan-embed-text-v1",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_bedrock.assert_called_once_with(
|
||||
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_jina() -> None:
|
||||
"""Test Jina embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
|
||||
mock_instance = MagicMock()
|
||||
mock_jina.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "jina",
|
||||
"api_key": "jina-key",
|
||||
"model_name": "jina-embeddings-v2-base-en",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_jina.assert_called_once_with(
|
||||
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_unsupported_provider() -> None:
|
||||
"""Test handling of unsupported provider."""
|
||||
config = {"provider": "unsupported-provider"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_config_modification() -> None:
|
||||
"""Test that original config dict is not modified."""
|
||||
original_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
}
|
||||
config_copy = original_config.copy()
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
|
||||
get_embedding_function(config_copy)
|
||||
|
||||
assert config_copy == original_config
|
||||
|
||||
|
||||
def test_get_embedding_function_exclude_none_values() -> None:
|
||||
"""Test that None values are excluded from embedding function calls."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
assert "model" not in call_kwargs
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_instructor() -> None:
|
||||
"""Test Instructor embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
|
||||
) as mock_instructor:
|
||||
mock_instance = MagicMock()
|
||||
mock_instructor.return_value = mock_instance
|
||||
|
||||
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||
assert result == mock_instance
|
||||
218
tests/rag/test_error_handling.py
Normal file
218
tests/rag/test_error_handling.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tests for RAG client error handling scenarios."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_connection_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles RAG client connection failures."""
|
||||
mock_get_client.side_effect = ConnectionError("Unable to connect to ChromaDB")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="connection_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_search_timeout(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles search timeouts gracefully."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = TimeoutError("Search operation timed out")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="timeout_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_collection_not_found(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles missing collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = ValueError(
|
||||
"Collection 'knowledge_missing' does not exist"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="missing_collection")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles invalid embedding configurations."""
|
||||
mock_get_client.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.get_embedding_function"
|
||||
) as mock_get_embedding:
|
||||
mock_get_embedding.side_effect = ValueError(
|
||||
"Unsupported provider: invalid_provider"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
|
||||
KnowledgeStorage(
|
||||
embedder={"provider": "invalid_provider"},
|
||||
collection_name="invalid_embedding_test",
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles RAG client failures in memory operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = RuntimeError("ChromaDB server error")
|
||||
|
||||
storage = RAGStorage("short_term", crew=None)
|
||||
|
||||
results = storage.search("test query")
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles save operation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.add_documents.side_effect = Exception("Failed to add documents")
|
||||
|
||||
storage = RAGStorage("long_term", crew=None)
|
||||
|
||||
storage.save("test memory", {"key": "value"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage reset handles readonly database errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception(
|
||||
"attempt to write a readonly database"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="readonly_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_collection_does_not_exist(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test KnowledgeStorage reset handles non-existent collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Collection does not exist")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="nonexistent_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage reset propagates unexpected errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Unexpected database error")
|
||||
|
||||
storage = RAGStorage("entities", crew=None)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="An error occurred while resetting the entities memory"
|
||||
):
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles malformed search results."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = [
|
||||
{"content": "valid result", "metadata": {"source": "test"}},
|
||||
{"invalid": "missing content field", "metadata": {"source": "test"}},
|
||||
None,
|
||||
{"content": None, "metadata": {"source": "test"}},
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="malformed_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 4
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles network interruptions during operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_client.search.side_effect = [
|
||||
ConnectionError("Network interruption"),
|
||||
[{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}],
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="network_test")
|
||||
|
||||
first_attempt = storage.search(["test query"])
|
||||
assert first_attempt == []
|
||||
|
||||
mock_client.search.side_effect = None
|
||||
mock_client.search.return_value = [
|
||||
{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}
|
||||
]
|
||||
|
||||
second_attempt = storage.search(["test query"])
|
||||
assert len(second_attempt) == 1
|
||||
assert second_attempt[0]["content"] == "recovered result"
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles collection creation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.side_effect = Exception(
|
||||
"Failed to create collection"
|
||||
)
|
||||
|
||||
storage = RAGStorage("user_memory", crew=None)
|
||||
|
||||
storage.save("test data", {"metadata": "test"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_embedding_dimension_mismatch_detailed(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test detailed handling of embedding dimension mismatch errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.side_effect = Exception(
|
||||
"Embedding dimension mismatch: expected 384, got 1536"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="dimension_detailed_test")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
storage.save(["test document"])
|
||||
|
||||
assert "Embedding dimension mismatch" in str(exc_info.value)
|
||||
assert "Make sure you're using the same embedding model" in str(exc_info.value)
|
||||
assert "crewai reset-memories -a" in str(exc_info.value)
|
||||
Reference in New Issue
Block a user