refactor: unify rag storage with instance-specific client support (#3455)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled

- ignore line length errors globally
- migrate knowledge/memory and crew query_knowledge to `SearchResult`
- remove legacy chromadb utils; fix empty metadata handling
- restore openai as default embedding provider; support instance-specific clients
- update and fix tests for `SearchResult` migration and rag changes
This commit is contained in:
Greyson LaLonde
2025-09-17 14:46:54 -04:00
committed by GitHub
parent 81bd81e5f5
commit f28e78c5ba
30 changed files with 1956 additions and 976 deletions

View File

@@ -1,7 +1,6 @@
"""Test Knowledge creation and querying functionality."""
from pathlib import Path
from typing import List, Union
from unittest.mock import patch
import pytest
@@ -23,7 +22,7 @@ def mock_vector_db():
instance = mock.return_value
instance.query.return_value = [
{
"context": "Brandon's favorite color is blue and he likes Mexican food.",
"content": "Brandon's favorite color is blue and he likes Mexican food.",
"score": 0.9,
}
]
@@ -44,13 +43,13 @@ def test_single_short_string(mock_vector_db):
content=content, metadata={"preference": "personal"}
)
mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite color?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("blue" in result["context"].lower() for result in results)
assert any("blue" in result["content"].lower() for result in results)
# Verify the mock was called
mock_vector_db.query.assert_called_once()
@@ -84,14 +83,14 @@ def test_single_2k_character_string(mock_vector_db):
content=content, metadata={"preference": "personal"}
)
mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite movie?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("inception" in result["context"].lower() for result in results)
assert any("inception" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -109,7 +108,7 @@ def test_multiple_short_strings(mock_vector_db):
# Mock the vector db query response
mock_vector_db.query.return_value = [
{"context": "Brandon has a dog named Max.", "score": 0.9}
{"content": "Brandon has a dog named Max.", "score": 0.9}
]
mock_vector_db.sources = string_sources
@@ -119,7 +118,7 @@ def test_multiple_short_strings(mock_vector_db):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("max" in result["context"].lower() for result in results)
assert any("max" in result["content"].lower() for result in results)
# Verify the mock was called
mock_vector_db.query.assert_called_once()
@@ -180,7 +179,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
]
mock_vector_db.sources = string_sources
mock_vector_db.query.return_value = [{"context": contents[1], "score": 0.9}]
mock_vector_db.query.return_value = [{"content": contents[1], "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite book?"
@@ -188,7 +187,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
# Assert that the correct information is retrieved
assert any(
"the hitchhiker's guide to the galaxy" in result["context"].lower()
"the hitchhiker's guide to the galaxy" in result["content"].lower()
for result in results
)
mock_vector_db.query.assert_called_once()
@@ -205,13 +204,13 @@ def test_single_short_file(mock_vector_db, tmpdir):
file_paths=[file_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What sport does Brandon like?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("basketball" in result["context"].lower() for result in results)
assert any("basketball" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -247,13 +246,13 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
file_paths=[file_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite movie?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("inception" in result["context"].lower() for result in results)
assert any("inception" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -286,13 +285,13 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
]
mock_vector_db.sources = file_sources
mock_vector_db.query.return_value = [
{"context": "Brandon lives in New York.", "score": 0.9}
{"content": "Brandon lives in New York.", "score": 0.9}
]
# Perform a query
query = "What city does he reside in?"
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("new york" in result["context"].lower() for result in results)
assert any("new york" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -360,7 +359,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
mock_vector_db.sources = file_sources
mock_vector_db.query.return_value = [
{
"context": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
"content": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
"score": 0.9,
}
]
@@ -370,7 +369,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
# Assert that the correct information is retrieved
assert any(
"the hitchhiker's guide to the galaxy" in result["context"].lower()
"the hitchhiker's guide to the galaxy" in result["content"].lower()
for result in results
)
mock_vector_db.query.assert_called_once()
@@ -407,14 +406,14 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
# Combine string and file sources
mock_vector_db.sources = string_sources + file_sources
mock_vector_db.query.return_value = [{"context": file_contents[1], "score": 0.9}]
mock_vector_db.query.return_value = [{"content": file_contents[1], "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite book?"
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("the alchemist" in result["context"].lower() for result in results)
assert any("the alchemist" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -430,7 +429,7 @@ def test_pdf_knowledge_source(mock_vector_db):
)
mock_vector_db.sources = [pdf_source]
mock_vector_db.query.return_value = [
{"context": "crewai create crew latest-ai-development", "score": 0.9}
{"content": "crewai create crew latest-ai-development", "score": 0.9}
]
# Perform a query
@@ -439,7 +438,7 @@ def test_pdf_knowledge_source(mock_vector_db):
# Assert that the correct information is retrieved
assert any(
"crewai create crew latest-ai-development" in result["context"].lower()
"crewai create crew latest-ai-development" in result["content"].lower()
for result in results
)
mock_vector_db.query.assert_called_once()
@@ -467,7 +466,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
)
mock_vector_db.sources = [csv_source]
mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9}
{"content": "Brandon is 30 years old.", "score": 0.9}
]
# Perform a query
@@ -475,7 +474,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results)
assert any("30" in result["content"] for result in results)
mock_vector_db.query.assert_called_once()
@@ -502,7 +501,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
)
mock_vector_db.sources = [json_source]
mock_vector_db.query.return_value = [
{"context": "Alice lives in Los Angeles.", "score": 0.9}
{"content": "Alice lives in Los Angeles.", "score": 0.9}
]
# Perform a query
@@ -510,7 +509,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("los angeles" in result["context"].lower() for result in results)
assert any("los angeles" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -518,7 +517,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
"""Test ExcelKnowledgeSource with a simple Excel file."""
# Create an Excel file with sample data
import pandas as pd
import pandas as pd # type: ignore[import-untyped]
excel_data = {
"Name": ["Brandon", "Alice", "Bob"],
@@ -535,7 +534,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
)
mock_vector_db.sources = [excel_source]
mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9}
{"content": "Brandon is 30 years old.", "score": 0.9}
]
# Perform a query
@@ -543,7 +542,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results)
assert any("30" in result["content"] for result in results)
mock_vector_db.query.assert_called_once()
@@ -557,20 +556,20 @@ def test_docling_source(mock_vector_db):
mock_vector_db.sources = [docling_source]
mock_vector_db.query.return_value = [
{
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"content": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"score": 0.9,
}
]
# Perform a query
query = "What is reward hacking?"
results = mock_vector_db.query(query)
assert any("reward hacking" in result["context"].lower() for result in results)
assert any("reward hacking" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@pytest.mark.vcr
def test_multiple_docling_sources():
urls: List[Union[Path, str]] = [
def test_multiple_docling_sources() -> None:
urls: list[Path | str] = [
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
]

View File

@@ -0,0 +1,191 @@
"""Tests for Knowledge SearchResult type conversion and integration."""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.knowledge import Knowledge # type: ignore[import-untyped]
from crewai.knowledge.source.string_knowledge_source import ( # type: ignore[import-untyped]
StringKnowledgeSource,
)
from crewai.knowledge.utils.knowledge_utils import ( # type: ignore[import-untyped]
extract_knowledge_context,
)
def test_knowledge_query_returns_searchresult() -> None:
"""Test that Knowledge.query returns SearchResult format."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.return_value = [
{
"content": "AI is fascinating",
"score": 0.9,
"metadata": {"source": "doc1"},
},
{
"content": "Machine learning rocks",
"score": 0.8,
"metadata": {"source": "doc2"},
},
]
sources = [StringKnowledgeSource(content="Test knowledge content")]
knowledge = Knowledge(collection_name="test_collection", sources=sources)
results = knowledge.query(
["AI technology"], results_limit=5, score_threshold=0.3
)
mock_storage.search.assert_called_once_with(
["AI technology"], limit=5, score_threshold=0.3
)
assert isinstance(results, list)
assert len(results) == 2
for result in results:
assert isinstance(result, dict)
assert "content" in result
assert "score" in result
assert "metadata" in result
assert results[0]["content"] == "AI is fascinating"
assert results[0]["score"] == 0.9
assert results[1]["content"] == "Machine learning rocks"
assert results[1]["score"] == 0.8
def test_knowledge_query_with_empty_results() -> None:
"""Test Knowledge.query with empty search results."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.return_value = []
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="empty_test", sources=sources)
results = knowledge.query(["nonexistent query"])
assert isinstance(results, list)
assert len(results) == 0
def test_extract_knowledge_context_with_searchresult() -> None:
"""Test extract_knowledge_context works with SearchResult format."""
search_results = [
{"content": "Python is great for AI", "score": 0.95, "metadata": {}},
{"content": "Machine learning algorithms", "score": 0.88, "metadata": {}},
{"content": "Deep learning frameworks", "score": 0.82, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert "Additional Information:" in context
assert "Python is great for AI" in context
assert "Machine learning algorithms" in context
assert "Deep learning frameworks" in context
expected_content = (
"Python is great for AI\nMachine learning algorithms\nDeep learning frameworks"
)
assert expected_content in context
def test_extract_knowledge_context_with_empty_content() -> None:
"""Test extract_knowledge_context handles empty or invalid content."""
search_results = [
{"content": "", "score": 0.5, "metadata": {}},
{"content": None, "score": 0.4, "metadata": {}},
{"score": 0.3, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert context == ""
def test_extract_knowledge_context_filters_invalid_results() -> None:
"""Test that extract_knowledge_context filters out invalid results."""
search_results: list[dict[str, Any] | None] = [
{"content": "Valid content 1", "score": 0.9, "metadata": {}},
{"content": "", "score": 0.8, "metadata": {}},
{"content": "Valid content 2", "score": 0.7, "metadata": {}},
None,
{"content": None, "score": 0.6, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert "Additional Information:" in context
assert "Valid content 1" in context
assert "Valid content 2" in context
assert context.count("\n") == 1
@patch("crewai.rag.config.utils.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
def test_knowledge_storage_exception_handling(
mock_storage_class: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test Knowledge handles storage exceptions gracefully."""
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.side_effect = Exception("Storage error")
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="error_test", sources=sources)
with pytest.raises(ValueError, match="Storage is not initialized"):
knowledge.storage = None
knowledge.query(["test query"])
def test_knowledge_add_sources_integration() -> None:
"""Test Knowledge.add_sources integrates properly with storage."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
sources = [
StringKnowledgeSource(content="Content 1"),
StringKnowledgeSource(content="Content 2"),
]
knowledge = Knowledge(collection_name="add_sources_test", sources=sources)
knowledge.add_sources()
for source in sources:
assert source.storage == mock_storage
def test_knowledge_reset_integration() -> None:
"""Test Knowledge.reset integrates with storage."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="reset_test", sources=sources)
knowledge.reset()
mock_storage.reset.assert_called_once()
@patch("crewai.rag.config.utils.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
def test_knowledge_reset_without_storage(
mock_storage_class: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test Knowledge.reset raises error when storage is None."""
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="no_storage_test", sources=sources)
knowledge.storage = None
with pytest.raises(ValueError, match="Storage is not initialized"):
knowledge.reset()

View File

@@ -0,0 +1,196 @@
"""Integration tests for KnowledgeStorage RAG client migration."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
KnowledgeStorage,
)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.create_client")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_knowledge_storage_uses_rag_client(
mock_get_embedding: MagicMock,
mock_create_client: MagicMock,
mock_get_client: MagicMock,
) -> None:
"""Test that KnowledgeStorage properly integrates with RAG client."""
mock_client = MagicMock()
mock_create_client.return_value = mock_client
mock_get_client.return_value = mock_client
mock_client.search.return_value = [
{"content": "test content", "score": 0.9, "metadata": {"source": "test"}}
]
embedder_config = {"provider": "openai", "model": "text-embedding-3-small"}
storage = KnowledgeStorage(
embedder=embedder_config, collection_name="test_knowledge"
)
mock_create_client.assert_called_once()
results = storage.search(["test query"], limit=5, score_threshold=0.3)
mock_get_client.assert_not_called()
mock_client.search.assert_called_once_with(
collection_name="knowledge_test_knowledge",
query="test query",
limit=5,
metadata_filter=None,
score_threshold=0.3,
)
assert isinstance(results, list)
assert len(results) == 1
assert isinstance(results[0], dict)
assert "content" in results[0]
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_collection_name_prefixing(mock_get_client: MagicMock) -> None:
"""Test that collection names are properly prefixed."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage(collection_name="custom_knowledge")
storage.search(["test"], limit=1)
mock_client.search.assert_called_once()
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["collection_name"] == "knowledge_custom_knowledge"
mock_client.reset_mock()
storage_default = KnowledgeStorage()
storage_default.search(["test"], limit=1)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["collection_name"] == "knowledge"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_save_documents_integration(mock_get_client: MagicMock) -> None:
"""Test document saving through RAG client."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(collection_name="test_docs")
documents = ["Document 1 content", "Document 2 content"]
storage.save(documents)
mock_client.get_or_create_collection.assert_called_once_with(
collection_name="knowledge_test_docs"
)
mock_client.add_documents.assert_called_once()
call_kwargs = mock_client.add_documents.call_args.kwargs
added_docs = call_kwargs["documents"]
assert len(added_docs) == 2
assert added_docs[0]["content"] == "Document 1 content"
assert added_docs[1]["content"] == "Document 2 content"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_reset_integration(mock_get_client: MagicMock) -> None:
"""Test collection reset through RAG client."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(collection_name="test_reset")
storage.reset()
mock_client.delete_collection.assert_called_once_with(
collection_name="knowledge_test_reset"
)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_search_error_handling(mock_get_client: MagicMock) -> None:
"""Test error handling during search operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = Exception("RAG client error")
storage = KnowledgeStorage(collection_name="error_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_embedding_configuration_flow(
mock_get_embedding: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test that embedding configuration flows properly to RAG client."""
mock_embedding_func = MagicMock()
mock_get_embedding.return_value = mock_embedding_func
mock_get_client.return_value = MagicMock()
embedder_config = {
"provider": "sentence-transformer",
"model_name": "all-MiniLM-L6-v2",
}
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
mock_get_embedding.assert_called_once_with(embedder_config)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_query_list_conversion(mock_get_client: MagicMock) -> None:
"""Test that query list is properly converted to string."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage()
storage.search(["single query"])
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["query"] == "single query"
mock_client.reset_mock()
storage.search(["query one", "query two"])
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["query"] == "query one query two"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_metadata_filter_handling(mock_get_client: MagicMock) -> None:
"""Test metadata filter parameter handling."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage()
metadata_filter = {"category": "technical", "priority": "high"}
storage.search(["test"], metadata_filter=metadata_filter)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["metadata_filter"] == metadata_filter
mock_client.reset_mock()
storage.search(["test"], metadata_filter=None)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["metadata_filter"] is None
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
"""Test specific handling of dimension mismatch errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.side_effect = Exception("dimension mismatch detected")
storage = KnowledgeStorage(collection_name="dimension_test")
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
storage.save(["test document"])