mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
refactor: simplify rag client initialization (#3401)
* Simplified Qdrant and ChromaDB client initialization * Refactored factory structure and updated tests accordingly
This commit is contained in:
@@ -40,8 +40,19 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
"""
|
||||
|
||||
client: ChromaDBClientType
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable],
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured ChromaDB client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
|
||||
@@ -15,12 +15,10 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
Returns:
|
||||
Configured ChromaDBClient instance.
|
||||
"""
|
||||
chromadb_client = Client(
|
||||
settings=config.settings, tenant=config.tenant, database=config.database
|
||||
|
||||
return ChromaDBClient(
|
||||
client=Client(
|
||||
settings=config.settings, tenant=config.tenant, database=config.database
|
||||
),
|
||||
embedding_function=config.embedding_function,
|
||||
)
|
||||
|
||||
client = ChromaDBClient()
|
||||
client.client = chromadb_client
|
||||
client.embedding_function = config.embedding_function
|
||||
|
||||
return client
|
||||
|
||||
@@ -11,7 +11,7 @@ from crewai.rag.config.constants import (
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
DEFAULT_RAG_CONFIG_CLASS,
|
||||
)
|
||||
from crewai.rag.config.factory import create_client
|
||||
from crewai.rag.factory import create_client
|
||||
|
||||
|
||||
class RagContext(BaseModel):
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
@@ -16,7 +14,6 @@ from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
QdrantClientParams,
|
||||
QdrantClientType,
|
||||
QdrantCollectionCreateParams,
|
||||
)
|
||||
@@ -43,40 +40,19 @@ class QdrantClient(BaseClient):
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
"""
|
||||
|
||||
client: QdrantClientType
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: QdrantClientType | None = None,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction | None = None,
|
||||
**kwargs: Unpack[QdrantClientParams],
|
||||
client: QdrantClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with optional client and embedding function.
|
||||
"""Initialize QdrantClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Optional pre-configured Qdrant client instance.
|
||||
embedding_function: Optional embedding function. If not provided,
|
||||
uses FastEmbed's BAAI/bge-small-en-v1.5 model.
|
||||
**kwargs: Additional arguments for QdrantClient creation.
|
||||
client: Pre-configured Qdrant client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
"""
|
||||
if client is not None:
|
||||
self.client = client
|
||||
else:
|
||||
location = kwargs.get("location", ":memory:")
|
||||
client_kwargs = {k: v for k, v in kwargs.items() if k != "location"}
|
||||
self.client = SyncQdrantClientBase(location, **cast(Any, client_kwargs))
|
||||
|
||||
if embedding_function is not None:
|
||||
self.embedding_function = embedding_function
|
||||
else:
|
||||
_embedder = TextEmbedding("BAAI/bge-small-en-v1.5")
|
||||
|
||||
def _embed_fn(text: str) -> list[float]:
|
||||
embeddings = list(_embedder.embed([text]))
|
||||
return [float(x) for x in embeddings[0]] if embeddings else []
|
||||
|
||||
self.embedding_function = _embed_fn
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
@@ -284,7 +260,7 @@ class QdrantClient(BaseClient):
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=points, wait=True)
|
||||
self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -325,9 +301,7 @@ class QdrantClient(BaseClient):
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
await self.client.upsert(
|
||||
collection_name=collection_name, points=points, wait=True
|
||||
)
|
||||
await self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
|
||||
@@ -86,7 +86,11 @@ class AsyncEmbeddingFunction(Protocol):
|
||||
|
||||
|
||||
class QdrantClientParams(TypedDict, total=False):
|
||||
"""Parameters for QdrantClient initialization."""
|
||||
"""Parameters for QdrantClient initialization.
|
||||
|
||||
Notes:
|
||||
Need to implement in factory or remove.
|
||||
"""
|
||||
|
||||
location: str | None
|
||||
url: str | None
|
||||
|
||||
Reference in New Issue
Block a user