refactor: simplify rag client initialization (#3401)

* Simplified Qdrant and ChromaDB client initialization
* Refactored factory structure and updated tests accordingly
This commit is contained in:
Greyson LaLonde
2025-08-26 08:54:51 -04:00
committed by GitHub
parent 869bb115c8
commit 4b4a119a9f
9 changed files with 44 additions and 57 deletions

View File

@@ -40,8 +40,19 @@ class ChromaDBClient(BaseClient):
embedding_function: Function to generate embeddings for documents.
"""
client: ChromaDBClientType
embedding_function: ChromaEmbeddingFunction[Embeddable]
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction[Embeddable],
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
Args:
client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion.
"""
self.client = client
self.embedding_function = embedding_function
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]

View File

@@ -15,12 +15,10 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
Returns:
Configured ChromaDBClient instance.
"""
chromadb_client = Client(
return ChromaDBClient(
client=Client(
settings=config.settings, tenant=config.tenant, database=config.database
),
embedding_function=config.embedding_function,
)
client = ChromaDBClient()
client.client = chromadb_client
client.embedding_function = config.embedding_function
return client

View File

@@ -11,7 +11,7 @@ from crewai.rag.config.constants import (
DEFAULT_RAG_CONFIG_PATH,
DEFAULT_RAG_CONFIG_CLASS,
)
from crewai.rag.config.factory import create_client
from crewai.rag.factory import create_client
class RagContext(BaseModel):

View File

@@ -2,8 +2,6 @@
from typing import Any, cast
from fastembed import TextEmbedding
from qdrant_client import QdrantClient as SyncQdrantClientBase
from typing_extensions import Unpack
from crewai.rag.core.base_client import (
@@ -16,7 +14,6 @@ from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.qdrant.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
QdrantClientParams,
QdrantClientType,
QdrantCollectionCreateParams,
)
@@ -43,40 +40,19 @@ class QdrantClient(BaseClient):
embedding_function: Function to generate embeddings for documents.
"""
client: QdrantClientType
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction
def __init__(
self,
client: QdrantClientType | None = None,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction | None = None,
**kwargs: Unpack[QdrantClientParams],
client: QdrantClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
) -> None:
"""Initialize QdrantClient with optional client and embedding function.
"""Initialize QdrantClient with client and embedding function.
Args:
client: Optional pre-configured Qdrant client instance.
embedding_function: Optional embedding function. If not provided,
uses FastEmbed's BAAI/bge-small-en-v1.5 model.
**kwargs: Additional arguments for QdrantClient creation.
client: Pre-configured Qdrant client instance.
embedding_function: Embedding function for text to vector conversion.
"""
if client is not None:
self.client = client
else:
location = kwargs.get("location", ":memory:")
client_kwargs = {k: v for k, v in kwargs.items() if k != "location"}
self.client = SyncQdrantClientBase(location, **cast(Any, client_kwargs))
if embedding_function is not None:
self.embedding_function = embedding_function
else:
_embedder = TextEmbedding("BAAI/bge-small-en-v1.5")
def _embed_fn(text: str) -> list[float]:
embeddings = list(_embedder.embed([text]))
return [float(x) for x in embeddings[0]] if embeddings else []
self.embedding_function = _embed_fn
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
@@ -284,7 +260,7 @@ class QdrantClient(BaseClient):
point = _create_point_from_document(doc, embedding)
points.append(point)
self.client.upsert(collection_name=collection_name, points=points, wait=True)
self.client.upsert(collection_name=collection_name, points=points)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
@@ -325,9 +301,7 @@ class QdrantClient(BaseClient):
point = _create_point_from_document(doc, embedding)
points.append(point)
await self.client.upsert(
collection_name=collection_name, points=points, wait=True
)
await self.client.upsert(collection_name=collection_name, points=points)
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]

View File

@@ -86,7 +86,11 @@ class AsyncEmbeddingFunction(Protocol):
class QdrantClientParams(TypedDict, total=False):
"""Parameters for QdrantClient initialization."""
"""Parameters for QdrantClient initialization.
Notes:
Need to implement in factory or remove.
"""
location: str | None
url: str | None

View File

@@ -27,18 +27,20 @@ def mock_async_chromadb_client():
@pytest.fixture
def client(mock_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance for testing."""
client = ChromaDBClient()
client.client = mock_chromadb_client
client.embedding_function = Mock()
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_chromadb_client, embedding_function=mock_embedding
)
return client
@pytest.fixture
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with async client for testing."""
client = ChromaDBClient()
client.client = mock_async_chromadb_client
client.embedding_function = Mock()
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_async_chromadb_client, embedding_function=mock_embedding
)
return client

View File

@@ -2,7 +2,7 @@
from unittest.mock import Mock, patch
from crewai.rag.config.factory import create_client
from crewai.rag.factory import create_client
def test_create_client_chromadb():
@@ -10,7 +10,7 @@ def test_create_client_chromadb():
mock_config = Mock()
mock_config.provider = "chromadb"
with patch("crewai.rag.config.factory.require") as mock_require:
with patch("crewai.rag.factory.require") as mock_require:
mock_module = Mock()
mock_client = Mock()
mock_module.create_client.return_value = mock_client

View File

@@ -236,7 +236,6 @@ class TestQdrantClient:
# Check upsert was called with correct parameters
call_args = mock_qdrant_client.upsert.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["wait"] is True
assert len(call_args.kwargs["points"]) == 1
point = call_args.kwargs["points"][0]
assert point.vector == [0.1, 0.2, 0.3]
@@ -330,7 +329,6 @@ class TestQdrantClient:
# Check upsert was called with correct parameters
call_args = mock_async_qdrant_client.upsert.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["wait"] is True
assert len(call_args.kwargs["points"]) == 1
point = call_args.kwargs["points"][0]
assert point.vector == [0.1, 0.2, 0.3]