refactor: simplify rag client initialization (#3401)

* Simplified Qdrant and ChromaDB client initialization
* Refactored factory structure and updated tests accordingly
This commit is contained in:
Greyson LaLonde
2025-08-26 08:54:51 -04:00
committed by GitHub
parent 869bb115c8
commit 4b4a119a9f
9 changed files with 44 additions and 57 deletions

View File

@@ -27,18 +27,20 @@ def mock_async_chromadb_client():
@pytest.fixture
def client(mock_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance for testing."""
client = ChromaDBClient()
client.client = mock_chromadb_client
client.embedding_function = Mock()
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_chromadb_client, embedding_function=mock_embedding
)
return client
@pytest.fixture
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with async client for testing."""
client = ChromaDBClient()
client.client = mock_async_chromadb_client
client.embedding_function = Mock()
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_async_chromadb_client, embedding_function=mock_embedding
)
return client

View File

@@ -2,7 +2,7 @@
from unittest.mock import Mock, patch
from crewai.rag.config.factory import create_client
from crewai.rag.factory import create_client
def test_create_client_chromadb():
@@ -10,7 +10,7 @@ def test_create_client_chromadb():
mock_config = Mock()
mock_config.provider = "chromadb"
with patch("crewai.rag.config.factory.require") as mock_require:
with patch("crewai.rag.factory.require") as mock_require:
mock_module = Mock()
mock_client = Mock()
mock_module.create_client.return_value = mock_client

View File

@@ -236,7 +236,6 @@ class TestQdrantClient:
# Check upsert was called with correct parameters
call_args = mock_qdrant_client.upsert.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["wait"] is True
assert len(call_args.kwargs["points"]) == 1
point = call_args.kwargs["points"][0]
assert point.vector == [0.1, 0.2, 0.3]
@@ -330,7 +329,6 @@ class TestQdrantClient:
# Check upsert was called with correct parameters
call_args = mock_async_qdrant_client.upsert.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["wait"] is True
assert len(call_args.kwargs["points"]) == 1
point = call_args.kwargs["points"][0]
assert point.vector == [0.1, 0.2, 0.3]