mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
refactor: simplify rag client initialization (#3401)
* Simplified Qdrant and ChromaDB client initialization * Refactored factory structure and updated tests accordingly
This commit is contained in:
@@ -40,8 +40,19 @@ class ChromaDBClient(BaseClient):
|
|||||||
embedding_function: Function to generate embeddings for documents.
|
embedding_function: Function to generate embeddings for documents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client: ChromaDBClientType
|
def __init__(
|
||||||
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
self,
|
||||||
|
client: ChromaDBClientType,
|
||||||
|
embedding_function: ChromaEmbeddingFunction[Embeddable],
|
||||||
|
) -> None:
|
||||||
|
"""Initialize ChromaDBClient with client and embedding function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Pre-configured ChromaDB client instance.
|
||||||
|
embedding_function: Embedding function for text to vector conversion.
|
||||||
|
"""
|
||||||
|
self.client = client
|
||||||
|
self.embedding_function = embedding_function
|
||||||
|
|
||||||
def create_collection(
|
def create_collection(
|
||||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||||
|
|||||||
@@ -15,12 +15,10 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Configured ChromaDBClient instance.
|
Configured ChromaDBClient instance.
|
||||||
"""
|
"""
|
||||||
chromadb_client = Client(
|
|
||||||
settings=config.settings, tenant=config.tenant, database=config.database
|
return ChromaDBClient(
|
||||||
|
client=Client(
|
||||||
|
settings=config.settings, tenant=config.tenant, database=config.database
|
||||||
|
),
|
||||||
|
embedding_function=config.embedding_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = ChromaDBClient()
|
|
||||||
client.client = chromadb_client
|
|
||||||
client.embedding_function = config.embedding_function
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from crewai.rag.config.constants import (
|
|||||||
DEFAULT_RAG_CONFIG_PATH,
|
DEFAULT_RAG_CONFIG_PATH,
|
||||||
DEFAULT_RAG_CONFIG_CLASS,
|
DEFAULT_RAG_CONFIG_CLASS,
|
||||||
)
|
)
|
||||||
from crewai.rag.config.factory import create_client
|
from crewai.rag.factory import create_client
|
||||||
|
|
||||||
|
|
||||||
class RagContext(BaseModel):
|
class RagContext(BaseModel):
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from fastembed import TextEmbedding
|
|
||||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from crewai.rag.core.base_client import (
|
from crewai.rag.core.base_client import (
|
||||||
@@ -16,7 +14,6 @@ from crewai.rag.core.exceptions import ClientMethodMismatchError
|
|||||||
from crewai.rag.qdrant.types import (
|
from crewai.rag.qdrant.types import (
|
||||||
AsyncEmbeddingFunction,
|
AsyncEmbeddingFunction,
|
||||||
EmbeddingFunction,
|
EmbeddingFunction,
|
||||||
QdrantClientParams,
|
|
||||||
QdrantClientType,
|
QdrantClientType,
|
||||||
QdrantCollectionCreateParams,
|
QdrantCollectionCreateParams,
|
||||||
)
|
)
|
||||||
@@ -43,40 +40,19 @@ class QdrantClient(BaseClient):
|
|||||||
embedding_function: Function to generate embeddings for documents.
|
embedding_function: Function to generate embeddings for documents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client: QdrantClientType
|
|
||||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: QdrantClientType | None = None,
|
client: QdrantClientType,
|
||||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction | None = None,
|
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||||
**kwargs: Unpack[QdrantClientParams],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize QdrantClient with optional client and embedding function.
|
"""Initialize QdrantClient with client and embedding function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: Optional pre-configured Qdrant client instance.
|
client: Pre-configured Qdrant client instance.
|
||||||
embedding_function: Optional embedding function. If not provided,
|
embedding_function: Embedding function for text to vector conversion.
|
||||||
uses FastEmbed's BAAI/bge-small-en-v1.5 model.
|
|
||||||
**kwargs: Additional arguments for QdrantClient creation.
|
|
||||||
"""
|
"""
|
||||||
if client is not None:
|
self.client = client
|
||||||
self.client = client
|
self.embedding_function = embedding_function
|
||||||
else:
|
|
||||||
location = kwargs.get("location", ":memory:")
|
|
||||||
client_kwargs = {k: v for k, v in kwargs.items() if k != "location"}
|
|
||||||
self.client = SyncQdrantClientBase(location, **cast(Any, client_kwargs))
|
|
||||||
|
|
||||||
if embedding_function is not None:
|
|
||||||
self.embedding_function = embedding_function
|
|
||||||
else:
|
|
||||||
_embedder = TextEmbedding("BAAI/bge-small-en-v1.5")
|
|
||||||
|
|
||||||
def _embed_fn(text: str) -> list[float]:
|
|
||||||
embeddings = list(_embedder.embed([text]))
|
|
||||||
return [float(x) for x in embeddings[0]] if embeddings else []
|
|
||||||
|
|
||||||
self.embedding_function = _embed_fn
|
|
||||||
|
|
||||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||||
"""Create a new collection in Qdrant.
|
"""Create a new collection in Qdrant.
|
||||||
@@ -284,7 +260,7 @@ class QdrantClient(BaseClient):
|
|||||||
point = _create_point_from_document(doc, embedding)
|
point = _create_point_from_document(doc, embedding)
|
||||||
points.append(point)
|
points.append(point)
|
||||||
|
|
||||||
self.client.upsert(collection_name=collection_name, points=points, wait=True)
|
self.client.upsert(collection_name=collection_name, points=points)
|
||||||
|
|
||||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||||
"""Add documents with their embeddings to a collection asynchronously.
|
"""Add documents with their embeddings to a collection asynchronously.
|
||||||
@@ -325,9 +301,7 @@ class QdrantClient(BaseClient):
|
|||||||
point = _create_point_from_document(doc, embedding)
|
point = _create_point_from_document(doc, embedding)
|
||||||
points.append(point)
|
points.append(point)
|
||||||
|
|
||||||
await self.client.upsert(
|
await self.client.upsert(collection_name=collection_name, points=points)
|
||||||
collection_name=collection_name, points=points, wait=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||||
|
|||||||
@@ -86,7 +86,11 @@ class AsyncEmbeddingFunction(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class QdrantClientParams(TypedDict, total=False):
|
class QdrantClientParams(TypedDict, total=False):
|
||||||
"""Parameters for QdrantClient initialization."""
|
"""Parameters for QdrantClient initialization.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Need to implement in factory or remove.
|
||||||
|
"""
|
||||||
|
|
||||||
location: str | None
|
location: str | None
|
||||||
url: str | None
|
url: str | None
|
||||||
|
|||||||
@@ -27,18 +27,20 @@ def mock_async_chromadb_client():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client(mock_chromadb_client) -> ChromaDBClient:
|
def client(mock_chromadb_client) -> ChromaDBClient:
|
||||||
"""Create a ChromaDBClient instance for testing."""
|
"""Create a ChromaDBClient instance for testing."""
|
||||||
client = ChromaDBClient()
|
mock_embedding = Mock()
|
||||||
client.client = mock_chromadb_client
|
client = ChromaDBClient(
|
||||||
client.embedding_function = Mock()
|
client=mock_chromadb_client, embedding_function=mock_embedding
|
||||||
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
|
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
|
||||||
"""Create a ChromaDBClient instance with async client for testing."""
|
"""Create a ChromaDBClient instance with async client for testing."""
|
||||||
client = ChromaDBClient()
|
mock_embedding = Mock()
|
||||||
client.client = mock_async_chromadb_client
|
client = ChromaDBClient(
|
||||||
client.embedding_function = Mock()
|
client=mock_async_chromadb_client, embedding_function=mock_embedding
|
||||||
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from crewai.rag.config.factory import create_client
|
from crewai.rag.factory import create_client
|
||||||
|
|
||||||
|
|
||||||
def test_create_client_chromadb():
|
def test_create_client_chromadb():
|
||||||
@@ -10,7 +10,7 @@ def test_create_client_chromadb():
|
|||||||
mock_config = Mock()
|
mock_config = Mock()
|
||||||
mock_config.provider = "chromadb"
|
mock_config.provider = "chromadb"
|
||||||
|
|
||||||
with patch("crewai.rag.config.factory.require") as mock_require:
|
with patch("crewai.rag.factory.require") as mock_require:
|
||||||
mock_module = Mock()
|
mock_module = Mock()
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
mock_module.create_client.return_value = mock_client
|
mock_module.create_client.return_value = mock_client
|
||||||
|
|||||||
@@ -236,7 +236,6 @@ class TestQdrantClient:
|
|||||||
# Check upsert was called with correct parameters
|
# Check upsert was called with correct parameters
|
||||||
call_args = mock_qdrant_client.upsert.call_args
|
call_args = mock_qdrant_client.upsert.call_args
|
||||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||||
assert call_args.kwargs["wait"] is True
|
|
||||||
assert len(call_args.kwargs["points"]) == 1
|
assert len(call_args.kwargs["points"]) == 1
|
||||||
point = call_args.kwargs["points"][0]
|
point = call_args.kwargs["points"][0]
|
||||||
assert point.vector == [0.1, 0.2, 0.3]
|
assert point.vector == [0.1, 0.2, 0.3]
|
||||||
@@ -330,7 +329,6 @@ class TestQdrantClient:
|
|||||||
# Check upsert was called with correct parameters
|
# Check upsert was called with correct parameters
|
||||||
call_args = mock_async_qdrant_client.upsert.call_args
|
call_args = mock_async_qdrant_client.upsert.call_args
|
||||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||||
assert call_args.kwargs["wait"] is True
|
|
||||||
assert len(call_args.kwargs["points"]) == 1
|
assert len(call_args.kwargs["points"]) == 1
|
||||||
point = call_args.kwargs["points"][0]
|
point = call_args.kwargs["points"][0]
|
||||||
assert point.vector == [0.1, 0.2, 0.3]
|
assert point.vector == [0.1, 0.2, 0.3]
|
||||||
|
|||||||
Reference in New Issue
Block a user