refactor: simplify rag client initialization (#3401)

* Simplified Qdrant and ChromaDB client initialization
* Refactored factory structure and updated tests accordingly
This commit is contained in:
Greyson LaLonde
2025-08-26 08:54:51 -04:00
committed by GitHub
parent 869bb115c8
commit 4b4a119a9f
9 changed files with 44 additions and 57 deletions

View File

@@ -40,8 +40,19 @@ class ChromaDBClient(BaseClient):
embedding_function: Function to generate embeddings for documents.
"""
client: ChromaDBClientType
embedding_function: ChromaEmbeddingFunction[Embeddable]
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction[Embeddable],
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
Args:
client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion.
"""
self.client = client
self.embedding_function = embedding_function
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]

View File

@@ -15,12 +15,10 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
Returns:
Configured ChromaDBClient instance.
"""
chromadb_client = Client(
settings=config.settings, tenant=config.tenant, database=config.database
return ChromaDBClient(
client=Client(
settings=config.settings, tenant=config.tenant, database=config.database
),
embedding_function=config.embedding_function,
)
client = ChromaDBClient()
client.client = chromadb_client
client.embedding_function = config.embedding_function
return client

View File

@@ -11,7 +11,7 @@ from crewai.rag.config.constants import (
DEFAULT_RAG_CONFIG_PATH,
DEFAULT_RAG_CONFIG_CLASS,
)
from crewai.rag.config.factory import create_client
from crewai.rag.factory import create_client
class RagContext(BaseModel):

View File

@@ -2,8 +2,6 @@
from typing import Any, cast
from fastembed import TextEmbedding
from qdrant_client import QdrantClient as SyncQdrantClientBase
from typing_extensions import Unpack
from crewai.rag.core.base_client import (
@@ -16,7 +14,6 @@ from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.qdrant.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
QdrantClientParams,
QdrantClientType,
QdrantCollectionCreateParams,
)
@@ -43,40 +40,19 @@ class QdrantClient(BaseClient):
embedding_function: Function to generate embeddings for documents.
"""
client: QdrantClientType
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction
def __init__(
self,
client: QdrantClientType | None = None,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction | None = None,
**kwargs: Unpack[QdrantClientParams],
client: QdrantClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
) -> None:
"""Initialize QdrantClient with optional client and embedding function.
"""Initialize QdrantClient with client and embedding function.
Args:
client: Optional pre-configured Qdrant client instance.
embedding_function: Optional embedding function. If not provided,
uses FastEmbed's BAAI/bge-small-en-v1.5 model.
**kwargs: Additional arguments for QdrantClient creation.
client: Pre-configured Qdrant client instance.
embedding_function: Embedding function for text to vector conversion.
"""
if client is not None:
self.client = client
else:
location = kwargs.get("location", ":memory:")
client_kwargs = {k: v for k, v in kwargs.items() if k != "location"}
self.client = SyncQdrantClientBase(location, **cast(Any, client_kwargs))
if embedding_function is not None:
self.embedding_function = embedding_function
else:
_embedder = TextEmbedding("BAAI/bge-small-en-v1.5")
def _embed_fn(text: str) -> list[float]:
embeddings = list(_embedder.embed([text]))
return [float(x) for x in embeddings[0]] if embeddings else []
self.embedding_function = _embed_fn
self.client = client
self.embedding_function = embedding_function
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
@@ -284,7 +260,7 @@ class QdrantClient(BaseClient):
point = _create_point_from_document(doc, embedding)
points.append(point)
self.client.upsert(collection_name=collection_name, points=points, wait=True)
self.client.upsert(collection_name=collection_name, points=points)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
@@ -325,9 +301,7 @@ class QdrantClient(BaseClient):
point = _create_point_from_document(doc, embedding)
points.append(point)
await self.client.upsert(
collection_name=collection_name, points=points, wait=True
)
await self.client.upsert(collection_name=collection_name, points=points)
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]

View File

@@ -86,7 +86,11 @@ class AsyncEmbeddingFunction(Protocol):
class QdrantClientParams(TypedDict, total=False):
"""Parameters for QdrantClient initialization."""
"""Parameters for QdrantClient initialization.
Notes:
Need to implement in factory or remove.
"""
location: str | None
url: str | None