mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
3 Commits
main
...
devin/1756
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
411285f5ef | ||
|
|
dce26e8276 | ||
|
|
e3a575920c |
@@ -14,7 +14,7 @@ class _MissingProvider:
|
||||
Raises RuntimeError when instantiated to indicate missing dependencies.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
|
||||
provider: Literal["chromadb", "qdrant", "elasticsearch", "__missing__"] = field(
|
||||
default="__missing__"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
|
||||
|
||||
class ChromaFactoryModule(Protocol):
|
||||
@@ -25,3 +27,11 @@ class QdrantFactoryModule(Protocol):
|
||||
def create_client(self, config: QdrantConfig) -> QdrantClient:
|
||||
"""Creates a Qdrant client from configuration."""
|
||||
...
|
||||
|
||||
|
||||
class ElasticsearchFactoryModule(Protocol):
|
||||
"""Protocol for Elasticsearch factory module."""
|
||||
|
||||
def create_client(self, config: ElasticsearchConfig) -> ElasticsearchClient:
|
||||
"""Creates an Elasticsearch client from configuration."""
|
||||
...
|
||||
|
||||
@@ -20,3 +20,10 @@ class MissingQdrantConfig(_MissingProvider):
|
||||
"""Placeholder for missing Qdrant configuration."""
|
||||
|
||||
provider: Literal["qdrant"] = field(default="qdrant")
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class MissingElasticsearchConfig(_MissingProvider):
|
||||
"""Placeholder for missing Elasticsearch configuration."""
|
||||
|
||||
provider: Literal["elasticsearch"] = field(default="elasticsearch")
|
||||
|
||||
@@ -3,6 +3,6 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
SupportedProvider = Annotated[
|
||||
Literal["chromadb", "qdrant"],
|
||||
Literal["chromadb", "qdrant", "elasticsearch"],
|
||||
"Supported RAG provider types, add providers here as they become available",
|
||||
]
|
||||
|
||||
@@ -13,6 +13,9 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
|
||||
|
||||
QdrantConfig = QdrantConfig_
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig as ElasticsearchConfig_
|
||||
|
||||
ElasticsearchConfig = ElasticsearchConfig_
|
||||
else:
|
||||
try:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
@@ -28,7 +31,14 @@ else:
|
||||
MissingQdrantConfig as QdrantConfig,
|
||||
)
|
||||
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
|
||||
try:
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
except ImportError:
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingElasticsearchConfig as ElasticsearchConfig,
|
||||
)
|
||||
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig | ElasticsearchConfig
|
||||
RagConfigType: TypeAlias = Annotated[
|
||||
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
||||
]
|
||||
|
||||
1
src/crewai/rag/elasticsearch/__init__.py
Normal file
1
src/crewai/rag/elasticsearch/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Elasticsearch RAG implementation."""
|
||||
502
src/crewai/rag/elasticsearch/client.py
Normal file
502
src/crewai/rag/elasticsearch/client.py
Normal file
@@ -0,0 +1,502 @@
|
||||
"""Elasticsearch client implementation."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.elasticsearch.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
ElasticsearchClientType,
|
||||
ElasticsearchCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.elasticsearch.utils import (
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_prepare_document_for_elasticsearch,
|
||||
_process_search_results,
|
||||
_build_vector_search_query,
|
||||
_get_index_mapping,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class ElasticsearchClient(BaseClient):
|
||||
"""Elasticsearch implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for Elasticsearch, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: Elasticsearch client instance (Elasticsearch or AsyncElasticsearch).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
vector_dimension: Dimension of the embedding vectors.
|
||||
similarity: Similarity function to use for vector search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ElasticsearchClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
vector_dimension: int = 384,
|
||||
similarity: str = "cosine",
|
||||
) -> None:
|
||||
"""Initialize ElasticsearchClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured Elasticsearch client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
vector_dimension: Dimension of the embedding vectors.
|
||||
similarity: Similarity function to use for vector search.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.vector_dimension = vector_dimension
|
||||
self.similarity = similarity
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
|
||||
"""Create a new index in Elasticsearch.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to create. Must be unique.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Raises:
|
||||
ValueError: If index with the same name already exists.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="create_collection",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="acreate_collection",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' already exists")
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
self.client.indices.create(index=collection_name, body=mapping)
|
||||
|
||||
async def acreate_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
|
||||
"""Create a new index in Elasticsearch asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to create. Must be unique.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Raises:
|
||||
ValueError: If index with the same name already exists.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="acreate_collection",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="create_collection",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' already exists")
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
await self.client.indices.create(index=collection_name, body=mapping)
|
||||
|
||||
def get_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
|
||||
"""Get an existing index or create it if it doesn't exist.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to get or create.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Returns:
|
||||
Index info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="get_or_create_collection",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="aget_or_create_collection",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.indices.exists(index=collection_name):
|
||||
return self.client.indices.get(index=collection_name)
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
self.client.indices.create(index=collection_name, body=mapping)
|
||||
return self.client.indices.get(index=collection_name)
|
||||
|
||||
async def aget_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
|
||||
"""Get an existing index or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to get or create.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Returns:
|
||||
Index info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aget_or_create_collection",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="get_or_create_collection",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.indices.exists(index=collection_name):
|
||||
return await self.client.indices.get(index=collection_name)
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
await self.client.indices.create(index=collection_name, body=mapping)
|
||||
return await self.client.indices.get(index=collection_name)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to an index.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the index to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="add_documents",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="aadd_documents",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
|
||||
|
||||
self.client.index(
|
||||
index=collection_name,
|
||||
id=prepared_doc["id"],
|
||||
body=prepared_doc["body"]
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to an index asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the index to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aadd_documents",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="add_documents",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
|
||||
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
|
||||
|
||||
await self.client.index(
|
||||
index=collection_name,
|
||||
id=prepared_doc["id"],
|
||||
body=prepared_doc["body"]
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="search",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="asearch",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync search. "
|
||||
"Use asearch instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_query = _build_vector_search_query(
|
||||
query_vector=query_embedding,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
response = self.client.search(index=collection_name, body=search_query)
|
||||
return _process_search_results(response, score_threshold)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="asearch",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="search",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
query_embedding = await async_fn(query)
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_query = _build_vector_search_query(
|
||||
query_vector=query_embedding,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
response = await self.client.search(index=collection_name, body=search_query)
|
||||
return _process_search_results(response, score_threshold)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete an index and all its data.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="delete_collection",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="adelete_collection",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
self.client.indices.delete(index=collection_name)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete an index and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="adelete_collection",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="delete_collection",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
await self.client.indices.delete(index=collection_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all indices and data.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="reset",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="areset",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
indices_response = self.client.indices.get(index="*")
|
||||
|
||||
for index_name in indices_response.keys():
|
||||
if not index_name.startswith("."):
|
||||
self.client.indices.delete(index=index_name)
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all indices and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="areset",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="reset",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
indices_response = await self.client.indices.get(index="*")
|
||||
|
||||
for index_name in indices_response.keys():
|
||||
if not index_name.startswith("."):
|
||||
await self.client.indices.delete(index=index_name)
|
||||
92
src/crewai/rag/elasticsearch/config.py
Normal file
92
src/crewai/rag/elasticsearch/config.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Elasticsearch configuration model."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.elasticsearch.types import (
|
||||
ElasticsearchClientParams,
|
||||
ElasticsearchEmbeddingFunctionWrapper,
|
||||
)
|
||||
from crewai.rag.elasticsearch.constants import (
|
||||
DEFAULT_HOST,
|
||||
DEFAULT_PORT,
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
)
|
||||
|
||||
|
||||
def _default_options() -> ElasticsearchClientParams:
|
||||
"""Create default Elasticsearch client options.
|
||||
|
||||
Returns:
|
||||
Default options with local Elasticsearch connection.
|
||||
"""
|
||||
return ElasticsearchClientParams(
|
||||
hosts=[f"http://{DEFAULT_HOST}:{DEFAULT_PORT}"],
|
||||
use_ssl=False,
|
||||
verify_certs=False,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedding_function() -> ElasticsearchEmbeddingFunctionWrapper:
|
||||
"""Create default Elasticsearch embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function using sentence-transformers.
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
model = SentenceTransformer(DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
def embed_fn(text: str) -> list[float]:
|
||||
"""Embed a single text string.
|
||||
|
||||
Args:
|
||||
text: Text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
embedding = model.encode(text, convert_to_tensor=False)
|
||||
return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
|
||||
|
||||
return cast(ElasticsearchEmbeddingFunctionWrapper, embed_fn)
|
||||
except ImportError:
|
||||
def fallback_embed_fn(text: str) -> list[float]:
|
||||
"""Fallback embedding function when sentence-transformers is not available."""
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
hash_obj = hashlib.md5(text.encode(), usedforsecurity=False)
|
||||
hash_bytes = hash_obj.digest()
|
||||
|
||||
vector = []
|
||||
for i in range(0, len(hash_bytes), 4):
|
||||
chunk = hash_bytes[i:i+4]
|
||||
if len(chunk) == 4:
|
||||
value = struct.unpack('f', chunk)[0]
|
||||
vector.append(float(value))
|
||||
|
||||
while len(vector) < DEFAULT_VECTOR_DIMENSION:
|
||||
vector.extend(vector[:DEFAULT_VECTOR_DIMENSION - len(vector)])
|
||||
|
||||
return vector[:DEFAULT_VECTOR_DIMENSION]
|
||||
|
||||
return cast(ElasticsearchEmbeddingFunctionWrapper, fallback_embed_fn)
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class ElasticsearchConfig(BaseRagConfig):
|
||||
"""Configuration for Elasticsearch client."""
|
||||
|
||||
provider: Literal["elasticsearch"] = field(default="elasticsearch", init=False)
|
||||
options: ElasticsearchClientParams = field(default_factory=_default_options)
|
||||
vector_dimension: int = DEFAULT_VECTOR_DIMENSION
|
||||
similarity: str = "cosine"
|
||||
embedding_function: ElasticsearchEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
12
src/crewai/rag/elasticsearch/constants.py
Normal file
12
src/crewai/rag/elasticsearch/constants.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Constants for Elasticsearch RAG implementation."""
|
||||
|
||||
from typing import Final
|
||||
|
||||
DEFAULT_HOST: Final[str] = "localhost"
|
||||
DEFAULT_PORT: Final[int] = 9200
|
||||
DEFAULT_INDEX_SETTINGS: Final[dict] = {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0,
|
||||
}
|
||||
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
DEFAULT_VECTOR_DIMENSION: Final[int] = 384
|
||||
31
src/crewai/rag/elasticsearch/factory.py
Normal file
31
src/crewai/rag/elasticsearch/factory.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Factory functions for creating Elasticsearch clients."""
|
||||
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
||||
|
||||
|
||||
def create_client(config: ElasticsearchConfig) -> ElasticsearchClient:
|
||||
"""Create an ElasticsearchClient from configuration.
|
||||
|
||||
Args:
|
||||
config: Elasticsearch configuration object.
|
||||
|
||||
Returns:
|
||||
Configured ElasticsearchClient instance.
|
||||
"""
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"elasticsearch package is required for Elasticsearch support. "
|
||||
"Install it with: pip install elasticsearch"
|
||||
) from e
|
||||
|
||||
client = Elasticsearch(**config.options)
|
||||
|
||||
return ElasticsearchClient(
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
vector_dimension=config.vector_dimension,
|
||||
similarity=config.similarity,
|
||||
)
|
||||
93
src/crewai/rag/elasticsearch/types.py
Normal file
93
src/crewai/rag/elasticsearch/types.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Type definitions for Elasticsearch RAG implementation."""
|
||||
|
||||
from typing import Any, Protocol, Union, TYPE_CHECKING
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
from elasticsearch import Elasticsearch, AsyncElasticsearch
|
||||
ElasticsearchClientType: TypeAlias = Union[Elasticsearch, AsyncElasticsearch]
|
||||
else:
|
||||
try:
|
||||
from elasticsearch import Elasticsearch, AsyncElasticsearch
|
||||
ElasticsearchClientType = Union[Elasticsearch, AsyncElasticsearch]
|
||||
except ImportError:
|
||||
ElasticsearchClientType = Any
|
||||
|
||||
|
||||
class ElasticsearchClientParams(TypedDict, total=False):
|
||||
"""Parameters for Elasticsearch client initialization."""
|
||||
|
||||
hosts: NotRequired[list[str]]
|
||||
cloud_id: NotRequired[str]
|
||||
username: NotRequired[str]
|
||||
password: NotRequired[str]
|
||||
api_key: NotRequired[str]
|
||||
use_ssl: NotRequired[bool]
|
||||
verify_certs: NotRequired[bool]
|
||||
ca_certs: NotRequired[str]
|
||||
timeout: NotRequired[int]
|
||||
|
||||
|
||||
class ElasticsearchIndexSettings(TypedDict, total=False):
|
||||
"""Settings for Elasticsearch index creation."""
|
||||
|
||||
number_of_shards: NotRequired[int]
|
||||
number_of_replicas: NotRequired[int]
|
||||
refresh_interval: NotRequired[str]
|
||||
|
||||
|
||||
class ElasticsearchCollectionCreateParams(TypedDict, total=False):
|
||||
"""Parameters for creating Elasticsearch collections/indices."""
|
||||
|
||||
collection_name: str
|
||||
index_settings: NotRequired[ElasticsearchIndexSettings]
|
||||
vector_dimension: NotRequired[int]
|
||||
similarity: NotRequired[str]
|
||||
|
||||
|
||||
class EmbeddingFunction(Protocol):
|
||||
"""Protocol for embedding functions that convert text to vectors."""
|
||||
|
||||
def __call__(self, text: str) -> list[float]:
|
||||
"""Convert text to embedding vector.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AsyncEmbeddingFunction(Protocol):
|
||||
"""Protocol for async embedding functions that convert text to vectors."""
|
||||
|
||||
async def __call__(self, text: str) -> list[float]:
|
||||
"""Convert text to embedding vector asynchronously.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ElasticsearchEmbeddingFunctionWrapper(EmbeddingFunction):
|
||||
"""Base class for Elasticsearch EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for Elasticsearch EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle Elasticsearch's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
186
src/crewai/rag/elasticsearch/utils.py
Normal file
186
src/crewai/rag/elasticsearch/utils.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Utility functions for Elasticsearch RAG implementation."""
|
||||
|
||||
import hashlib
|
||||
from typing import Any, TypeGuard
|
||||
|
||||
from crewai.rag.elasticsearch.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
ElasticsearchClientType,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
try:
|
||||
from elasticsearch import Elasticsearch, AsyncElasticsearch
|
||||
except ImportError:
|
||||
Elasticsearch = None
|
||||
AsyncElasticsearch = None
|
||||
|
||||
|
||||
def _is_sync_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
|
||||
"""Type guard to check if the client is a sync Elasticsearch client."""
|
||||
if Elasticsearch is None:
|
||||
return False
|
||||
return isinstance(client, Elasticsearch)
|
||||
|
||||
|
||||
def _is_async_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
|
||||
"""Type guard to check if the client is an async Elasticsearch client."""
|
||||
if AsyncElasticsearch is None:
|
||||
return False
|
||||
return isinstance(client, AsyncElasticsearch)
|
||||
|
||||
|
||||
def _is_async_embedding_function(
|
||||
func: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
) -> TypeGuard[AsyncEmbeddingFunction]:
|
||||
"""Type guard to check if the embedding function is async."""
|
||||
import inspect
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _generate_doc_id(content: str) -> str:
|
||||
"""Generate a document ID from content using SHA256 hash."""
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def _prepare_document_for_elasticsearch(
|
||||
doc: BaseRecord, embedding: list[float]
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare a document for Elasticsearch indexing.
|
||||
|
||||
Args:
|
||||
doc: Document record to prepare.
|
||||
embedding: Embedding vector for the document.
|
||||
|
||||
Returns:
|
||||
Document formatted for Elasticsearch.
|
||||
"""
|
||||
doc_id = doc.get("doc_id") or _generate_doc_id(doc["content"])
|
||||
|
||||
es_doc = {
|
||||
"content": doc["content"],
|
||||
"content_vector": embedding,
|
||||
"metadata": doc.get("metadata", {}),
|
||||
}
|
||||
|
||||
return {"id": doc_id, "body": es_doc}
|
||||
|
||||
|
||||
def _process_search_results(
|
||||
response: dict[str, Any], score_threshold: float | None = None
|
||||
) -> list[SearchResult]:
|
||||
"""Process Elasticsearch search response into SearchResult format.
|
||||
|
||||
Args:
|
||||
response: Raw Elasticsearch search response.
|
||||
score_threshold: Optional minimum score threshold.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dictionaries.
|
||||
"""
|
||||
results = []
|
||||
|
||||
hits = response.get("hits", {}).get("hits", [])
|
||||
|
||||
for hit in hits:
|
||||
score = hit.get("_score", 0.0)
|
||||
|
||||
if score_threshold is not None and score < score_threshold:
|
||||
continue
|
||||
|
||||
source = hit.get("_source", {})
|
||||
|
||||
result = SearchResult(
|
||||
id=hit.get("_id", ""),
|
||||
content=source.get("content", ""),
|
||||
metadata=source.get("metadata", {}),
|
||||
score=score,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _build_vector_search_query(
|
||||
query_vector: list[float],
|
||||
limit: int = 10,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build Elasticsearch query for vector similarity search.
|
||||
|
||||
Args:
|
||||
query_vector: Query embedding vector.
|
||||
limit: Maximum number of results.
|
||||
metadata_filter: Optional metadata filter.
|
||||
score_threshold: Optional minimum score threshold.
|
||||
|
||||
Returns:
|
||||
Elasticsearch query dictionary.
|
||||
"""
|
||||
query = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
|
||||
"params": {"query_vector": query_vector}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if metadata_filter:
|
||||
bool_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query["query"]
|
||||
],
|
||||
"filter": []
|
||||
}
|
||||
}
|
||||
|
||||
for key, value in metadata_filter.items():
|
||||
bool_query["bool"]["filter"].append({
|
||||
"term": {f"metadata.{key}": value}
|
||||
})
|
||||
|
||||
query["query"] = bool_query
|
||||
|
||||
if score_threshold is not None:
|
||||
query["min_score"] = score_threshold
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def _get_index_mapping(vector_dimension: int, similarity: str = "cosine") -> dict[str, Any]:
|
||||
"""Get Elasticsearch index mapping for vector search.
|
||||
|
||||
Args:
|
||||
vector_dimension: Dimension of the embedding vectors.
|
||||
similarity: Similarity function to use.
|
||||
|
||||
Returns:
|
||||
Elasticsearch mapping dictionary.
|
||||
"""
|
||||
return {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "text",
|
||||
"analyzer": "standard"
|
||||
},
|
||||
"content_vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": vector_dimension,
|
||||
"similarity": similarity
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"dynamic": True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ from typing import cast
|
||||
from crewai.rag.config.optional_imports.protocols import (
|
||||
ChromaFactoryModule,
|
||||
QdrantFactoryModule,
|
||||
ElasticsearchFactoryModule,
|
||||
)
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
@@ -43,3 +44,15 @@ def create_client(config: RagConfigType) -> BaseClient:
|
||||
),
|
||||
)
|
||||
return qdrant_mod.create_client(config)
|
||||
|
||||
if config.provider == "elasticsearch":
|
||||
elasticsearch_mod = cast(
|
||||
ElasticsearchFactoryModule,
|
||||
require(
|
||||
"crewai.rag.elasticsearch.factory",
|
||||
purpose="The 'elasticsearch' provider",
|
||||
),
|
||||
)
|
||||
return elasticsearch_mod.create_client(config)
|
||||
|
||||
raise ValueError(f"Unsupported provider: {config.provider}")
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.factory import create_client
|
||||
|
||||
|
||||
@@ -25,10 +27,50 @@ def test_create_client_chromadb():
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_qdrant():
|
||||
"""Test Qdrant client creation."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "qdrant"
|
||||
|
||||
with patch("crewai.rag.factory.require") as mock_require:
|
||||
mock_module = Mock()
|
||||
mock_client = Mock()
|
||||
mock_module.create_client.return_value = mock_client
|
||||
mock_require.return_value = mock_module
|
||||
|
||||
result = create_client(mock_config)
|
||||
|
||||
assert result == mock_client
|
||||
mock_require.assert_called_once_with(
|
||||
"crewai.rag.qdrant.factory", purpose="The 'qdrant' provider"
|
||||
)
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_elasticsearch():
|
||||
"""Test Elasticsearch client creation."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "elasticsearch"
|
||||
|
||||
with patch("crewai.rag.factory.require") as mock_require:
|
||||
mock_module = Mock()
|
||||
mock_client = Mock()
|
||||
mock_module.create_client.return_value = mock_client
|
||||
mock_require.return_value = mock_module
|
||||
|
||||
result = create_client(mock_config)
|
||||
|
||||
assert result == mock_client
|
||||
mock_require.assert_called_once_with(
|
||||
"crewai.rag.elasticsearch.factory", purpose="The 'elasticsearch' provider"
|
||||
)
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_unsupported_provider():
|
||||
"""Test unsupported provider returns None for now."""
|
||||
"""Test that unsupported provider raises ValueError."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "unsupported"
|
||||
|
||||
result = create_client(mock_config)
|
||||
assert result is None
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
|
||||
create_client(mock_config)
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
import pytest
|
||||
|
||||
from crewai.rag.config.optional_imports.base import _MissingProvider
|
||||
from crewai.rag.config.optional_imports.providers import MissingChromaDBConfig
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingChromaDBConfig,
|
||||
MissingElasticsearchConfig,
|
||||
)
|
||||
|
||||
|
||||
def test_missing_provider_raises_runtime_error():
|
||||
@@ -20,3 +23,11 @@ def test_missing_chromadb_config_raises_runtime_error():
|
||||
RuntimeError, match="provider 'chromadb' requested but not installed"
|
||||
):
|
||||
MissingChromaDBConfig()
|
||||
|
||||
|
||||
def test_missing_elasticsearch_config_raises_runtime_error():
|
||||
"""Test that MissingElasticsearchConfig raises RuntimeError on instantiation."""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="provider 'elasticsearch' requested but not installed"
|
||||
):
|
||||
MissingElasticsearchConfig()
|
||||
|
||||
1
tests/rag/elasticsearch/__init__.py
Normal file
1
tests/rag/elasticsearch/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for Elasticsearch RAG implementation."""
|
||||
397
tests/rag/elasticsearch/test_client.py
Normal file
397
tests/rag/elasticsearch/test_client.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Tests for ElasticsearchClient implementation."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
||||
from crewai.rag.types import BaseRecord
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_elasticsearch_client():
|
||||
"""Create a mock Elasticsearch client."""
|
||||
mock_client = Mock()
|
||||
mock_client.indices = Mock()
|
||||
mock_client.indices.exists.return_value = False
|
||||
mock_client.indices.create.return_value = {"acknowledged": True}
|
||||
mock_client.indices.get.return_value = {"test_index": {"mappings": {}}}
|
||||
mock_client.indices.delete.return_value = {"acknowledged": True}
|
||||
mock_client.index.return_value = {"_id": "test_id", "result": "created"}
|
||||
mock_client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "doc1",
|
||||
"_score": 0.9,
|
||||
"_source": {
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_elasticsearch_client():
|
||||
"""Create a mock async Elasticsearch client."""
|
||||
mock_client = Mock()
|
||||
mock_client.indices = Mock()
|
||||
mock_client.indices.exists = AsyncMock(return_value=False)
|
||||
mock_client.indices.create = AsyncMock(return_value={"acknowledged": True})
|
||||
mock_client.indices.get = AsyncMock(return_value={"test_index": {"mappings": {}}})
|
||||
mock_client.indices.delete = AsyncMock(return_value={"acknowledged": True})
|
||||
mock_client.index = AsyncMock(return_value={"_id": "test_id", "result": "created"})
|
||||
mock_client.search = AsyncMock(return_value={
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "doc1",
|
||||
"_score": 0.9,
|
||||
"_source": {
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_elasticsearch_client) -> ElasticsearchClient:
|
||||
"""Create an ElasticsearchClient instance for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
client = ElasticsearchClient(
|
||||
client=mock_elasticsearch_client,
|
||||
embedding_function=mock_embedding,
|
||||
vector_dimension=3,
|
||||
similarity="cosine"
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_elasticsearch_client) -> ElasticsearchClient:
|
||||
"""Create an ElasticsearchClient instance with async client for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
client = ElasticsearchClient(
|
||||
client=mock_async_elasticsearch_client,
|
||||
embedding_function=mock_embedding,
|
||||
vector_dimension=3,
|
||||
similarity="cosine"
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class TestElasticsearchClient:
|
||||
"""Test suite for ElasticsearchClient."""
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that create_collection creates a new index."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
client.create_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.create.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.indices.create.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "mappings" in call_args.kwargs["body"]
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_create_collection_already_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that create_collection raises error if index exists."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Index 'test_index' already exists"
|
||||
):
|
||||
client.create_collection(collection_name="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
def test_create_collection_wrong_client_type(self, mock_is_async, mock_is_sync, mock_async_elasticsearch_client):
|
||||
"""Test that create_collection raises error with async client."""
|
||||
mock_embedding = Mock()
|
||||
client = ElasticsearchClient(
|
||||
client=mock_async_elasticsearch_client,
|
||||
embedding_function=mock_embedding
|
||||
)
|
||||
|
||||
with pytest.raises(ClientMethodMismatchError):
|
||||
client.create_collection(collection_name="test_index")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_acreate_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that acreate_collection creates a new index asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
await async_client.acreate_collection(collection_name="test_index")
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.indices.create.assert_called_once()
|
||||
call_args = mock_async_elasticsearch_client.indices.create.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "mappings" in call_args.kwargs["body"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_acreate_collection_already_exists(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that acreate_collection raises error if index exists."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Index 'test_index' already exists"
|
||||
):
|
||||
await async_client.acreate_collection(collection_name="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_get_or_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that get_or_create_collection returns existing index."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
||||
assert result == {"test_index": {"mappings": {}}}
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_get_or_create_collection_creates_new(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that get_or_create_collection creates new index if not exists."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
client.get_or_create_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.create.assert_called_once()
|
||||
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_aget_or_create_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that aget_or_create_collection returns existing index asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
result = await async_client.aget_or_create_collection(collection_name="test_index")
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
||||
assert result == {"test_index": {"mappings": {}}}
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_add_documents(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that add_documents indexes documents correctly."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_index", documents=documents)
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.index.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.index.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "body" in call_args.kwargs
|
||||
assert call_args.kwargs["body"]["content"] == "test content"
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_add_documents_empty_list_raises_error(self, mock_is_async, mock_is_sync, client):
|
||||
"""Test that add_documents raises error with empty documents list."""
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
client.add_documents(collection_name="test_index", documents=[])
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_add_documents_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that add_documents raises error if index doesn't exist."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
documents: list[BaseRecord] = [{"content": "test content"}]
|
||||
|
||||
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
||||
client.add_documents(collection_name="test_index", documents=documents)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_aadd_documents(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that aadd_documents indexes documents correctly asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(collection_name="test_index", documents=documents)
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.index.assert_called_once()
|
||||
call_args = mock_async_elasticsearch_client.index.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "body" in call_args.kwargs
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_search(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that search performs vector similarity search."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
results = client.search(
|
||||
collection_name="test_index",
|
||||
query="test query",
|
||||
limit=5
|
||||
)
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.search.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.search.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "body" in call_args.kwargs
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "test content"
|
||||
assert results[0]["score"] == 0.9
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_search_with_metadata_filter(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that search applies metadata filter correctly."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
client.search(
|
||||
collection_name="test_index",
|
||||
query="test query",
|
||||
metadata_filter={"key": "value"}
|
||||
)
|
||||
|
||||
mock_elasticsearch_client.search.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.search.call_args
|
||||
query_body = call_args.kwargs["body"]
|
||||
assert "bool" in query_body["query"]
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_search_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that search raises error if index doesn't exist."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
||||
client.search(collection_name="test_index", query="test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_asearch(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that asearch performs vector similarity search asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_index",
|
||||
query="test query",
|
||||
limit=5
|
||||
)
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.search.assert_called_once()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "test content"
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_delete_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that delete_collection deletes the index."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
client.delete_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_delete_collection_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that delete_collection raises error if index doesn't exist."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
||||
client.delete_collection(collection_name="test_index")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_adelete_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that adelete_collection deletes the index asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
await async_client.adelete_collection(collection_name="test_index")
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_reset(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that reset deletes all non-system indices."""
|
||||
mock_elasticsearch_client.indices.get.return_value = {
|
||||
"test_index": {},
|
||||
".system_index": {},
|
||||
"another_index": {}
|
||||
}
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_elasticsearch_client.indices.get.assert_called_once_with(index="*")
|
||||
assert mock_elasticsearch_client.indices.delete.call_count == 2
|
||||
delete_calls = [call.kwargs["index"] for call in mock_elasticsearch_client.indices.delete.call_args_list]
|
||||
assert "test_index" in delete_calls
|
||||
assert "another_index" in delete_calls
|
||||
assert ".system_index" not in delete_calls
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_areset(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that areset deletes all non-system indices asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.get.return_value = {
|
||||
"test_index": {},
|
||||
".system_index": {},
|
||||
"another_index": {}
|
||||
}
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="*")
|
||||
assert mock_async_elasticsearch_client.indices.delete.call_count == 2
|
||||
49
tests/rag/elasticsearch/test_config.py
Normal file
49
tests/rag/elasticsearch/test_config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Tests for Elasticsearch configuration."""
|
||||
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
|
||||
|
||||
def test_elasticsearch_config_defaults():
|
||||
"""Test that ElasticsearchConfig has correct defaults."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
assert config.provider == "elasticsearch"
|
||||
assert config.vector_dimension == 384
|
||||
assert config.similarity == "cosine"
|
||||
assert config.embedding_function is not None
|
||||
assert config.options["hosts"] == ["http://localhost:9200"]
|
||||
assert config.options["use_ssl"] is False
|
||||
|
||||
|
||||
def test_elasticsearch_config_custom_options():
|
||||
"""Test that ElasticsearchConfig accepts custom options."""
|
||||
custom_options = {
|
||||
"hosts": ["https://elastic.example.com:9200"],
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"use_ssl": True,
|
||||
}
|
||||
|
||||
config = ElasticsearchConfig(
|
||||
options=custom_options,
|
||||
vector_dimension=768,
|
||||
similarity="dot_product"
|
||||
)
|
||||
|
||||
assert config.provider == "elasticsearch"
|
||||
assert config.vector_dimension == 768
|
||||
assert config.similarity == "dot_product"
|
||||
assert config.options["hosts"] == ["https://elastic.example.com:9200"]
|
||||
assert config.options["username"] == "user"
|
||||
assert config.options["use_ssl"] is True
|
||||
|
||||
|
||||
def test_elasticsearch_config_embedding_function():
|
||||
"""Test that embedding function works correctly."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
embedding = config.embedding_function("test text")
|
||||
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) == config.vector_dimension
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
40
tests/rag/elasticsearch/test_factory.py
Normal file
40
tests/rag/elasticsearch/test_factory.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Tests for Elasticsearch factory."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
|
||||
|
||||
def test_create_client():
|
||||
"""Test that create_client creates an ElasticsearchClient."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
with patch.dict('sys.modules', {'elasticsearch': Mock()}):
|
||||
mock_elasticsearch_module = Mock()
|
||||
mock_client_instance = Mock()
|
||||
mock_elasticsearch_module.Elasticsearch.return_value = mock_client_instance
|
||||
|
||||
with patch.dict('sys.modules', {'elasticsearch': mock_elasticsearch_module}):
|
||||
from crewai.rag.elasticsearch.factory import create_client
|
||||
client = create_client(config)
|
||||
|
||||
mock_elasticsearch_module.Elasticsearch.assert_called_once_with(**config.options)
|
||||
assert client.client == mock_client_instance
|
||||
assert client.embedding_function == config.embedding_function
|
||||
assert client.vector_dimension == config.vector_dimension
|
||||
assert client.similarity == config.similarity
|
||||
|
||||
|
||||
def test_create_client_missing_elasticsearch():
|
||||
"""Test that create_client raises ImportError when elasticsearch is not installed."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
with patch.dict('sys.modules', {}, clear=False):
|
||||
if 'elasticsearch' in __import__('sys').modules:
|
||||
del __import__('sys').modules['elasticsearch']
|
||||
|
||||
from crewai.rag.elasticsearch.factory import create_client
|
||||
with pytest.raises(ImportError, match="elasticsearch package is required"):
|
||||
create_client(config)
|
||||
Reference in New Issue
Block a user