diff --git a/src/crewai/rag/chromadb/client.py b/src/crewai/rag/chromadb/client.py index ea67cd2fb..b61f85e36 100644 --- a/src/crewai/rag/chromadb/client.py +++ b/src/crewai/rag/chromadb/client.py @@ -40,8 +40,19 @@ class ChromaDBClient(BaseClient): embedding_function: Function to generate embeddings for documents. """ - client: ChromaDBClientType - embedding_function: ChromaEmbeddingFunction[Embeddable] + def __init__( + self, + client: ChromaDBClientType, + embedding_function: ChromaEmbeddingFunction[Embeddable], + ) -> None: + """Initialize ChromaDBClient with client and embedding function. + + Args: + client: Pre-configured ChromaDB client instance. + embedding_function: Embedding function for text to vector conversion. + """ + self.client = client + self.embedding_function = embedding_function def create_collection( self, **kwargs: Unpack[ChromaDBCollectionCreateParams] diff --git a/src/crewai/rag/chromadb/factory.py b/src/crewai/rag/chromadb/factory.py index fff9f2dc1..4d3844910 100644 --- a/src/crewai/rag/chromadb/factory.py +++ b/src/crewai/rag/chromadb/factory.py @@ -15,12 +15,10 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient: Returns: Configured ChromaDBClient instance. """ - chromadb_client = Client( - settings=config.settings, tenant=config.tenant, database=config.database + + return ChromaDBClient( + client=Client( + settings=config.settings, tenant=config.tenant, database=config.database + ), + embedding_function=config.embedding_function, ) - - client = ChromaDBClient() - client.client = chromadb_client - client.embedding_function = config.embedding_function - - return client diff --git a/src/crewai/rag/config/utils.py b/src/crewai/rag/config/utils.py index 0eaef87f1..9db9fa732 100644 --- a/src/crewai/rag/config/utils.py +++ b/src/crewai/rag/config/utils.py @@ -11,7 +11,7 @@ from crewai.rag.config.constants import ( DEFAULT_RAG_CONFIG_PATH, DEFAULT_RAG_CONFIG_CLASS, ) -from crewai.rag.config.factory import create_client +from crewai.rag.factory import create_client class RagContext(BaseModel): diff --git a/src/crewai/rag/config/factory.py b/src/crewai/rag/factory.py similarity index 100% rename from src/crewai/rag/config/factory.py rename to src/crewai/rag/factory.py diff --git a/src/crewai/rag/qdrant/client.py b/src/crewai/rag/qdrant/client.py index 9c9bd5f00..3386d3411 100644 --- a/src/crewai/rag/qdrant/client.py +++ b/src/crewai/rag/qdrant/client.py @@ -2,8 +2,6 @@ from typing import Any, cast -from fastembed import TextEmbedding -from qdrant_client import QdrantClient as SyncQdrantClientBase from typing_extensions import Unpack from crewai.rag.core.base_client import ( @@ -16,7 +14,6 @@ from crewai.rag.core.exceptions import ClientMethodMismatchError from crewai.rag.qdrant.types import ( AsyncEmbeddingFunction, EmbeddingFunction, - QdrantClientParams, QdrantClientType, QdrantCollectionCreateParams, ) @@ -43,40 +40,19 @@ class QdrantClient(BaseClient): embedding_function: Function to generate embeddings for documents. """ - client: QdrantClientType - embedding_function: EmbeddingFunction | AsyncEmbeddingFunction - def __init__( self, - client: QdrantClientType | None = None, - embedding_function: EmbeddingFunction | AsyncEmbeddingFunction | None = None, - **kwargs: Unpack[QdrantClientParams], + client: QdrantClientType, + embedding_function: EmbeddingFunction | AsyncEmbeddingFunction, ) -> None: - """Initialize QdrantClient with optional client and embedding function. + """Initialize QdrantClient with client and embedding function. Args: - client: Optional pre-configured Qdrant client instance. - embedding_function: Optional embedding function. If not provided, - uses FastEmbed's BAAI/bge-small-en-v1.5 model. - **kwargs: Additional arguments for QdrantClient creation. + client: Pre-configured Qdrant client instance. + embedding_function: Embedding function for text to vector conversion. """ - if client is not None: - self.client = client - else: - location = kwargs.get("location", ":memory:") - client_kwargs = {k: v for k, v in kwargs.items() if k != "location"} - self.client = SyncQdrantClientBase(location, **cast(Any, client_kwargs)) - - if embedding_function is not None: - self.embedding_function = embedding_function - else: - _embedder = TextEmbedding("BAAI/bge-small-en-v1.5") - - def _embed_fn(text: str) -> list[float]: - embeddings = list(_embedder.embed([text])) - return [float(x) for x in embeddings[0]] if embeddings else [] - - self.embedding_function = _embed_fn + self.client = client + self.embedding_function = embedding_function def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None: """Create a new collection in Qdrant. @@ -284,7 +260,7 @@ class QdrantClient(BaseClient): point = _create_point_from_document(doc, embedding) points.append(point) - self.client.upsert(collection_name=collection_name, points=points, wait=True) + self.client.upsert(collection_name=collection_name, points=points) async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: """Add documents with their embeddings to a collection asynchronously. @@ -325,9 +301,7 @@ class QdrantClient(BaseClient): point = _create_point_from_document(doc, embedding) points.append(point) - await self.client.upsert( - collection_name=collection_name, points=points, wait=True - ) + await self.client.upsert(collection_name=collection_name, points=points) def search( self, **kwargs: Unpack[BaseCollectionSearchParams] diff --git a/src/crewai/rag/qdrant/types.py b/src/crewai/rag/qdrant/types.py index 706a4f155..1ed523e6a 100644 --- a/src/crewai/rag/qdrant/types.py +++ b/src/crewai/rag/qdrant/types.py @@ -86,7 +86,11 @@ class AsyncEmbeddingFunction(Protocol): class QdrantClientParams(TypedDict, total=False): - """Parameters for QdrantClient initialization.""" + """Parameters for QdrantClient initialization. + + Notes: + Need to implement in factory or remove. + """ location: str | None url: str | None diff --git a/tests/rag/chromadb/test_client.py b/tests/rag/chromadb/test_client.py index 67d58614e..4006b981e 100644 --- a/tests/rag/chromadb/test_client.py +++ b/tests/rag/chromadb/test_client.py @@ -27,18 +27,20 @@ def mock_async_chromadb_client(): @pytest.fixture def client(mock_chromadb_client) -> ChromaDBClient: """Create a ChromaDBClient instance for testing.""" - client = ChromaDBClient() - client.client = mock_chromadb_client - client.embedding_function = Mock() + mock_embedding = Mock() + client = ChromaDBClient( + client=mock_chromadb_client, embedding_function=mock_embedding + ) return client @pytest.fixture def async_client(mock_async_chromadb_client) -> ChromaDBClient: """Create a ChromaDBClient instance with async client for testing.""" - client = ChromaDBClient() - client.client = mock_async_chromadb_client - client.embedding_function = Mock() + mock_embedding = Mock() + client = ChromaDBClient( + client=mock_async_chromadb_client, embedding_function=mock_embedding + ) return client diff --git a/tests/rag/config/test_factory.py b/tests/rag/config/test_factory.py index 91f30329e..1482f1d41 100644 --- a/tests/rag/config/test_factory.py +++ b/tests/rag/config/test_factory.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, patch -from crewai.rag.config.factory import create_client +from crewai.rag.factory import create_client def test_create_client_chromadb(): @@ -10,7 +10,7 @@ def test_create_client_chromadb(): mock_config = Mock() mock_config.provider = "chromadb" - with patch("crewai.rag.config.factory.require") as mock_require: + with patch("crewai.rag.factory.require") as mock_require: mock_module = Mock() mock_client = Mock() mock_module.create_client.return_value = mock_client diff --git a/tests/rag/qdrant/test_client.py b/tests/rag/qdrant/test_client.py index 3eaed7921..a1c16e9bc 100644 --- a/tests/rag/qdrant/test_client.py +++ b/tests/rag/qdrant/test_client.py @@ -236,7 +236,6 @@ class TestQdrantClient: # Check upsert was called with correct parameters call_args = mock_qdrant_client.upsert.call_args assert call_args.kwargs["collection_name"] == "test_collection" - assert call_args.kwargs["wait"] is True assert len(call_args.kwargs["points"]) == 1 point = call_args.kwargs["points"][0] assert point.vector == [0.1, 0.2, 0.3] @@ -330,7 +329,6 @@ class TestQdrantClient: # Check upsert was called with correct parameters call_args = mock_async_qdrant_client.upsert.call_args assert call_args.kwargs["collection_name"] == "test_collection" - assert call_args.kwargs["wait"] is True assert len(call_args.kwargs["points"]) == 1 point = call_args.kwargs["points"][0] assert point.vector == [0.1, 0.2, 0.3]