mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
Compare commits
5 Commits
devin/1756
...
devin/1756
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b42d2e8cf0 | ||
|
|
36de68ecd4 | ||
|
|
109de91d08 | ||
|
|
92b70e652d | ||
|
|
fc3f2c49d2 |
@@ -1,13 +1,13 @@
|
||||
---
|
||||
title: Weaviate Vector Search
|
||||
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents.
|
||||
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents using hybrid search.
|
||||
icon: network-wired
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector embeddings for more accurate and contextually relevant search results.
|
||||
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector and keyword search for more accurate and contextually relevant search results.
|
||||
|
||||
[Weaviate](https://weaviate.io/) is a vector database that stores and queries vector embeddings, enabling semantic search capabilities.
|
||||
|
||||
@@ -39,6 +39,7 @@ from crewai_tools import WeaviateVectorSearchTool
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
@@ -63,6 +64,7 @@ The `WeaviateVectorSearchTool` accepts the following parameters:
|
||||
- **weaviate_cluster_url**: Required. The URL of the Weaviate cluster.
|
||||
- **weaviate_api_key**: Required. The API key for the Weaviate cluster.
|
||||
- **limit**: Optional. The number of results to return. Default is `3`.
|
||||
- **alpha**: Optional. Controls the weighting between vector and keyword (BM25) search. alpha = 0 -> BM25 only, alpha = 1 -> vector search only. Default is `0.75`.
|
||||
- **vectorizer**: Optional. The vectorizer to use. If not provided, it will use `text2vec_openai` with the `nomic-embed-text` model.
|
||||
- **generative_model**: Optional. The generative model to use. If not provided, it will use OpenAI's `gpt-4o`.
|
||||
|
||||
@@ -78,6 +80,7 @@ from weaviate.classes.config import Configure
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
vectorizer=Configure.Vectorizer.text2vec_openai(model="nomic-embed-text"),
|
||||
generative_model=Configure.Generative.openai(model="gpt-4o-mini"),
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
@@ -128,6 +131,7 @@ with test_docs.batch.dynamic() as batch:
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
@@ -145,6 +149,7 @@ from crewai_tools import WeaviateVectorSearchTool
|
||||
weaviate_tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"pdfplumber>=0.11.4",
|
||||
"regex>=2024.9.11",
|
||||
# Telemetry and Monitoring
|
||||
"opentelemetry-api>=1.30.0",
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
||||
"opentelemetry-api>=1.27.0,<1.28.0",
|
||||
"opentelemetry-sdk>=1.27.0,<1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.27.0,<1.28.0",
|
||||
# Data Handling
|
||||
"chromadb>=0.5.23",
|
||||
"tokenizers>=0.20.3",
|
||||
|
||||
@@ -98,8 +98,8 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
self.crew._long_term_memory.save(long_term_memory)
|
||||
|
||||
for entity in evaluation.entities:
|
||||
entity_memory = EntityMemoryItem(
|
||||
entity_memories = [
|
||||
EntityMemoryItem(
|
||||
name=entity.name,
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
@@ -107,7 +107,10 @@ class CrewAgentExecutorMixin:
|
||||
[f"- {r}" for r in entity.relationships]
|
||||
),
|
||||
)
|
||||
self.crew._entity_memory.save(entity_memory)
|
||||
for entity in evaluation.entities
|
||||
]
|
||||
if entity_memories:
|
||||
self.crew._entity_memory.save(entity_memories)
|
||||
except AttributeError as e:
|
||||
print(f"Missing attributes for long term memory: {e}")
|
||||
pass
|
||||
|
||||
@@ -1,6 +1 @@
|
||||
ALGORITHMS = ["RS256"]
|
||||
|
||||
#TODO: The AUTH0 constants should be removed after WorkOS migration is completed
|
||||
AUTH0_DOMAIN = "crewai.us.auth0.com"
|
||||
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
|
||||
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"
|
||||
|
||||
@@ -9,14 +9,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from .utils import validate_jwt_token
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from urllib.parse import quote
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.authentication.constants import (
|
||||
AUTH0_AUDIENCE,
|
||||
AUTH0_CLIENT_ID,
|
||||
AUTH0_DOMAIN,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -72,18 +65,6 @@ class AuthenticationCommand:
|
||||
"""Sign up to CrewAI+"""
|
||||
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
|
||||
|
||||
# TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete.
|
||||
user_provider = self._determine_user_provider()
|
||||
if user_provider == "auth0":
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
client_id=AUTH0_CLIENT_ID,
|
||||
domain=AUTH0_DOMAIN,
|
||||
audience=AUTH0_AUDIENCE,
|
||||
)
|
||||
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
||||
# End of temporary code.
|
||||
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
|
||||
@@ -206,30 +187,3 @@ class AuthenticationCommand:
|
||||
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
)
|
||||
|
||||
# TODO: WORKOS - This method is temporary until migration to WorkOS is complete.
|
||||
def _determine_user_provider(self) -> str:
|
||||
"""Determine which provider to use for authentication."""
|
||||
|
||||
console.print(
|
||||
"Enter your CrewAI Enterprise account email: ", style="bold blue", end=""
|
||||
)
|
||||
email = input()
|
||||
email_encoded = quote(email)
|
||||
|
||||
# It's not correct to call this method directly, but it's temporary until migration is complete.
|
||||
response = PlusAPI("")._make_request(
|
||||
"GET", f"/crewai_plus/api/v1/me/provider?email={email_encoded}"
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
if response.json().get("provider") == "auth0":
|
||||
return "auth0"
|
||||
else:
|
||||
return "workos"
|
||||
else:
|
||||
console.print(
|
||||
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
|
||||
style="red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any
|
||||
import time
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
@@ -24,7 +24,7 @@ class EntityMemory(Memory):
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
@@ -53,12 +53,33 @@ class EntityMemory(Memory):
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
def save(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Saves one or more entity items into the SQLite storage.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (included for supertype compatibility but not used).
|
||||
|
||||
Notes:
|
||||
The metadata parameter is included to satisfy the supertype signature but is not
|
||||
used - entity metadata is extracted from the EntityMemoryItem objects themselves.
|
||||
"""
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=item.metadata,
|
||||
metadata=metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -66,36 +87,61 @@ class EntityMemory(Memory):
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors = []
|
||||
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
for item in items:
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super().save(data, item.metadata)
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
|
||||
super().save(data, item.metadata)
|
||||
|
||||
# Emit memory save completed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=data,
|
||||
metadata=item.metadata,
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=item.metadata,
|
||||
metadata=fail_metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
|
||||
@@ -14,7 +14,7 @@ class _MissingProvider:
|
||||
Raises RuntimeError when instantiated to indicate missing dependencies.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "qdrant", "elasticsearch", "__missing__"] = field(
|
||||
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
|
||||
default="__missing__"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
|
||||
|
||||
class ChromaFactoryModule(Protocol):
|
||||
@@ -27,11 +25,3 @@ class QdrantFactoryModule(Protocol):
|
||||
def create_client(self, config: QdrantConfig) -> QdrantClient:
|
||||
"""Creates a Qdrant client from configuration."""
|
||||
...
|
||||
|
||||
|
||||
class ElasticsearchFactoryModule(Protocol):
|
||||
"""Protocol for Elasticsearch factory module."""
|
||||
|
||||
def create_client(self, config: ElasticsearchConfig) -> ElasticsearchClient:
|
||||
"""Creates an Elasticsearch client from configuration."""
|
||||
...
|
||||
|
||||
@@ -20,10 +20,3 @@ class MissingQdrantConfig(_MissingProvider):
|
||||
"""Placeholder for missing Qdrant configuration."""
|
||||
|
||||
provider: Literal["qdrant"] = field(default="qdrant")
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class MissingElasticsearchConfig(_MissingProvider):
|
||||
"""Placeholder for missing Elasticsearch configuration."""
|
||||
|
||||
provider: Literal["elasticsearch"] = field(default="elasticsearch")
|
||||
|
||||
@@ -3,6 +3,6 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
SupportedProvider = Annotated[
|
||||
Literal["chromadb", "qdrant", "elasticsearch"],
|
||||
Literal["chromadb", "qdrant"],
|
||||
"Supported RAG provider types, add providers here as they become available",
|
||||
]
|
||||
|
||||
@@ -13,9 +13,6 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
|
||||
|
||||
QdrantConfig = QdrantConfig_
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig as ElasticsearchConfig_
|
||||
|
||||
ElasticsearchConfig = ElasticsearchConfig_
|
||||
else:
|
||||
try:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
@@ -31,14 +28,7 @@ else:
|
||||
MissingQdrantConfig as QdrantConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
except ImportError:
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingElasticsearchConfig as ElasticsearchConfig,
|
||||
)
|
||||
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig | ElasticsearchConfig
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
|
||||
RagConfigType: TypeAlias = Annotated[
|
||||
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
||||
]
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Elasticsearch RAG implementation."""
|
||||
@@ -1,502 +0,0 @@
|
||||
"""Elasticsearch client implementation."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.elasticsearch.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
ElasticsearchClientType,
|
||||
ElasticsearchCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.elasticsearch.utils import (
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_prepare_document_for_elasticsearch,
|
||||
_process_search_results,
|
||||
_build_vector_search_query,
|
||||
_get_index_mapping,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class ElasticsearchClient(BaseClient):
|
||||
"""Elasticsearch implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for Elasticsearch, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: Elasticsearch client instance (Elasticsearch or AsyncElasticsearch).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
vector_dimension: Dimension of the embedding vectors.
|
||||
similarity: Similarity function to use for vector search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ElasticsearchClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
vector_dimension: int = 384,
|
||||
similarity: str = "cosine",
|
||||
) -> None:
|
||||
"""Initialize ElasticsearchClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured Elasticsearch client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
vector_dimension: Dimension of the embedding vectors.
|
||||
similarity: Similarity function to use for vector search.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.vector_dimension = vector_dimension
|
||||
self.similarity = similarity
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
|
||||
"""Create a new index in Elasticsearch.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to create. Must be unique.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Raises:
|
||||
ValueError: If index with the same name already exists.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="create_collection",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="acreate_collection",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' already exists")
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
self.client.indices.create(index=collection_name, body=mapping)
|
||||
|
||||
async def acreate_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
|
||||
"""Create a new index in Elasticsearch asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to create. Must be unique.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Raises:
|
||||
ValueError: If index with the same name already exists.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="acreate_collection",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="create_collection",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' already exists")
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
await self.client.indices.create(index=collection_name, body=mapping)
|
||||
|
||||
def get_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
|
||||
"""Get an existing index or create it if it doesn't exist.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to get or create.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Returns:
|
||||
Index info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="get_or_create_collection",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="aget_or_create_collection",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.indices.exists(index=collection_name):
|
||||
return self.client.indices.get(index=collection_name)
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
self.client.indices.create(index=collection_name, body=mapping)
|
||||
return self.client.indices.get(index=collection_name)
|
||||
|
||||
async def aget_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
|
||||
"""Get an existing index or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to get or create.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Returns:
|
||||
Index info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aget_or_create_collection",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="get_or_create_collection",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.indices.exists(index=collection_name):
|
||||
return await self.client.indices.get(index=collection_name)
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
await self.client.indices.create(index=collection_name, body=mapping)
|
||||
return await self.client.indices.get(index=collection_name)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to an index.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the index to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="add_documents",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="aadd_documents",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
|
||||
|
||||
self.client.index(
|
||||
index=collection_name,
|
||||
id=prepared_doc["id"],
|
||||
body=prepared_doc["body"]
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to an index asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the index to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aadd_documents",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="add_documents",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
|
||||
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
|
||||
|
||||
await self.client.index(
|
||||
index=collection_name,
|
||||
id=prepared_doc["id"],
|
||||
body=prepared_doc["body"]
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="search",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="asearch",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync search. "
|
||||
"Use asearch instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_query = _build_vector_search_query(
|
||||
query_vector=query_embedding,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
response = self.client.search(index=collection_name, body=search_query)
|
||||
return _process_search_results(response, score_threshold)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="asearch",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="search",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
query_embedding = await async_fn(query)
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_query = _build_vector_search_query(
|
||||
query_vector=query_embedding,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
response = await self.client.search(index=collection_name, body=search_query)
|
||||
return _process_search_results(response, score_threshold)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete an index and all its data.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="delete_collection",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="adelete_collection",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
self.client.indices.delete(index=collection_name)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete an index and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="adelete_collection",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="delete_collection",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
await self.client.indices.delete(index=collection_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all indices and data.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="reset",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="areset",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
indices_response = self.client.indices.get(index="*")
|
||||
|
||||
for index_name in indices_response.keys():
|
||||
if not index_name.startswith("."):
|
||||
self.client.indices.delete(index=index_name)
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all indices and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="areset",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="reset",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
indices_response = await self.client.indices.get(index="*")
|
||||
|
||||
for index_name in indices_response.keys():
|
||||
if not index_name.startswith("."):
|
||||
await self.client.indices.delete(index=index_name)
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Elasticsearch configuration model."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.elasticsearch.types import (
|
||||
ElasticsearchClientParams,
|
||||
ElasticsearchEmbeddingFunctionWrapper,
|
||||
)
|
||||
from crewai.rag.elasticsearch.constants import (
|
||||
DEFAULT_HOST,
|
||||
DEFAULT_PORT,
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
)
|
||||
|
||||
|
||||
def _default_options() -> ElasticsearchClientParams:
|
||||
"""Create default Elasticsearch client options.
|
||||
|
||||
Returns:
|
||||
Default options with local Elasticsearch connection.
|
||||
"""
|
||||
return ElasticsearchClientParams(
|
||||
hosts=[f"http://{DEFAULT_HOST}:{DEFAULT_PORT}"],
|
||||
use_ssl=False,
|
||||
verify_certs=False,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedding_function() -> ElasticsearchEmbeddingFunctionWrapper:
|
||||
"""Create default Elasticsearch embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function using sentence-transformers.
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
model = SentenceTransformer(DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
def embed_fn(text: str) -> list[float]:
|
||||
"""Embed a single text string.
|
||||
|
||||
Args:
|
||||
text: Text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
embedding = model.encode(text, convert_to_tensor=False)
|
||||
return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
|
||||
|
||||
return cast(ElasticsearchEmbeddingFunctionWrapper, embed_fn)
|
||||
except ImportError:
|
||||
def fallback_embed_fn(text: str) -> list[float]:
|
||||
"""Fallback embedding function when sentence-transformers is not available."""
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
hash_obj = hashlib.md5(text.encode(), usedforsecurity=False)
|
||||
hash_bytes = hash_obj.digest()
|
||||
|
||||
vector = []
|
||||
for i in range(0, len(hash_bytes), 4):
|
||||
chunk = hash_bytes[i:i+4]
|
||||
if len(chunk) == 4:
|
||||
value = struct.unpack('f', chunk)[0]
|
||||
vector.append(float(value))
|
||||
|
||||
while len(vector) < DEFAULT_VECTOR_DIMENSION:
|
||||
vector.extend(vector[:DEFAULT_VECTOR_DIMENSION - len(vector)])
|
||||
|
||||
return vector[:DEFAULT_VECTOR_DIMENSION]
|
||||
|
||||
return cast(ElasticsearchEmbeddingFunctionWrapper, fallback_embed_fn)
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class ElasticsearchConfig(BaseRagConfig):
|
||||
"""Configuration for Elasticsearch client."""
|
||||
|
||||
provider: Literal["elasticsearch"] = field(default="elasticsearch", init=False)
|
||||
options: ElasticsearchClientParams = field(default_factory=_default_options)
|
||||
vector_dimension: int = DEFAULT_VECTOR_DIMENSION
|
||||
similarity: str = "cosine"
|
||||
embedding_function: ElasticsearchEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Constants for Elasticsearch RAG implementation."""
|
||||
|
||||
from typing import Final
|
||||
|
||||
DEFAULT_HOST: Final[str] = "localhost"
|
||||
DEFAULT_PORT: Final[int] = 9200
|
||||
DEFAULT_INDEX_SETTINGS: Final[dict] = {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0,
|
||||
}
|
||||
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
DEFAULT_VECTOR_DIMENSION: Final[int] = 384
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Factory functions for creating Elasticsearch clients."""
|
||||
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
||||
|
||||
|
||||
def create_client(config: ElasticsearchConfig) -> ElasticsearchClient:
|
||||
"""Create an ElasticsearchClient from configuration.
|
||||
|
||||
Args:
|
||||
config: Elasticsearch configuration object.
|
||||
|
||||
Returns:
|
||||
Configured ElasticsearchClient instance.
|
||||
"""
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"elasticsearch package is required for Elasticsearch support. "
|
||||
"Install it with: pip install elasticsearch"
|
||||
) from e
|
||||
|
||||
client = Elasticsearch(**config.options)
|
||||
|
||||
return ElasticsearchClient(
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
vector_dimension=config.vector_dimension,
|
||||
similarity=config.similarity,
|
||||
)
|
||||
@@ -1,93 +0,0 @@
|
||||
"""Type definitions for Elasticsearch RAG implementation."""
|
||||
|
||||
from typing import Any, Protocol, Union, TYPE_CHECKING
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
from elasticsearch import Elasticsearch, AsyncElasticsearch
|
||||
ElasticsearchClientType: TypeAlias = Union[Elasticsearch, AsyncElasticsearch]
|
||||
else:
|
||||
try:
|
||||
from elasticsearch import Elasticsearch, AsyncElasticsearch
|
||||
ElasticsearchClientType = Union[Elasticsearch, AsyncElasticsearch]
|
||||
except ImportError:
|
||||
ElasticsearchClientType = Any
|
||||
|
||||
|
||||
class ElasticsearchClientParams(TypedDict, total=False):
|
||||
"""Parameters for Elasticsearch client initialization."""
|
||||
|
||||
hosts: NotRequired[list[str]]
|
||||
cloud_id: NotRequired[str]
|
||||
username: NotRequired[str]
|
||||
password: NotRequired[str]
|
||||
api_key: NotRequired[str]
|
||||
use_ssl: NotRequired[bool]
|
||||
verify_certs: NotRequired[bool]
|
||||
ca_certs: NotRequired[str]
|
||||
timeout: NotRequired[int]
|
||||
|
||||
|
||||
class ElasticsearchIndexSettings(TypedDict, total=False):
|
||||
"""Settings for Elasticsearch index creation."""
|
||||
|
||||
number_of_shards: NotRequired[int]
|
||||
number_of_replicas: NotRequired[int]
|
||||
refresh_interval: NotRequired[str]
|
||||
|
||||
|
||||
class ElasticsearchCollectionCreateParams(TypedDict, total=False):
|
||||
"""Parameters for creating Elasticsearch collections/indices."""
|
||||
|
||||
collection_name: str
|
||||
index_settings: NotRequired[ElasticsearchIndexSettings]
|
||||
vector_dimension: NotRequired[int]
|
||||
similarity: NotRequired[str]
|
||||
|
||||
|
||||
class EmbeddingFunction(Protocol):
|
||||
"""Protocol for embedding functions that convert text to vectors."""
|
||||
|
||||
def __call__(self, text: str) -> list[float]:
|
||||
"""Convert text to embedding vector.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AsyncEmbeddingFunction(Protocol):
|
||||
"""Protocol for async embedding functions that convert text to vectors."""
|
||||
|
||||
async def __call__(self, text: str) -> list[float]:
|
||||
"""Convert text to embedding vector asynchronously.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ElasticsearchEmbeddingFunctionWrapper(EmbeddingFunction):
|
||||
"""Base class for Elasticsearch EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for Elasticsearch EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle Elasticsearch's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
@@ -1,186 +0,0 @@
|
||||
"""Utility functions for Elasticsearch RAG implementation."""
|
||||
|
||||
import hashlib
|
||||
from typing import Any, TypeGuard
|
||||
|
||||
from crewai.rag.elasticsearch.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
ElasticsearchClientType,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
try:
|
||||
from elasticsearch import Elasticsearch, AsyncElasticsearch
|
||||
except ImportError:
|
||||
Elasticsearch = None
|
||||
AsyncElasticsearch = None
|
||||
|
||||
|
||||
def _is_sync_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
|
||||
"""Type guard to check if the client is a sync Elasticsearch client."""
|
||||
if Elasticsearch is None:
|
||||
return False
|
||||
return isinstance(client, Elasticsearch)
|
||||
|
||||
|
||||
def _is_async_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
|
||||
"""Type guard to check if the client is an async Elasticsearch client."""
|
||||
if AsyncElasticsearch is None:
|
||||
return False
|
||||
return isinstance(client, AsyncElasticsearch)
|
||||
|
||||
|
||||
def _is_async_embedding_function(
|
||||
func: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
) -> TypeGuard[AsyncEmbeddingFunction]:
|
||||
"""Type guard to check if the embedding function is async."""
|
||||
import inspect
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _generate_doc_id(content: str) -> str:
|
||||
"""Generate a document ID from content using SHA256 hash."""
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def _prepare_document_for_elasticsearch(
|
||||
doc: BaseRecord, embedding: list[float]
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare a document for Elasticsearch indexing.
|
||||
|
||||
Args:
|
||||
doc: Document record to prepare.
|
||||
embedding: Embedding vector for the document.
|
||||
|
||||
Returns:
|
||||
Document formatted for Elasticsearch.
|
||||
"""
|
||||
doc_id = doc.get("doc_id") or _generate_doc_id(doc["content"])
|
||||
|
||||
es_doc = {
|
||||
"content": doc["content"],
|
||||
"content_vector": embedding,
|
||||
"metadata": doc.get("metadata", {}),
|
||||
}
|
||||
|
||||
return {"id": doc_id, "body": es_doc}
|
||||
|
||||
|
||||
def _process_search_results(
|
||||
response: dict[str, Any], score_threshold: float | None = None
|
||||
) -> list[SearchResult]:
|
||||
"""Process Elasticsearch search response into SearchResult format.
|
||||
|
||||
Args:
|
||||
response: Raw Elasticsearch search response.
|
||||
score_threshold: Optional minimum score threshold.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dictionaries.
|
||||
"""
|
||||
results = []
|
||||
|
||||
hits = response.get("hits", {}).get("hits", [])
|
||||
|
||||
for hit in hits:
|
||||
score = hit.get("_score", 0.0)
|
||||
|
||||
if score_threshold is not None and score < score_threshold:
|
||||
continue
|
||||
|
||||
source = hit.get("_source", {})
|
||||
|
||||
result = SearchResult(
|
||||
id=hit.get("_id", ""),
|
||||
content=source.get("content", ""),
|
||||
metadata=source.get("metadata", {}),
|
||||
score=score,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _build_vector_search_query(
|
||||
query_vector: list[float],
|
||||
limit: int = 10,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build Elasticsearch query for vector similarity search.
|
||||
|
||||
Args:
|
||||
query_vector: Query embedding vector.
|
||||
limit: Maximum number of results.
|
||||
metadata_filter: Optional metadata filter.
|
||||
score_threshold: Optional minimum score threshold.
|
||||
|
||||
Returns:
|
||||
Elasticsearch query dictionary.
|
||||
"""
|
||||
query = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
|
||||
"params": {"query_vector": query_vector}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if metadata_filter:
|
||||
bool_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query["query"]
|
||||
],
|
||||
"filter": []
|
||||
}
|
||||
}
|
||||
|
||||
for key, value in metadata_filter.items():
|
||||
bool_query["bool"]["filter"].append({
|
||||
"term": {f"metadata.{key}": value}
|
||||
})
|
||||
|
||||
query["query"] = bool_query
|
||||
|
||||
if score_threshold is not None:
|
||||
query["min_score"] = score_threshold
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def _get_index_mapping(vector_dimension: int, similarity: str = "cosine") -> dict[str, Any]:
|
||||
"""Get Elasticsearch index mapping for vector search.
|
||||
|
||||
Args:
|
||||
vector_dimension: Dimension of the embedding vectors.
|
||||
similarity: Similarity function to use.
|
||||
|
||||
Returns:
|
||||
Elasticsearch mapping dictionary.
|
||||
"""
|
||||
return {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "text",
|
||||
"analyzer": "standard"
|
||||
},
|
||||
"content_vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": vector_dimension,
|
||||
"similarity": similarity
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"dynamic": True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ from typing import cast
|
||||
from crewai.rag.config.optional_imports.protocols import (
|
||||
ChromaFactoryModule,
|
||||
QdrantFactoryModule,
|
||||
ElasticsearchFactoryModule,
|
||||
)
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
@@ -44,15 +43,3 @@ def create_client(config: RagConfigType) -> BaseClient:
|
||||
),
|
||||
)
|
||||
return qdrant_mod.create_client(config)
|
||||
|
||||
if config.provider == "elasticsearch":
|
||||
elasticsearch_mod = cast(
|
||||
ElasticsearchFactoryModule,
|
||||
require(
|
||||
"crewai.rag.elasticsearch.factory",
|
||||
purpose="The 'elasticsearch' provider",
|
||||
),
|
||||
)
|
||||
return elasticsearch_mod.create_client(config)
|
||||
|
||||
raise ValueError(f"Unsupported provider: {config.provider}")
|
||||
|
||||
@@ -3,11 +3,6 @@ from datetime import datetime, timedelta
|
||||
import requests
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
from crewai.cli.authentication.main import AuthenticationCommand
|
||||
from crewai.cli.authentication.constants import (
|
||||
AUTH0_AUDIENCE,
|
||||
AUTH0_CLIENT_ID,
|
||||
AUTH0_DOMAIN
|
||||
)
|
||||
from crewai.cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
@@ -22,16 +17,6 @@ class TestAuthenticationCommand:
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,expected_urls",
|
||||
[
|
||||
(
|
||||
"auth0",
|
||||
{
|
||||
"device_code_url": f"https://{AUTH0_DOMAIN}/oauth/device/code",
|
||||
"token_url": f"https://{AUTH0_DOMAIN}/oauth/token",
|
||||
"client_id": AUTH0_CLIENT_ID,
|
||||
"audience": AUTH0_AUDIENCE,
|
||||
"domain": AUTH0_DOMAIN,
|
||||
},
|
||||
),
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
@@ -44,9 +29,6 @@ class TestAuthenticationCommand:
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
|
||||
)
|
||||
@patch("crewai.cli.authentication.main.AuthenticationCommand._get_device_code")
|
||||
@patch(
|
||||
"crewai.cli.authentication.main.AuthenticationCommand._display_auth_instructions"
|
||||
@@ -59,11 +41,9 @@ class TestAuthenticationCommand:
|
||||
mock_poll,
|
||||
mock_display,
|
||||
mock_get_device,
|
||||
mock_determine_provider,
|
||||
user_provider,
|
||||
expected_urls,
|
||||
):
|
||||
mock_determine_provider.return_value = user_provider
|
||||
mock_get_device.return_value = {
|
||||
"device_code": "test_code",
|
||||
"user_code": "123456",
|
||||
@@ -74,7 +54,6 @@ class TestAuthenticationCommand:
|
||||
mock_console_print.assert_called_once_with(
|
||||
"Signing in to CrewAI Enterprise...\n", style="bold blue"
|
||||
)
|
||||
mock_determine_provider.assert_called_once()
|
||||
mock_get_device.assert_called_once()
|
||||
mock_display.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"}
|
||||
@@ -82,9 +61,17 @@ class TestAuthenticationCommand:
|
||||
mock_poll.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"},
|
||||
)
|
||||
assert self.auth_command.oauth2_provider.get_client_id() == expected_urls["client_id"]
|
||||
assert self.auth_command.oauth2_provider.get_audience() == expected_urls["audience"]
|
||||
assert self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_client_id()
|
||||
== expected_urls["client_id"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_audience()
|
||||
== expected_urls["audience"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
||||
)
|
||||
|
||||
@patch("crewai.cli.authentication.main.webbrowser")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@@ -106,14 +93,6 @@ class TestAuthenticationCommand:
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,jwt_config",
|
||||
[
|
||||
(
|
||||
"auth0",
|
||||
{
|
||||
"jwks_url": f"https://{AUTH0_DOMAIN}/.well-known/jwks.json",
|
||||
"issuer": f"https://{AUTH0_DOMAIN}/",
|
||||
"audience": AUTH0_AUDIENCE,
|
||||
},
|
||||
),
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
@@ -135,14 +114,18 @@ class TestAuthenticationCommand:
|
||||
jwt_config,
|
||||
has_expiration,
|
||||
):
|
||||
from crewai.cli.authentication.providers.auth0 import Auth0Provider
|
||||
from crewai.cli.authentication.providers.workos import WorkosProvider
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
|
||||
if user_provider == "auth0":
|
||||
self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=AUTH0_DOMAIN, audience=jwt_config["audience"]))
|
||||
elif user_provider == "workos":
|
||||
self.auth_command.oauth2_provider = WorkosProvider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, audience=jwt_config["audience"]))
|
||||
if user_provider == "workos":
|
||||
self.auth_command.oauth2_provider = WorkosProvider(
|
||||
settings=Oauth2Settings(
|
||||
provider=user_provider,
|
||||
client_id="test-client-id",
|
||||
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
audience=jwt_config["audience"],
|
||||
)
|
||||
)
|
||||
|
||||
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
|
||||
|
||||
@@ -234,83 +217,6 @@ class TestAuthenticationCommand:
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_response,expected_provider",
|
||||
[
|
||||
({"provider": "auth0"}, "auth0"),
|
||||
({"provider": "workos"}, "workos"),
|
||||
({"provider": "none"}, "workos"), # Default to workos for any other value
|
||||
(
|
||||
{},
|
||||
"workos",
|
||||
), # Default to workos if no provider key is sent in the response
|
||||
],
|
||||
)
|
||||
@patch("crewai.cli.authentication.main.PlusAPI")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("builtins.input", return_value="test@example.com")
|
||||
def test_determine_user_provider_success(
|
||||
self,
|
||||
mock_input,
|
||||
mock_console_print,
|
||||
mock_plus_api,
|
||||
api_response,
|
||||
expected_provider,
|
||||
):
|
||||
mock_api_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = api_response
|
||||
mock_api_instance._make_request.return_value = mock_response
|
||||
mock_plus_api.return_value = mock_api_instance
|
||||
|
||||
result = self.auth_command._determine_user_provider()
|
||||
|
||||
mock_input.assert_called_once()
|
||||
|
||||
mock_plus_api.assert_called_once_with("")
|
||||
mock_api_instance._make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
|
||||
)
|
||||
|
||||
assert result == expected_provider
|
||||
|
||||
@patch("crewai.cli.authentication.main.PlusAPI")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("builtins.input", return_value="test@example.com")
|
||||
def test_determine_user_provider_error(
|
||||
self, mock_input, mock_console_print, mock_plus_api
|
||||
):
|
||||
mock_api_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_api_instance._make_request.return_value = mock_response
|
||||
mock_plus_api.return_value = mock_api_instance
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
self.auth_command._determine_user_provider()
|
||||
|
||||
mock_input.assert_called_once()
|
||||
|
||||
mock_plus_api.assert_called_once_with("")
|
||||
mock_api_instance._make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
|
||||
)
|
||||
|
||||
mock_console_print.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
"Enter your CrewAI Enterprise account email: ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
),
|
||||
call(
|
||||
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
|
||||
style="red",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@patch("requests.post")
|
||||
def test_get_device_code(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
@@ -323,7 +229,9 @@ class TestAuthenticationCommand:
|
||||
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
self.auth_command.oauth2_provider.get_authorize_url.return_value = "https://example.com/device"
|
||||
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
|
||||
"https://example.com/device"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||
|
||||
result = self.auth_command._get_device_code()
|
||||
@@ -366,12 +274,12 @@ class TestAuthenticationCommand:
|
||||
) as mock_tool_login,
|
||||
):
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_token_url.return_value = "https://example.com/token"
|
||||
self.auth_command.oauth2_provider.get_token_url.return_value = (
|
||||
"https://example.com/token"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
|
||||
self.auth_command._poll_for_token(
|
||||
device_code_data
|
||||
)
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
"https://example.com/token",
|
||||
@@ -406,9 +314,7 @@ class TestAuthenticationCommand:
|
||||
"interval": 0.1, # Short interval for testing
|
||||
}
|
||||
|
||||
self.auth_command._poll_for_token(
|
||||
device_code_data
|
||||
)
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_console_print.assert_any_call(
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
@@ -429,15 +335,4 @@ class TestAuthenticationCommand:
|
||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||
|
||||
with pytest.raises(requests.HTTPError):
|
||||
self.auth_command._poll_for_token(
|
||||
device_code_data
|
||||
)
|
||||
# @patch(
|
||||
# "crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
|
||||
# )
|
||||
# def test_login_with_auth0(self, mock_determine_provider):
|
||||
# from crewai.cli.authentication.providers.auth0 import Auth0Provider
|
||||
# from crewai.cli.authentication.main import Oauth2Settings
|
||||
|
||||
# self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider="auth0", client_id=AUTH0_CLIENT_ID, domain=AUTH0_DOMAIN, audience=AUTH0_AUDIENCE))
|
||||
# self.auth_command.login()
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.factory import create_client
|
||||
|
||||
|
||||
@@ -27,50 +25,10 @@ def test_create_client_chromadb():
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_qdrant():
|
||||
"""Test Qdrant client creation."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "qdrant"
|
||||
|
||||
with patch("crewai.rag.factory.require") as mock_require:
|
||||
mock_module = Mock()
|
||||
mock_client = Mock()
|
||||
mock_module.create_client.return_value = mock_client
|
||||
mock_require.return_value = mock_module
|
||||
|
||||
result = create_client(mock_config)
|
||||
|
||||
assert result == mock_client
|
||||
mock_require.assert_called_once_with(
|
||||
"crewai.rag.qdrant.factory", purpose="The 'qdrant' provider"
|
||||
)
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_elasticsearch():
|
||||
"""Test Elasticsearch client creation."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "elasticsearch"
|
||||
|
||||
with patch("crewai.rag.factory.require") as mock_require:
|
||||
mock_module = Mock()
|
||||
mock_client = Mock()
|
||||
mock_module.create_client.return_value = mock_client
|
||||
mock_require.return_value = mock_module
|
||||
|
||||
result = create_client(mock_config)
|
||||
|
||||
assert result == mock_client
|
||||
mock_require.assert_called_once_with(
|
||||
"crewai.rag.elasticsearch.factory", purpose="The 'elasticsearch' provider"
|
||||
)
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_unsupported_provider():
|
||||
"""Test that unsupported provider raises ValueError."""
|
||||
"""Test unsupported provider returns None for now."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "unsupported"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
|
||||
create_client(mock_config)
|
||||
result = create_client(mock_config)
|
||||
assert result is None
|
||||
|
||||
@@ -3,10 +3,7 @@
|
||||
import pytest
|
||||
|
||||
from crewai.rag.config.optional_imports.base import _MissingProvider
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingChromaDBConfig,
|
||||
MissingElasticsearchConfig,
|
||||
)
|
||||
from crewai.rag.config.optional_imports.providers import MissingChromaDBConfig
|
||||
|
||||
|
||||
def test_missing_provider_raises_runtime_error():
|
||||
@@ -23,11 +20,3 @@ def test_missing_chromadb_config_raises_runtime_error():
|
||||
RuntimeError, match="provider 'chromadb' requested but not installed"
|
||||
):
|
||||
MissingChromaDBConfig()
|
||||
|
||||
|
||||
def test_missing_elasticsearch_config_raises_runtime_error():
|
||||
"""Test that MissingElasticsearchConfig raises RuntimeError on instantiation."""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="provider 'elasticsearch' requested but not installed"
|
||||
):
|
||||
MissingElasticsearchConfig()
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Tests for Elasticsearch RAG implementation."""
|
||||
@@ -1,397 +0,0 @@
|
||||
"""Tests for ElasticsearchClient implementation."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
||||
from crewai.rag.types import BaseRecord
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_elasticsearch_client():
|
||||
"""Create a mock Elasticsearch client."""
|
||||
mock_client = Mock()
|
||||
mock_client.indices = Mock()
|
||||
mock_client.indices.exists.return_value = False
|
||||
mock_client.indices.create.return_value = {"acknowledged": True}
|
||||
mock_client.indices.get.return_value = {"test_index": {"mappings": {}}}
|
||||
mock_client.indices.delete.return_value = {"acknowledged": True}
|
||||
mock_client.index.return_value = {"_id": "test_id", "result": "created"}
|
||||
mock_client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "doc1",
|
||||
"_score": 0.9,
|
||||
"_source": {
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_elasticsearch_client():
|
||||
"""Create a mock async Elasticsearch client."""
|
||||
mock_client = Mock()
|
||||
mock_client.indices = Mock()
|
||||
mock_client.indices.exists = AsyncMock(return_value=False)
|
||||
mock_client.indices.create = AsyncMock(return_value={"acknowledged": True})
|
||||
mock_client.indices.get = AsyncMock(return_value={"test_index": {"mappings": {}}})
|
||||
mock_client.indices.delete = AsyncMock(return_value={"acknowledged": True})
|
||||
mock_client.index = AsyncMock(return_value={"_id": "test_id", "result": "created"})
|
||||
mock_client.search = AsyncMock(return_value={
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "doc1",
|
||||
"_score": 0.9,
|
||||
"_source": {
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_elasticsearch_client) -> ElasticsearchClient:
|
||||
"""Create an ElasticsearchClient instance for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
client = ElasticsearchClient(
|
||||
client=mock_elasticsearch_client,
|
||||
embedding_function=mock_embedding,
|
||||
vector_dimension=3,
|
||||
similarity="cosine"
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_elasticsearch_client) -> ElasticsearchClient:
|
||||
"""Create an ElasticsearchClient instance with async client for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
client = ElasticsearchClient(
|
||||
client=mock_async_elasticsearch_client,
|
||||
embedding_function=mock_embedding,
|
||||
vector_dimension=3,
|
||||
similarity="cosine"
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class TestElasticsearchClient:
|
||||
"""Test suite for ElasticsearchClient."""
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that create_collection creates a new index."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
client.create_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.create.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.indices.create.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "mappings" in call_args.kwargs["body"]
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_create_collection_already_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that create_collection raises error if index exists."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Index 'test_index' already exists"
|
||||
):
|
||||
client.create_collection(collection_name="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
def test_create_collection_wrong_client_type(self, mock_is_async, mock_is_sync, mock_async_elasticsearch_client):
|
||||
"""Test that create_collection raises error with async client."""
|
||||
mock_embedding = Mock()
|
||||
client = ElasticsearchClient(
|
||||
client=mock_async_elasticsearch_client,
|
||||
embedding_function=mock_embedding
|
||||
)
|
||||
|
||||
with pytest.raises(ClientMethodMismatchError):
|
||||
client.create_collection(collection_name="test_index")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_acreate_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that acreate_collection creates a new index asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
await async_client.acreate_collection(collection_name="test_index")
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.indices.create.assert_called_once()
|
||||
call_args = mock_async_elasticsearch_client.indices.create.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "mappings" in call_args.kwargs["body"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_acreate_collection_already_exists(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that acreate_collection raises error if index exists."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Index 'test_index' already exists"
|
||||
):
|
||||
await async_client.acreate_collection(collection_name="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_get_or_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that get_or_create_collection returns existing index."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
||||
assert result == {"test_index": {"mappings": {}}}
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_get_or_create_collection_creates_new(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that get_or_create_collection creates new index if not exists."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
client.get_or_create_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.create.assert_called_once()
|
||||
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_aget_or_create_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that aget_or_create_collection returns existing index asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
result = await async_client.aget_or_create_collection(collection_name="test_index")
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
||||
assert result == {"test_index": {"mappings": {}}}
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_add_documents(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that add_documents indexes documents correctly."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_index", documents=documents)
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.index.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.index.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "body" in call_args.kwargs
|
||||
assert call_args.kwargs["body"]["content"] == "test content"
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_add_documents_empty_list_raises_error(self, mock_is_async, mock_is_sync, client):
|
||||
"""Test that add_documents raises error with empty documents list."""
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
client.add_documents(collection_name="test_index", documents=[])
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_add_documents_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that add_documents raises error if index doesn't exist."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
documents: list[BaseRecord] = [{"content": "test content"}]
|
||||
|
||||
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
||||
client.add_documents(collection_name="test_index", documents=documents)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_aadd_documents(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that aadd_documents indexes documents correctly asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(collection_name="test_index", documents=documents)
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.index.assert_called_once()
|
||||
call_args = mock_async_elasticsearch_client.index.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "body" in call_args.kwargs
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_search(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that search performs vector similarity search."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
results = client.search(
|
||||
collection_name="test_index",
|
||||
query="test query",
|
||||
limit=5
|
||||
)
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.search.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.search.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "body" in call_args.kwargs
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "test content"
|
||||
assert results[0]["score"] == 0.9
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_search_with_metadata_filter(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that search applies metadata filter correctly."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
client.search(
|
||||
collection_name="test_index",
|
||||
query="test query",
|
||||
metadata_filter={"key": "value"}
|
||||
)
|
||||
|
||||
mock_elasticsearch_client.search.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.search.call_args
|
||||
query_body = call_args.kwargs["body"]
|
||||
assert "bool" in query_body["query"]
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_search_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that search raises error if index doesn't exist."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
||||
client.search(collection_name="test_index", query="test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_asearch(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that asearch performs vector similarity search asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_index",
|
||||
query="test query",
|
||||
limit=5
|
||||
)
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.search.assert_called_once()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "test content"
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_delete_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that delete_collection deletes the index."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
client.delete_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_delete_collection_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that delete_collection raises error if index doesn't exist."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
||||
client.delete_collection(collection_name="test_index")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_adelete_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that adelete_collection deletes the index asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
await async_client.adelete_collection(collection_name="test_index")
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_reset(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that reset deletes all non-system indices."""
|
||||
mock_elasticsearch_client.indices.get.return_value = {
|
||||
"test_index": {},
|
||||
".system_index": {},
|
||||
"another_index": {}
|
||||
}
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_elasticsearch_client.indices.get.assert_called_once_with(index="*")
|
||||
assert mock_elasticsearch_client.indices.delete.call_count == 2
|
||||
delete_calls = [call.kwargs["index"] for call in mock_elasticsearch_client.indices.delete.call_args_list]
|
||||
assert "test_index" in delete_calls
|
||||
assert "another_index" in delete_calls
|
||||
assert ".system_index" not in delete_calls
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_areset(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that areset deletes all non-system indices asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.get.return_value = {
|
||||
"test_index": {},
|
||||
".system_index": {},
|
||||
"another_index": {}
|
||||
}
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="*")
|
||||
assert mock_async_elasticsearch_client.indices.delete.call_count == 2
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Tests for Elasticsearch configuration."""
|
||||
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
|
||||
|
||||
def test_elasticsearch_config_defaults():
|
||||
"""Test that ElasticsearchConfig has correct defaults."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
assert config.provider == "elasticsearch"
|
||||
assert config.vector_dimension == 384
|
||||
assert config.similarity == "cosine"
|
||||
assert config.embedding_function is not None
|
||||
assert config.options["hosts"] == ["http://localhost:9200"]
|
||||
assert config.options["use_ssl"] is False
|
||||
|
||||
|
||||
def test_elasticsearch_config_custom_options():
|
||||
"""Test that ElasticsearchConfig accepts custom options."""
|
||||
custom_options = {
|
||||
"hosts": ["https://elastic.example.com:9200"],
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"use_ssl": True,
|
||||
}
|
||||
|
||||
config = ElasticsearchConfig(
|
||||
options=custom_options,
|
||||
vector_dimension=768,
|
||||
similarity="dot_product"
|
||||
)
|
||||
|
||||
assert config.provider == "elasticsearch"
|
||||
assert config.vector_dimension == 768
|
||||
assert config.similarity == "dot_product"
|
||||
assert config.options["hosts"] == ["https://elastic.example.com:9200"]
|
||||
assert config.options["username"] == "user"
|
||||
assert config.options["use_ssl"] is True
|
||||
|
||||
|
||||
def test_elasticsearch_config_embedding_function():
|
||||
"""Test that embedding function works correctly."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
embedding = config.embedding_function("test text")
|
||||
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) == config.vector_dimension
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Tests for Elasticsearch factory."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
|
||||
|
||||
def test_create_client():
|
||||
"""Test that create_client creates an ElasticsearchClient."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
with patch.dict('sys.modules', {'elasticsearch': Mock()}):
|
||||
mock_elasticsearch_module = Mock()
|
||||
mock_client_instance = Mock()
|
||||
mock_elasticsearch_module.Elasticsearch.return_value = mock_client_instance
|
||||
|
||||
with patch.dict('sys.modules', {'elasticsearch': mock_elasticsearch_module}):
|
||||
from crewai.rag.elasticsearch.factory import create_client
|
||||
client = create_client(config)
|
||||
|
||||
mock_elasticsearch_module.Elasticsearch.assert_called_once_with(**config.options)
|
||||
assert client.client == mock_client_instance
|
||||
assert client.embedding_function == config.embedding_function
|
||||
assert client.vector_dimension == config.vector_dimension
|
||||
assert client.similarity == config.similarity
|
||||
|
||||
|
||||
def test_create_client_missing_elasticsearch():
|
||||
"""Test that create_client raises ImportError when elasticsearch is not installed."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
with patch.dict('sys.modules', {}, clear=False):
|
||||
if 'elasticsearch' in __import__('sys').modules:
|
||||
del __import__('sys').modules['elasticsearch']
|
||||
|
||||
from crewai.rag.elasticsearch.factory import create_client
|
||||
with pytest.raises(ImportError, match="elasticsearch package is required"):
|
||||
create_client(config)
|
||||
@@ -624,12 +624,12 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
tools = kwargs["tools"]
|
||||
|
||||
assert any(isinstance(tool, TestTool) for tool in tools), (
|
||||
"TestTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in tools
|
||||
), "TestTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -688,12 +688,12 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
tools = kwargs["tools"]
|
||||
|
||||
assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), (
|
||||
"TestTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in new_ceo.tools
|
||||
), "TestTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -817,17 +817,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Confirm AnotherTestTool is present but TestTool is not
|
||||
assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), (
|
||||
"AnotherTestTool should be present"
|
||||
)
|
||||
assert not any(isinstance(tool, TestTool) for tool in used_tools), (
|
||||
"TestTool should not be present among used tools"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, AnotherTestTool) for tool in used_tools
|
||||
), "AnotherTestTool should be present"
|
||||
assert not any(
|
||||
isinstance(tool, TestTool) for tool in used_tools
|
||||
), "TestTool should not be present among used tools"
|
||||
|
||||
# Confirm delegation tool(s) are present
|
||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in used_tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
# Finally, make sure the agent's original tools remain unchanged
|
||||
assert len(researcher_with_delegation.tools) == 1
|
||||
@@ -931,9 +931,9 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
|
||||
tool="multiplier", input={"first_number": 2, "second_number": 6}
|
||||
)
|
||||
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
|
||||
assert cache_calls[1] == expected_call, (
|
||||
f"Second call mismatch: {cache_calls[1]}"
|
||||
)
|
||||
assert (
|
||||
cache_calls[1] == expected_call
|
||||
), f"Second call mismatch: {cache_calls[1]}"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -1676,9 +1676,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
|
||||
|
||||
# Verify that exactly one tool was used and it was a CodeInterpreterTool
|
||||
assert len(used_tools) == 1, "Should have exactly one tool"
|
||||
assert isinstance(used_tools[0], CodeInterpreterTool), (
|
||||
"Tool should be CodeInterpreterTool"
|
||||
)
|
||||
assert isinstance(
|
||||
used_tools[0], CodeInterpreterTool
|
||||
), "Tool should be CodeInterpreterTool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -2537,8 +2537,8 @@ def test_memory_events_are_emitted():
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) == 6
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 6
|
||||
assert len(events["MemorySaveStartedEvent"]) == 3
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 3
|
||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 3
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 3
|
||||
@@ -3817,9 +3817,9 @@ def test_fetch_inputs():
|
||||
expected_placeholders = {"role_detail", "topic", "field"}
|
||||
actual_placeholders = crew.fetch_inputs()
|
||||
|
||||
assert actual_placeholders == expected_placeholders, (
|
||||
f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||
)
|
||||
assert (
|
||||
actual_placeholders == expected_placeholders
|
||||
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||
|
||||
|
||||
def test_task_tools_preserve_code_execution_tools():
|
||||
@@ -3894,20 +3894,20 @@ def test_task_tools_preserve_code_execution_tools():
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Verify all expected tools are present
|
||||
assert any(isinstance(tool, TestTool) for tool in used_tools), (
|
||||
"Task's TestTool should be present"
|
||||
)
|
||||
assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), (
|
||||
"CodeInterpreterTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in used_tools
|
||||
), "Task's TestTool should be present"
|
||||
assert any(
|
||||
isinstance(tool, CodeInterpreterTool) for tool in used_tools
|
||||
), "CodeInterpreterTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in used_tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
|
||||
assert len(used_tools) == 4, (
|
||||
"Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
||||
)
|
||||
assert (
|
||||
len(used_tools) == 4
|
||||
), "Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -3951,9 +3951,9 @@ def test_multimodal_flag_adds_multimodal_tools():
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Check that the multimodal tool was added
|
||||
assert any(isinstance(tool, AddImageTool) for tool in used_tools), (
|
||||
"AddImageTool should be present when agent is multimodal"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, AddImageTool) for tool in used_tools
|
||||
), "AddImageTool should be present when agent is multimodal"
|
||||
|
||||
# Verify we have exactly one tool (just the AddImageTool)
|
||||
assert len(used_tools) == 1, "Should only have the AddImageTool"
|
||||
@@ -4217,9 +4217,9 @@ def test_crew_guardrail_feedback_in_context():
|
||||
assert len(execution_contexts) > 1, "Task should have been executed multiple times"
|
||||
|
||||
# Verify that the second execution included the guardrail feedback
|
||||
assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], (
|
||||
"Guardrail feedback should be included in retry context"
|
||||
)
|
||||
assert (
|
||||
"Output must contain the keyword 'IMPORTANT'" in execution_contexts[1]
|
||||
), "Guardrail feedback should be included in retry context"
|
||||
|
||||
# Verify final output meets guardrail requirements
|
||||
assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
|
||||
@@ -4435,46 +4435,46 @@ def test_crew_copy_with_memory():
|
||||
try:
|
||||
crew_copy = crew.copy()
|
||||
|
||||
assert hasattr(crew_copy, "_short_term_memory"), (
|
||||
"Copied crew should have _short_term_memory"
|
||||
)
|
||||
assert crew_copy._short_term_memory is not None, (
|
||||
"Copied _short_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._short_term_memory) != original_short_term_id, (
|
||||
"Copied _short_term_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_short_term_memory"
|
||||
), "Copied crew should have _short_term_memory"
|
||||
assert (
|
||||
crew_copy._short_term_memory is not None
|
||||
), "Copied _short_term_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._short_term_memory) != original_short_term_id
|
||||
), "Copied _short_term_memory should be a new object"
|
||||
|
||||
assert hasattr(crew_copy, "_long_term_memory"), (
|
||||
"Copied crew should have _long_term_memory"
|
||||
)
|
||||
assert crew_copy._long_term_memory is not None, (
|
||||
"Copied _long_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._long_term_memory) != original_long_term_id, (
|
||||
"Copied _long_term_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_long_term_memory"
|
||||
), "Copied crew should have _long_term_memory"
|
||||
assert (
|
||||
crew_copy._long_term_memory is not None
|
||||
), "Copied _long_term_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._long_term_memory) != original_long_term_id
|
||||
), "Copied _long_term_memory should be a new object"
|
||||
|
||||
assert hasattr(crew_copy, "_entity_memory"), (
|
||||
"Copied crew should have _entity_memory"
|
||||
)
|
||||
assert crew_copy._entity_memory is not None, (
|
||||
"Copied _entity_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._entity_memory) != original_entity_id, (
|
||||
"Copied _entity_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_entity_memory"
|
||||
), "Copied crew should have _entity_memory"
|
||||
assert (
|
||||
crew_copy._entity_memory is not None
|
||||
), "Copied _entity_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._entity_memory) != original_entity_id
|
||||
), "Copied _entity_memory should be a new object"
|
||||
|
||||
if original_external_id:
|
||||
assert hasattr(crew_copy, "_external_memory"), (
|
||||
"Copied crew should have _external_memory"
|
||||
)
|
||||
assert crew_copy._external_memory is not None, (
|
||||
"Copied _external_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._external_memory) != original_external_id, (
|
||||
"Copied _external_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_external_memory"
|
||||
), "Copied crew should have _external_memory"
|
||||
assert (
|
||||
crew_copy._external_memory is not None
|
||||
), "Copied _external_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._external_memory) != original_external_id
|
||||
), "Copied _external_memory should be a new object"
|
||||
else:
|
||||
assert (
|
||||
not hasattr(crew_copy, "_external_memory")
|
||||
|
||||
91
tests/test_dependency_compatibility.py
Normal file
91
tests/test_dependency_compatibility.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Tests for dependency compatibility, specifically for issue #3413.
|
||||
|
||||
This module tests that CrewAI can be installed alongside Google Cloud SDKs
|
||||
without protobuf dependency conflicts.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDependencyCompatibility:
|
||||
"""Test dependency compatibility with Google Cloud SDKs."""
|
||||
|
||||
def test_opentelemetry_protobuf_compatibility(self):
|
||||
"""Test that opentelemetry versions work with protobuf<5.0."""
|
||||
try:
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
|
||||
tracer_provider = TracerProvider()
|
||||
tracer = tracer_provider.get_tracer("test")
|
||||
|
||||
with tracer.start_as_current_span("test_span") as span:
|
||||
span.set_attribute("test", "value")
|
||||
assert span is not None
|
||||
|
||||
except ImportError as e:
|
||||
pytest.fail(f"Failed to import opentelemetry modules: {e}")
|
||||
|
||||
def test_google_cloud_sdk_compatibility_simulation(self):
|
||||
"""Simulate Google Cloud SDK protobuf requirements."""
|
||||
try:
|
||||
import google.protobuf
|
||||
version_parts = google.protobuf.__version__.split('.')
|
||||
major_version = int(version_parts[0])
|
||||
|
||||
assert major_version < 5, f"protobuf version {google.protobuf.__version__} should be <5.0 for Google Cloud SDK compatibility"
|
||||
|
||||
except ImportError:
|
||||
pytest.skip("protobuf not installed, skipping compatibility test")
|
||||
|
||||
def test_crewai_telemetry_functionality(self):
|
||||
"""Test that CrewAI telemetry functionality works with downgraded opentelemetry."""
|
||||
try:
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities.crew.crew_context import get_crew_context
|
||||
|
||||
telemetry = Telemetry()
|
||||
assert telemetry is not None
|
||||
|
||||
get_crew_context()
|
||||
|
||||
except ImportError as e:
|
||||
pytest.fail(f"Failed to import CrewAI telemetry modules: {e}")
|
||||
|
||||
def test_dry_run_installation_compatibility(self):
|
||||
"""Test that CrewAI and Google Cloud SDKs can be installed together."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
result = subprocess.run([
|
||||
sys.executable, "-m", "pip", "install", "--dry-run", "--quiet",
|
||||
"opentelemetry-api>=1.27.0,<1.30.0",
|
||||
"opentelemetry-sdk>=1.27.0,<1.30.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.27.0,<1.30.0",
|
||||
"google-cloud-storage"
|
||||
], capture_output=True, text=True, cwd=temp_dir)
|
||||
|
||||
assert result.returncode == 0, f"Dry run installation failed: {result.stderr}"
|
||||
|
||||
assert "protobuf" in result.stdout.lower(), "protobuf should be in installation plan"
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Dry run installation test failed: {e}")
|
||||
|
||||
def test_protobuf_version_constraint_resolution(self):
|
||||
"""Test that protobuf version constraints are properly resolved."""
|
||||
try:
|
||||
import google.protobuf
|
||||
version = google.protobuf.__version__
|
||||
|
||||
version_parts = [int(x) for x in version.split('.')]
|
||||
major, minor = version_parts[0], version_parts[1]
|
||||
|
||||
assert major >= 3, f"protobuf version {version} should be >=3.19"
|
||||
if major == 3:
|
||||
assert minor >= 19, f"protobuf version {version} should be >=3.19"
|
||||
assert major < 5, f"protobuf version {version} should be <5.0 for Google Cloud SDK compatibility"
|
||||
|
||||
except ImportError:
|
||||
pytest.skip("protobuf not installed, skipping version constraint test")
|
||||
Reference in New Issue
Block a user