Compare commits

..

5 Commits

Author SHA1 Message Date
Devin AI
b42d2e8cf0 fix: resolve lint issues in dependency compatibility tests
- Remove unused imports: pathlib.Path, opentelemetry modules
- Remove unused variable assignment for context
- All ruff checks now pass locally

Co-Authored-By: João <joao@crewai.com>
2025-08-27 19:31:32 +00:00
Devin AI
36de68ecd4 fix: resolve opentelemetry protobuf dependency conflict with Google Cloud SDKs
- Downgrade opentelemetry requirements from >=1.30.0 to >=1.27.0,<1.28.0
- This resolves protobuf version conflict where opentelemetry 1.30.0+ requires protobuf>=5.0
  but Google Cloud SDKs require protobuf<5.0
- Now uses protobuf 4.25.8 which satisfies both requirements
- Add comprehensive dependency compatibility tests
- Fixes issue #3413

Co-Authored-By: João <joao@crewai.com>
2025-08-27 19:28:05 +00:00
Greyson LaLonde
109de91d08 fix: batch entity memory items to reduce redundant operations (#3409)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
* fix: batch save entity memory items to reduce redundant operations

* test: update memory event count after entity batch save implementation
2025-08-27 10:47:20 -04:00
Erika Shorten
92b70e652d Add hybrid search alpha parameter to the docs (#3397)
Co-authored-by: Tony Kipkemboi <iamtonykipkemboi@gmail.com>
2025-08-27 10:36:39 -04:00
Heitor Carvalho
fc3f2c49d2 chore: remove auth0 and the need of typing the email on 'crewai login' (#3408)
* Remove the need of typing the email on 'crewai login'

* Remove auth0 constants, update tests
2025-08-27 10:12:57 -04:00
29 changed files with 3578 additions and 5446 deletions

View File

@@ -1,13 +1,13 @@
---
title: Weaviate Vector Search
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents.
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents using hybrid search.
icon: network-wired
---
## Overview
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector embeddings for more accurate and contextually relevant search results.
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector and keyword search for more accurate and contextually relevant search results.
[Weaviate](https://weaviate.io/) is a vector database that stores and queries vector embeddings, enabling semantic search capabilities.
@@ -39,6 +39,7 @@ from crewai_tools import WeaviateVectorSearchTool
tool = WeaviateVectorSearchTool(
collection_name='example_collections',
limit=3,
alpha=0.75,
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
weaviate_api_key="your-weaviate-api-key",
)
@@ -63,6 +64,7 @@ The `WeaviateVectorSearchTool` accepts the following parameters:
- **weaviate_cluster_url**: Required. The URL of the Weaviate cluster.
- **weaviate_api_key**: Required. The API key for the Weaviate cluster.
- **limit**: Optional. The number of results to return. Default is `3`.
- **alpha**: Optional. Controls the weighting between vector and keyword (BM25) search. alpha = 0 -> BM25 only, alpha = 1 -> vector search only. Default is `0.75`.
- **vectorizer**: Optional. The vectorizer to use. If not provided, it will use `text2vec_openai` with the `nomic-embed-text` model.
- **generative_model**: Optional. The generative model to use. If not provided, it will use OpenAI's `gpt-4o`.
@@ -78,6 +80,7 @@ from weaviate.classes.config import Configure
tool = WeaviateVectorSearchTool(
collection_name='example_collections',
limit=3,
alpha=0.75,
vectorizer=Configure.Vectorizer.text2vec_openai(model="nomic-embed-text"),
generative_model=Configure.Generative.openai(model="gpt-4o-mini"),
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
@@ -128,6 +131,7 @@ with test_docs.batch.dynamic() as batch:
tool = WeaviateVectorSearchTool(
collection_name='example_collections',
limit=3,
alpha=0.75,
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
weaviate_api_key="your-weaviate-api-key",
)
@@ -145,6 +149,7 @@ from crewai_tools import WeaviateVectorSearchTool
weaviate_tool = WeaviateVectorSearchTool(
collection_name='example_collections',
limit=3,
alpha=0.75,
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
weaviate_api_key="your-weaviate-api-key",
)

View File

@@ -17,9 +17,9 @@ dependencies = [
"pdfplumber>=0.11.4",
"regex>=2024.9.11",
# Telemetry and Monitoring
"opentelemetry-api>=1.30.0",
"opentelemetry-sdk>=1.30.0",
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
"opentelemetry-api>=1.27.0,<1.28.0",
"opentelemetry-sdk>=1.27.0,<1.28.0",
"opentelemetry-exporter-otlp-proto-http>=1.27.0,<1.28.0",
# Data Handling
"chromadb>=0.5.23",
"tokenizers>=0.20.3",

View File

@@ -98,8 +98,8 @@ class CrewAgentExecutorMixin:
)
self.crew._long_term_memory.save(long_term_memory)
for entity in evaluation.entities:
entity_memory = EntityMemoryItem(
entity_memories = [
EntityMemoryItem(
name=entity.name,
type=entity.type,
description=entity.description,
@@ -107,7 +107,10 @@ class CrewAgentExecutorMixin:
[f"- {r}" for r in entity.relationships]
),
)
self.crew._entity_memory.save(entity_memory)
for entity in evaluation.entities
]
if entity_memories:
self.crew._entity_memory.save(entity_memories)
except AttributeError as e:
print(f"Missing attributes for long term memory: {e}")
pass

View File

@@ -1,6 +1 @@
ALGORITHMS = ["RS256"]
#TODO: The AUTH0 constants should be removed after WorkOS migration is completed
AUTH0_DOMAIN = "crewai.us.auth0.com"
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"

View File

@@ -9,14 +9,7 @@ from pydantic import BaseModel, Field
from .utils import validate_jwt_token
from crewai.cli.shared.token_manager import TokenManager
from urllib.parse import quote
from crewai.cli.plus_api import PlusAPI
from crewai.cli.config import Settings
from crewai.cli.authentication.constants import (
AUTH0_AUDIENCE,
AUTH0_CLIENT_ID,
AUTH0_DOMAIN,
)
console = Console()
@@ -72,18 +65,6 @@ class AuthenticationCommand:
"""Sign up to CrewAI+"""
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
# TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete.
user_provider = self._determine_user_provider()
if user_provider == "auth0":
settings = Oauth2Settings(
provider="auth0",
client_id=AUTH0_CLIENT_ID,
domain=AUTH0_DOMAIN,
audience=AUTH0_AUDIENCE,
)
self.oauth2_provider = ProviderFactory.from_settings(settings)
# End of temporary code.
device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data)
@@ -206,30 +187,3 @@ class AuthenticationCommand:
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
)
# TODO: WORKOS - This method is temporary until migration to WorkOS is complete.
def _determine_user_provider(self) -> str:
"""Determine which provider to use for authentication."""
console.print(
"Enter your CrewAI Enterprise account email: ", style="bold blue", end=""
)
email = input()
email_encoded = quote(email)
# It's not correct to call this method directly, but it's temporary until migration is complete.
response = PlusAPI("")._make_request(
"GET", f"/crewai_plus/api/v1/me/provider?email={email_encoded}"
)
if response.status_code == 200:
if response.json().get("provider") == "auth0":
return "auth0"
else:
return "workos"
else:
console.print(
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
style="red",
)
raise SystemExit

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any
import time
from pydantic import PrivateAttr
@@ -24,7 +24,7 @@ class EntityMemory(Memory):
Inherits from the Memory class.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None
@@ -53,12 +53,33 @@ class EntityMemory(Memory):
super().__init__(storage=storage)
self._memory_provider = memory_provider
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
"""Saves an entity item into the SQLite storage."""
def save(
self,
value: EntityMemoryItem | list[EntityMemoryItem],
metadata: dict[str, Any] | None = None,
) -> None:
"""Saves one or more entity items into the SQLite storage.
Args:
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
metadata: Optional metadata dict (included for supertype compatibility but not used).
Notes:
The metadata parameter is included to satisfy the supertype signature but is not
used - entity metadata is extracted from the EntityMemoryItem objects themselves.
"""
if not value:
return
items = value if isinstance(value, list) else [value]
is_batch = len(items) > 1
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
metadata=item.metadata,
metadata=metadata,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
@@ -66,36 +87,61 @@ class EntityMemory(Memory):
)
start_time = time.time()
saved_count = 0
errors = []
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
for item in items:
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
saved_count += 1
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
if is_batch:
emit_value = f"Saved {saved_count} entities"
metadata = {"entity_count": saved_count, "errors": errors}
else:
data = f"{item.name}({item.type}): {item.description}"
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
metadata = items[0].metadata
super().save(data, item.metadata)
# Emit memory save completed event
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=data,
metadata=item.metadata,
value=emit_value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
if errors:
raise Exception(
f"Partial save: {len(errors)} failed out of {len(items)}"
)
except Exception as e:
fail_metadata = (
{"entity_count": len(items), "saved": saved_count}
if is_batch
else items[0].metadata
)
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
metadata=item.metadata,
metadata=fail_metadata,
error=str(e),
source_type="entity_memory",
from_agent=self.agent,

View File

@@ -14,7 +14,7 @@ class _MissingProvider:
Raises RuntimeError when instantiated to indicate missing dependencies.
"""
provider: Literal["chromadb", "qdrant", "elasticsearch", "__missing__"] = field(
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
default="__missing__"
)

View File

@@ -9,8 +9,6 @@ if TYPE_CHECKING:
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
from crewai.rag.elasticsearch.client import ElasticsearchClient
from crewai.rag.elasticsearch.config import ElasticsearchConfig
class ChromaFactoryModule(Protocol):
@@ -27,11 +25,3 @@ class QdrantFactoryModule(Protocol):
def create_client(self, config: QdrantConfig) -> QdrantClient:
"""Creates a Qdrant client from configuration."""
...
class ElasticsearchFactoryModule(Protocol):
"""Protocol for Elasticsearch factory module."""
def create_client(self, config: ElasticsearchConfig) -> ElasticsearchClient:
"""Creates an Elasticsearch client from configuration."""
...

View File

@@ -20,10 +20,3 @@ class MissingQdrantConfig(_MissingProvider):
"""Placeholder for missing Qdrant configuration."""
provider: Literal["qdrant"] = field(default="qdrant")
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingElasticsearchConfig(_MissingProvider):
"""Placeholder for missing Elasticsearch configuration."""
provider: Literal["elasticsearch"] = field(default="elasticsearch")

View File

@@ -3,6 +3,6 @@
from typing import Annotated, Literal
SupportedProvider = Annotated[
Literal["chromadb", "qdrant", "elasticsearch"],
Literal["chromadb", "qdrant"],
"Supported RAG provider types, add providers here as they become available",
]

View File

@@ -13,9 +13,6 @@ if TYPE_CHECKING:
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
QdrantConfig = QdrantConfig_
from crewai.rag.elasticsearch.config import ElasticsearchConfig as ElasticsearchConfig_
ElasticsearchConfig = ElasticsearchConfig_
else:
try:
from crewai.rag.chromadb.config import ChromaDBConfig
@@ -31,14 +28,7 @@ else:
MissingQdrantConfig as QdrantConfig,
)
try:
from crewai.rag.elasticsearch.config import ElasticsearchConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingElasticsearchConfig as ElasticsearchConfig,
)
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig | ElasticsearchConfig
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
RagConfigType: TypeAlias = Annotated[
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
]

View File

@@ -1 +0,0 @@
"""Elasticsearch RAG implementation."""

View File

@@ -1,502 +0,0 @@
"""Elasticsearch client implementation."""
from typing import Any, cast
from typing_extensions import Unpack
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionSearchParams,
)
from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.elasticsearch.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
ElasticsearchClientType,
ElasticsearchCollectionCreateParams,
)
from crewai.rag.elasticsearch.utils import (
_is_async_client,
_is_async_embedding_function,
_is_sync_client,
_prepare_document_for_elasticsearch,
_process_search_results,
_build_vector_search_query,
_get_index_mapping,
)
from crewai.rag.types import SearchResult
class ElasticsearchClient(BaseClient):
"""Elasticsearch implementation of the BaseClient protocol.
Provides vector database operations for Elasticsearch, supporting both
synchronous and asynchronous clients.
Attributes:
client: Elasticsearch client instance (Elasticsearch or AsyncElasticsearch).
embedding_function: Function to generate embeddings for documents.
vector_dimension: Dimension of the embedding vectors.
similarity: Similarity function to use for vector search.
"""
def __init__(
self,
client: ElasticsearchClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
vector_dimension: int = 384,
similarity: str = "cosine",
) -> None:
"""Initialize ElasticsearchClient with client and embedding function.
Args:
client: Pre-configured Elasticsearch client instance.
embedding_function: Embedding function for text to vector conversion.
vector_dimension: Dimension of the embedding vectors.
similarity: Similarity function to use for vector search.
"""
self.client = client
self.embedding_function = embedding_function
self.vector_dimension = vector_dimension
self.similarity = similarity
def create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
"""Create a new index in Elasticsearch.
Keyword Args:
collection_name: Name of the index to create. Must be unique.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Raises:
ValueError: If index with the same name already exists.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="create_collection",
expected_client="Elasticsearch",
alt_method="acreate_collection",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
if self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' already exists")
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
self.client.indices.create(index=collection_name, body=mapping)
async def acreate_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
"""Create a new index in Elasticsearch asynchronously.
Keyword Args:
collection_name: Name of the index to create. Must be unique.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Raises:
ValueError: If index with the same name already exists.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="acreate_collection",
expected_client="AsyncElasticsearch",
alt_method="create_collection",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
if await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' already exists")
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
await self.client.indices.create(index=collection_name, body=mapping)
def get_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
"""Get an existing index or create it if it doesn't exist.
Keyword Args:
collection_name: Name of the index to get or create.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Returns:
Index info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="get_or_create_collection",
expected_client="Elasticsearch",
alt_method="aget_or_create_collection",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
if self.client.indices.exists(index=collection_name):
return self.client.indices.get(index=collection_name)
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
self.client.indices.create(index=collection_name, body=mapping)
return self.client.indices.get(index=collection_name)
async def aget_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
"""Get an existing index or create it if it doesn't exist asynchronously.
Keyword Args:
collection_name: Name of the index to get or create.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Returns:
Index info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aget_or_create_collection",
expected_client="AsyncElasticsearch",
alt_method="get_or_create_collection",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
if await self.client.indices.exists(index=collection_name):
return await self.client.indices.get(index=collection_name)
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
await self.client.indices.create(index=collection_name, body=mapping)
return await self.client.indices.get(index=collection_name)
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to an index.
Keyword Args:
collection_name: The name of the index to add documents to.
documents: List of BaseRecord dicts containing document data.
Raises:
ValueError: If index doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="add_documents",
expected_client="Elasticsearch",
alt_method="aadd_documents",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
if not self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync add_documents. "
"Use aadd_documents instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
self.client.index(
index=collection_name,
id=prepared_doc["id"],
body=prepared_doc["body"]
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to an index asynchronously.
Keyword Args:
collection_name: The name of the index to add documents to.
documents: List of BaseRecord dicts containing document data.
Raises:
ValueError: If index doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aadd_documents",
expected_client="AsyncElasticsearch",
alt_method="add_documents",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
if not await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
await self.client.index(
index=collection_name,
id=prepared_doc["id"],
body=prepared_doc["body"]
)
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Keyword Args:
collection_name: Name of the index to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="search",
expected_client="Elasticsearch",
alt_method="asearch",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
if not self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync search. "
"Use asearch instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_query = _build_vector_search_query(
query_vector=query_embedding,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
response = self.client.search(index=collection_name, body=search_query)
return _process_search_results(response, score_threshold)
async def asearch(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Keyword Args:
collection_name: Name of the index to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="asearch",
expected_client="AsyncElasticsearch",
alt_method="search",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
if not await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
query_embedding = await async_fn(query)
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_query = _build_vector_search_query(
query_vector=query_embedding,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
response = await self.client.search(index=collection_name, body=search_query)
return _process_search_results(response, score_threshold)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete an index and all its data.
Keyword Args:
collection_name: Name of the index to delete.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="delete_collection",
expected_client="Elasticsearch",
alt_method="adelete_collection",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
if not self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
self.client.indices.delete(index=collection_name)
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete an index and all its data asynchronously.
Keyword Args:
collection_name: Name of the index to delete.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="adelete_collection",
expected_client="AsyncElasticsearch",
alt_method="delete_collection",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
if not await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
await self.client.indices.delete(index=collection_name)
def reset(self) -> None:
"""Reset the vector database by deleting all indices and data.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="reset",
expected_client="Elasticsearch",
alt_method="areset",
alt_client="AsyncElasticsearch",
)
indices_response = self.client.indices.get(index="*")
for index_name in indices_response.keys():
if not index_name.startswith("."):
self.client.indices.delete(index=index_name)
async def areset(self) -> None:
"""Reset the vector database by deleting all indices and data asynchronously.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="areset",
expected_client="AsyncElasticsearch",
alt_method="reset",
alt_client="Elasticsearch",
)
indices_response = await self.client.indices.get(index="*")
for index_name in indices_response.keys():
if not index_name.startswith("."):
await self.client.indices.delete(index=index_name)

View File

@@ -1,92 +0,0 @@
"""Elasticsearch configuration model."""
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.elasticsearch.types import (
ElasticsearchClientParams,
ElasticsearchEmbeddingFunctionWrapper,
)
from crewai.rag.elasticsearch.constants import (
DEFAULT_HOST,
DEFAULT_PORT,
DEFAULT_EMBEDDING_MODEL,
DEFAULT_VECTOR_DIMENSION,
)
def _default_options() -> ElasticsearchClientParams:
"""Create default Elasticsearch client options.
Returns:
Default options with local Elasticsearch connection.
"""
return ElasticsearchClientParams(
hosts=[f"http://{DEFAULT_HOST}:{DEFAULT_PORT}"],
use_ssl=False,
verify_certs=False,
timeout=30,
)
def _default_embedding_function() -> ElasticsearchEmbeddingFunctionWrapper:
"""Create default Elasticsearch embedding function.
Returns:
Default embedding function using sentence-transformers.
"""
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(DEFAULT_EMBEDDING_MODEL)
def embed_fn(text: str) -> list[float]:
"""Embed a single text string.
Args:
text: Text to embed.
Returns:
Embedding vector as list of floats.
"""
embedding = model.encode(text, convert_to_tensor=False)
return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
return cast(ElasticsearchEmbeddingFunctionWrapper, embed_fn)
except ImportError:
def fallback_embed_fn(text: str) -> list[float]:
"""Fallback embedding function when sentence-transformers is not available."""
import hashlib
import struct
hash_obj = hashlib.md5(text.encode(), usedforsecurity=False)
hash_bytes = hash_obj.digest()
vector = []
for i in range(0, len(hash_bytes), 4):
chunk = hash_bytes[i:i+4]
if len(chunk) == 4:
value = struct.unpack('f', chunk)[0]
vector.append(float(value))
while len(vector) < DEFAULT_VECTOR_DIMENSION:
vector.extend(vector[:DEFAULT_VECTOR_DIMENSION - len(vector)])
return vector[:DEFAULT_VECTOR_DIMENSION]
return cast(ElasticsearchEmbeddingFunctionWrapper, fallback_embed_fn)
@pyd_dataclass(frozen=True)
class ElasticsearchConfig(BaseRagConfig):
"""Configuration for Elasticsearch client."""
provider: Literal["elasticsearch"] = field(default="elasticsearch", init=False)
options: ElasticsearchClientParams = field(default_factory=_default_options)
vector_dimension: int = DEFAULT_VECTOR_DIMENSION
similarity: str = "cosine"
embedding_function: ElasticsearchEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -1,12 +0,0 @@
"""Constants for Elasticsearch RAG implementation."""
from typing import Final
DEFAULT_HOST: Final[str] = "localhost"
DEFAULT_PORT: Final[int] = 9200
DEFAULT_INDEX_SETTINGS: Final[dict] = {
"number_of_shards": 1,
"number_of_replicas": 0,
}
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_VECTOR_DIMENSION: Final[int] = 384

View File

@@ -1,31 +0,0 @@
"""Factory functions for creating Elasticsearch clients."""
from crewai.rag.elasticsearch.config import ElasticsearchConfig
from crewai.rag.elasticsearch.client import ElasticsearchClient
def create_client(config: ElasticsearchConfig) -> ElasticsearchClient:
"""Create an ElasticsearchClient from configuration.
Args:
config: Elasticsearch configuration object.
Returns:
Configured ElasticsearchClient instance.
"""
try:
from elasticsearch import Elasticsearch
except ImportError as e:
raise ImportError(
"elasticsearch package is required for Elasticsearch support. "
"Install it with: pip install elasticsearch"
) from e
client = Elasticsearch(**config.options)
return ElasticsearchClient(
client=client,
embedding_function=config.embedding_function,
vector_dimension=config.vector_dimension,
similarity=config.similarity,
)

View File

@@ -1,93 +0,0 @@
"""Type definitions for Elasticsearch RAG implementation."""
from typing import Any, Protocol, Union, TYPE_CHECKING
from typing_extensions import NotRequired, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
if TYPE_CHECKING:
from typing import TypeAlias
from elasticsearch import Elasticsearch, AsyncElasticsearch
ElasticsearchClientType: TypeAlias = Union[Elasticsearch, AsyncElasticsearch]
else:
try:
from elasticsearch import Elasticsearch, AsyncElasticsearch
ElasticsearchClientType = Union[Elasticsearch, AsyncElasticsearch]
except ImportError:
ElasticsearchClientType = Any
class ElasticsearchClientParams(TypedDict, total=False):
"""Parameters for Elasticsearch client initialization."""
hosts: NotRequired[list[str]]
cloud_id: NotRequired[str]
username: NotRequired[str]
password: NotRequired[str]
api_key: NotRequired[str]
use_ssl: NotRequired[bool]
verify_certs: NotRequired[bool]
ca_certs: NotRequired[str]
timeout: NotRequired[int]
class ElasticsearchIndexSettings(TypedDict, total=False):
"""Settings for Elasticsearch index creation."""
number_of_shards: NotRequired[int]
number_of_replicas: NotRequired[int]
refresh_interval: NotRequired[str]
class ElasticsearchCollectionCreateParams(TypedDict, total=False):
"""Parameters for creating Elasticsearch collections/indices."""
collection_name: str
index_settings: NotRequired[ElasticsearchIndexSettings]
vector_dimension: NotRequired[int]
similarity: NotRequired[str]
class EmbeddingFunction(Protocol):
"""Protocol for embedding functions that convert text to vectors."""
def __call__(self, text: str) -> list[float]:
"""Convert text to embedding vector.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats.
"""
...
class AsyncEmbeddingFunction(Protocol):
"""Protocol for async embedding functions that convert text to vectors."""
async def __call__(self, text: str) -> list[float]:
"""Convert text to embedding vector asynchronously.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats.
"""
...
class ElasticsearchEmbeddingFunctionWrapper(EmbeddingFunction):
"""Base class for Elasticsearch EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Elasticsearch EmbeddingFunction.
This allows Pydantic to handle Elasticsearch's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()

View File

@@ -1,186 +0,0 @@
"""Utility functions for Elasticsearch RAG implementation."""
import hashlib
from typing import Any, TypeGuard
from crewai.rag.elasticsearch.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
ElasticsearchClientType,
)
from crewai.rag.types import BaseRecord, SearchResult
try:
from elasticsearch import Elasticsearch, AsyncElasticsearch
except ImportError:
Elasticsearch = None
AsyncElasticsearch = None
def _is_sync_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
"""Type guard to check if the client is a sync Elasticsearch client."""
if Elasticsearch is None:
return False
return isinstance(client, Elasticsearch)
def _is_async_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
"""Type guard to check if the client is an async Elasticsearch client."""
if AsyncElasticsearch is None:
return False
return isinstance(client, AsyncElasticsearch)
def _is_async_embedding_function(
func: EmbeddingFunction | AsyncEmbeddingFunction,
) -> TypeGuard[AsyncEmbeddingFunction]:
"""Type guard to check if the embedding function is async."""
import inspect
return inspect.iscoroutinefunction(func)
def _generate_doc_id(content: str) -> str:
"""Generate a document ID from content using SHA256 hash."""
return hashlib.sha256(content.encode()).hexdigest()
def _prepare_document_for_elasticsearch(
doc: BaseRecord, embedding: list[float]
) -> dict[str, Any]:
"""Prepare a document for Elasticsearch indexing.
Args:
doc: Document record to prepare.
embedding: Embedding vector for the document.
Returns:
Document formatted for Elasticsearch.
"""
doc_id = doc.get("doc_id") or _generate_doc_id(doc["content"])
es_doc = {
"content": doc["content"],
"content_vector": embedding,
"metadata": doc.get("metadata", {}),
}
return {"id": doc_id, "body": es_doc}
def _process_search_results(
response: dict[str, Any], score_threshold: float | None = None
) -> list[SearchResult]:
"""Process Elasticsearch search response into SearchResult format.
Args:
response: Raw Elasticsearch search response.
score_threshold: Optional minimum score threshold.
Returns:
List of SearchResult dictionaries.
"""
results = []
hits = response.get("hits", {}).get("hits", [])
for hit in hits:
score = hit.get("_score", 0.0)
if score_threshold is not None and score < score_threshold:
continue
source = hit.get("_source", {})
result = SearchResult(
id=hit.get("_id", ""),
content=source.get("content", ""),
metadata=source.get("metadata", {}),
score=score,
)
results.append(result)
return results
def _build_vector_search_query(
query_vector: list[float],
limit: int = 10,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float | None = None,
) -> dict[str, Any]:
"""Build Elasticsearch query for vector similarity search.
Args:
query_vector: Query embedding vector.
limit: Maximum number of results.
metadata_filter: Optional metadata filter.
score_threshold: Optional minimum score threshold.
Returns:
Elasticsearch query dictionary.
"""
query = {
"size": limit,
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
"params": {"query_vector": query_vector}
}
}
}
}
if metadata_filter:
bool_query = {
"bool": {
"must": [
query["query"]
],
"filter": []
}
}
for key, value in metadata_filter.items():
bool_query["bool"]["filter"].append({
"term": {f"metadata.{key}": value}
})
query["query"] = bool_query
if score_threshold is not None:
query["min_score"] = score_threshold
return query
def _get_index_mapping(vector_dimension: int, similarity: str = "cosine") -> dict[str, Any]:
"""Get Elasticsearch index mapping for vector search.
Args:
vector_dimension: Dimension of the embedding vectors.
similarity: Similarity function to use.
Returns:
Elasticsearch mapping dictionary.
"""
return {
"mappings": {
"properties": {
"content": {
"type": "text",
"analyzer": "standard"
},
"content_vector": {
"type": "dense_vector",
"dims": vector_dimension,
"similarity": similarity
},
"metadata": {
"type": "object",
"dynamic": True
}
}
}
}

View File

@@ -5,7 +5,6 @@ from typing import cast
from crewai.rag.config.optional_imports.protocols import (
ChromaFactoryModule,
QdrantFactoryModule,
ElasticsearchFactoryModule,
)
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
@@ -44,15 +43,3 @@ def create_client(config: RagConfigType) -> BaseClient:
),
)
return qdrant_mod.create_client(config)
if config.provider == "elasticsearch":
elasticsearch_mod = cast(
ElasticsearchFactoryModule,
require(
"crewai.rag.elasticsearch.factory",
purpose="The 'elasticsearch' provider",
),
)
return elasticsearch_mod.create_client(config)
raise ValueError(f"Unsupported provider: {config.provider}")

View File

@@ -3,11 +3,6 @@ from datetime import datetime, timedelta
import requests
from unittest.mock import MagicMock, patch, call
from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.authentication.constants import (
AUTH0_AUDIENCE,
AUTH0_CLIENT_ID,
AUTH0_DOMAIN
)
from crewai.cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
@@ -22,16 +17,6 @@ class TestAuthenticationCommand:
@pytest.mark.parametrize(
"user_provider,expected_urls",
[
(
"auth0",
{
"device_code_url": f"https://{AUTH0_DOMAIN}/oauth/device/code",
"token_url": f"https://{AUTH0_DOMAIN}/oauth/token",
"client_id": AUTH0_CLIENT_ID,
"audience": AUTH0_AUDIENCE,
"domain": AUTH0_DOMAIN,
},
),
(
"workos",
{
@@ -44,9 +29,6 @@ class TestAuthenticationCommand:
),
],
)
@patch(
"crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
)
@patch("crewai.cli.authentication.main.AuthenticationCommand._get_device_code")
@patch(
"crewai.cli.authentication.main.AuthenticationCommand._display_auth_instructions"
@@ -59,11 +41,9 @@ class TestAuthenticationCommand:
mock_poll,
mock_display,
mock_get_device,
mock_determine_provider,
user_provider,
expected_urls,
):
mock_determine_provider.return_value = user_provider
mock_get_device.return_value = {
"device_code": "test_code",
"user_code": "123456",
@@ -74,7 +54,6 @@ class TestAuthenticationCommand:
mock_console_print.assert_called_once_with(
"Signing in to CrewAI Enterprise...\n", style="bold blue"
)
mock_determine_provider.assert_called_once()
mock_get_device.assert_called_once()
mock_display.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"}
@@ -82,9 +61,17 @@ class TestAuthenticationCommand:
mock_poll.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"},
)
assert self.auth_command.oauth2_provider.get_client_id() == expected_urls["client_id"]
assert self.auth_command.oauth2_provider.get_audience() == expected_urls["audience"]
assert self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
assert (
self.auth_command.oauth2_provider.get_client_id()
== expected_urls["client_id"]
)
assert (
self.auth_command.oauth2_provider.get_audience()
== expected_urls["audience"]
)
assert (
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
)
@patch("crewai.cli.authentication.main.webbrowser")
@patch("crewai.cli.authentication.main.console.print")
@@ -106,14 +93,6 @@ class TestAuthenticationCommand:
@pytest.mark.parametrize(
"user_provider,jwt_config",
[
(
"auth0",
{
"jwks_url": f"https://{AUTH0_DOMAIN}/.well-known/jwks.json",
"issuer": f"https://{AUTH0_DOMAIN}/",
"audience": AUTH0_AUDIENCE,
},
),
(
"workos",
{
@@ -135,14 +114,18 @@ class TestAuthenticationCommand:
jwt_config,
has_expiration,
):
from crewai.cli.authentication.providers.auth0 import Auth0Provider
from crewai.cli.authentication.providers.workos import WorkosProvider
from crewai.cli.authentication.main import Oauth2Settings
if user_provider == "auth0":
self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=AUTH0_DOMAIN, audience=jwt_config["audience"]))
elif user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, audience=jwt_config["audience"]))
if user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(
settings=Oauth2Settings(
provider=user_provider,
client_id="test-client-id",
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
audience=jwt_config["audience"],
)
)
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
@@ -234,83 +217,6 @@ class TestAuthenticationCommand:
]
mock_console_print.assert_has_calls(expected_calls)
@pytest.mark.parametrize(
"api_response,expected_provider",
[
({"provider": "auth0"}, "auth0"),
({"provider": "workos"}, "workos"),
({"provider": "none"}, "workos"), # Default to workos for any other value
(
{},
"workos",
), # Default to workos if no provider key is sent in the response
],
)
@patch("crewai.cli.authentication.main.PlusAPI")
@patch("crewai.cli.authentication.main.console.print")
@patch("builtins.input", return_value="test@example.com")
def test_determine_user_provider_success(
self,
mock_input,
mock_console_print,
mock_plus_api,
api_response,
expected_provider,
):
mock_api_instance = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = api_response
mock_api_instance._make_request.return_value = mock_response
mock_plus_api.return_value = mock_api_instance
result = self.auth_command._determine_user_provider()
mock_input.assert_called_once()
mock_plus_api.assert_called_once_with("")
mock_api_instance._make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
)
assert result == expected_provider
@patch("crewai.cli.authentication.main.PlusAPI")
@patch("crewai.cli.authentication.main.console.print")
@patch("builtins.input", return_value="test@example.com")
def test_determine_user_provider_error(
self, mock_input, mock_console_print, mock_plus_api
):
mock_api_instance = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 500
mock_api_instance._make_request.return_value = mock_response
mock_plus_api.return_value = mock_api_instance
with pytest.raises(SystemExit):
self.auth_command._determine_user_provider()
mock_input.assert_called_once()
mock_plus_api.assert_called_once_with("")
mock_api_instance._make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
)
mock_console_print.assert_has_calls(
[
call(
"Enter your CrewAI Enterprise account email: ",
style="bold blue",
end="",
),
call(
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
style="red",
),
]
)
@patch("requests.post")
def test_get_device_code(self, mock_post):
mock_response = MagicMock()
@@ -323,7 +229,9 @@ class TestAuthenticationCommand:
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command.oauth2_provider.get_authorize_url.return_value = "https://example.com/device"
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
"https://example.com/device"
)
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
result = self.auth_command._get_device_code()
@@ -366,12 +274,12 @@ class TestAuthenticationCommand:
) as mock_tool_login,
):
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_token_url.return_value = "https://example.com/token"
self.auth_command.oauth2_provider.get_token_url.return_value = (
"https://example.com/token"
)
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command._poll_for_token(
device_code_data
)
self.auth_command._poll_for_token(device_code_data)
mock_post.assert_called_once_with(
"https://example.com/token",
@@ -406,9 +314,7 @@ class TestAuthenticationCommand:
"interval": 0.1, # Short interval for testing
}
self.auth_command._poll_for_token(
device_code_data
)
self.auth_command._poll_for_token(device_code_data)
mock_console_print.assert_any_call(
"Timeout: Failed to get the token. Please try again.", style="bold red"
@@ -429,15 +335,4 @@ class TestAuthenticationCommand:
device_code_data = {"device_code": "test_device_code", "interval": 1}
with pytest.raises(requests.HTTPError):
self.auth_command._poll_for_token(
device_code_data
)
# @patch(
# "crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
# )
# def test_login_with_auth0(self, mock_determine_provider):
# from crewai.cli.authentication.providers.auth0 import Auth0Provider
# from crewai.cli.authentication.main import Oauth2Settings
# self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider="auth0", client_id=AUTH0_CLIENT_ID, domain=AUTH0_DOMAIN, audience=AUTH0_AUDIENCE))
# self.auth_command.login()
self.auth_command._poll_for_token(device_code_data)

View File

@@ -2,8 +2,6 @@
from unittest.mock import Mock, patch
import pytest
from crewai.rag.factory import create_client
@@ -27,50 +25,10 @@ def test_create_client_chromadb():
mock_module.create_client.assert_called_once_with(mock_config)
def test_create_client_qdrant():
"""Test Qdrant client creation."""
mock_config = Mock()
mock_config.provider = "qdrant"
with patch("crewai.rag.factory.require") as mock_require:
mock_module = Mock()
mock_client = Mock()
mock_module.create_client.return_value = mock_client
mock_require.return_value = mock_module
result = create_client(mock_config)
assert result == mock_client
mock_require.assert_called_once_with(
"crewai.rag.qdrant.factory", purpose="The 'qdrant' provider"
)
mock_module.create_client.assert_called_once_with(mock_config)
def test_create_client_elasticsearch():
"""Test Elasticsearch client creation."""
mock_config = Mock()
mock_config.provider = "elasticsearch"
with patch("crewai.rag.factory.require") as mock_require:
mock_module = Mock()
mock_client = Mock()
mock_module.create_client.return_value = mock_client
mock_require.return_value = mock_module
result = create_client(mock_config)
assert result == mock_client
mock_require.assert_called_once_with(
"crewai.rag.elasticsearch.factory", purpose="The 'elasticsearch' provider"
)
mock_module.create_client.assert_called_once_with(mock_config)
def test_create_client_unsupported_provider():
"""Test that unsupported provider raises ValueError."""
"""Test unsupported provider returns None for now."""
mock_config = Mock()
mock_config.provider = "unsupported"
with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
create_client(mock_config)
result = create_client(mock_config)
assert result is None

View File

@@ -3,10 +3,7 @@
import pytest
from crewai.rag.config.optional_imports.base import _MissingProvider
from crewai.rag.config.optional_imports.providers import (
MissingChromaDBConfig,
MissingElasticsearchConfig,
)
from crewai.rag.config.optional_imports.providers import MissingChromaDBConfig
def test_missing_provider_raises_runtime_error():
@@ -23,11 +20,3 @@ def test_missing_chromadb_config_raises_runtime_error():
RuntimeError, match="provider 'chromadb' requested but not installed"
):
MissingChromaDBConfig()
def test_missing_elasticsearch_config_raises_runtime_error():
"""Test that MissingElasticsearchConfig raises RuntimeError on instantiation."""
with pytest.raises(
RuntimeError, match="provider 'elasticsearch' requested but not installed"
):
MissingElasticsearchConfig()

View File

@@ -1 +0,0 @@
"""Tests for Elasticsearch RAG implementation."""

View File

@@ -1,397 +0,0 @@
"""Tests for ElasticsearchClient implementation."""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from crewai.rag.elasticsearch.client import ElasticsearchClient
from crewai.rag.types import BaseRecord
from crewai.rag.core.exceptions import ClientMethodMismatchError
@pytest.fixture
def mock_elasticsearch_client():
"""Create a mock Elasticsearch client."""
mock_client = Mock()
mock_client.indices = Mock()
mock_client.indices.exists.return_value = False
mock_client.indices.create.return_value = {"acknowledged": True}
mock_client.indices.get.return_value = {"test_index": {"mappings": {}}}
mock_client.indices.delete.return_value = {"acknowledged": True}
mock_client.index.return_value = {"_id": "test_id", "result": "created"}
mock_client.search.return_value = {
"hits": {
"hits": [
{
"_id": "doc1",
"_score": 0.9,
"_source": {
"content": "test content",
"metadata": {"key": "value"}
}
}
]
}
}
return mock_client
@pytest.fixture
def mock_async_elasticsearch_client():
"""Create a mock async Elasticsearch client."""
mock_client = Mock()
mock_client.indices = Mock()
mock_client.indices.exists = AsyncMock(return_value=False)
mock_client.indices.create = AsyncMock(return_value={"acknowledged": True})
mock_client.indices.get = AsyncMock(return_value={"test_index": {"mappings": {}}})
mock_client.indices.delete = AsyncMock(return_value={"acknowledged": True})
mock_client.index = AsyncMock(return_value={"_id": "test_id", "result": "created"})
mock_client.search = AsyncMock(return_value={
"hits": {
"hits": [
{
"_id": "doc1",
"_score": 0.9,
"_source": {
"content": "test content",
"metadata": {"key": "value"}
}
}
]
}
})
return mock_client
@pytest.fixture
def client(mock_elasticsearch_client) -> ElasticsearchClient:
"""Create an ElasticsearchClient instance for testing."""
mock_embedding = Mock()
mock_embedding.return_value = [0.1, 0.2, 0.3]
client = ElasticsearchClient(
client=mock_elasticsearch_client,
embedding_function=mock_embedding,
vector_dimension=3,
similarity="cosine"
)
return client
@pytest.fixture
def async_client(mock_async_elasticsearch_client) -> ElasticsearchClient:
"""Create an ElasticsearchClient instance with async client for testing."""
mock_embedding = Mock()
mock_embedding.return_value = [0.1, 0.2, 0.3]
client = ElasticsearchClient(
client=mock_async_elasticsearch_client,
embedding_function=mock_embedding,
vector_dimension=3,
similarity="cosine"
)
return client
class TestElasticsearchClient:
"""Test suite for ElasticsearchClient."""
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that create_collection creates a new index."""
mock_elasticsearch_client.indices.exists.return_value = False
client.create_collection(collection_name="test_index")
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.indices.create.assert_called_once()
call_args = mock_elasticsearch_client.indices.create.call_args
assert call_args.kwargs["index"] == "test_index"
assert "mappings" in call_args.kwargs["body"]
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_create_collection_already_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that create_collection raises error if index exists."""
mock_elasticsearch_client.indices.exists.return_value = True
with pytest.raises(
ValueError, match="Index 'test_index' already exists"
):
client.create_collection(collection_name="test_index")
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
def test_create_collection_wrong_client_type(self, mock_is_async, mock_is_sync, mock_async_elasticsearch_client):
"""Test that create_collection raises error with async client."""
mock_embedding = Mock()
client = ElasticsearchClient(
client=mock_async_elasticsearch_client,
embedding_function=mock_embedding
)
with pytest.raises(ClientMethodMismatchError):
client.create_collection(collection_name="test_index")
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_acreate_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that acreate_collection creates a new index asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = False
await async_client.acreate_collection(collection_name="test_index")
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.indices.create.assert_called_once()
call_args = mock_async_elasticsearch_client.indices.create.call_args
assert call_args.kwargs["index"] == "test_index"
assert "mappings" in call_args.kwargs["body"]
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_acreate_collection_already_exists(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that acreate_collection raises error if index exists."""
mock_async_elasticsearch_client.indices.exists.return_value = True
with pytest.raises(
ValueError, match="Index 'test_index' already exists"
):
await async_client.acreate_collection(collection_name="test_index")
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_get_or_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that get_or_create_collection returns existing index."""
mock_elasticsearch_client.indices.exists.return_value = True
result = client.get_or_create_collection(collection_name="test_index")
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
assert result == {"test_index": {"mappings": {}}}
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_get_or_create_collection_creates_new(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that get_or_create_collection creates new index if not exists."""
mock_elasticsearch_client.indices.exists.return_value = False
client.get_or_create_collection(collection_name="test_index")
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.indices.create.assert_called_once()
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_aget_or_create_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that aget_or_create_collection returns existing index asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = True
result = await async_client.aget_or_create_collection(collection_name="test_index")
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
assert result == {"test_index": {"mappings": {}}}
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_add_documents(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that add_documents indexes documents correctly."""
mock_elasticsearch_client.indices.exists.return_value = True
documents: list[BaseRecord] = [
{
"content": "test content",
"metadata": {"key": "value"}
}
]
client.add_documents(collection_name="test_index", documents=documents)
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.index.assert_called_once()
call_args = mock_elasticsearch_client.index.call_args
assert call_args.kwargs["index"] == "test_index"
assert "body" in call_args.kwargs
assert call_args.kwargs["body"]["content"] == "test content"
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_add_documents_empty_list_raises_error(self, mock_is_async, mock_is_sync, client):
"""Test that add_documents raises error with empty documents list."""
with pytest.raises(ValueError, match="Documents list cannot be empty"):
client.add_documents(collection_name="test_index", documents=[])
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_add_documents_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that add_documents raises error if index doesn't exist."""
mock_elasticsearch_client.indices.exists.return_value = False
documents: list[BaseRecord] = [{"content": "test content"}]
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
client.add_documents(collection_name="test_index", documents=documents)
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_aadd_documents(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that aadd_documents indexes documents correctly asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = True
documents: list[BaseRecord] = [
{
"content": "test content",
"metadata": {"key": "value"}
}
]
await async_client.aadd_documents(collection_name="test_index", documents=documents)
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.index.assert_called_once()
call_args = mock_async_elasticsearch_client.index.call_args
assert call_args.kwargs["index"] == "test_index"
assert "body" in call_args.kwargs
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_search(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that search performs vector similarity search."""
mock_elasticsearch_client.indices.exists.return_value = True
results = client.search(
collection_name="test_index",
query="test query",
limit=5
)
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.search.assert_called_once()
call_args = mock_elasticsearch_client.search.call_args
assert call_args.kwargs["index"] == "test_index"
assert "body" in call_args.kwargs
assert len(results) == 1
assert results[0]["id"] == "doc1"
assert results[0]["content"] == "test content"
assert results[0]["score"] == 0.9
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_search_with_metadata_filter(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that search applies metadata filter correctly."""
mock_elasticsearch_client.indices.exists.return_value = True
client.search(
collection_name="test_index",
query="test query",
metadata_filter={"key": "value"}
)
mock_elasticsearch_client.search.assert_called_once()
call_args = mock_elasticsearch_client.search.call_args
query_body = call_args.kwargs["body"]
assert "bool" in query_body["query"]
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_search_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that search raises error if index doesn't exist."""
mock_elasticsearch_client.indices.exists.return_value = False
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
client.search(collection_name="test_index", query="test query")
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_asearch(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that asearch performs vector similarity search asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = True
results = await async_client.asearch(
collection_name="test_index",
query="test query",
limit=5
)
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.search.assert_called_once()
assert len(results) == 1
assert results[0]["id"] == "doc1"
assert results[0]["content"] == "test content"
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_delete_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that delete_collection deletes the index."""
mock_elasticsearch_client.indices.exists.return_value = True
client.delete_collection(collection_name="test_index")
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_delete_collection_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that delete_collection raises error if index doesn't exist."""
mock_elasticsearch_client.indices.exists.return_value = False
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
client.delete_collection(collection_name="test_index")
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_adelete_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that adelete_collection deletes the index asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = True
await async_client.adelete_collection(collection_name="test_index")
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_reset(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that reset deletes all non-system indices."""
mock_elasticsearch_client.indices.get.return_value = {
"test_index": {},
".system_index": {},
"another_index": {}
}
client.reset()
mock_elasticsearch_client.indices.get.assert_called_once_with(index="*")
assert mock_elasticsearch_client.indices.delete.call_count == 2
delete_calls = [call.kwargs["index"] for call in mock_elasticsearch_client.indices.delete.call_args_list]
assert "test_index" in delete_calls
assert "another_index" in delete_calls
assert ".system_index" not in delete_calls
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_areset(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that areset deletes all non-system indices asynchronously."""
mock_async_elasticsearch_client.indices.get.return_value = {
"test_index": {},
".system_index": {},
"another_index": {}
}
await async_client.areset()
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="*")
assert mock_async_elasticsearch_client.indices.delete.call_count == 2

View File

@@ -1,49 +0,0 @@
"""Tests for Elasticsearch configuration."""
from crewai.rag.elasticsearch.config import ElasticsearchConfig
def test_elasticsearch_config_defaults():
"""Test that ElasticsearchConfig has correct defaults."""
config = ElasticsearchConfig()
assert config.provider == "elasticsearch"
assert config.vector_dimension == 384
assert config.similarity == "cosine"
assert config.embedding_function is not None
assert config.options["hosts"] == ["http://localhost:9200"]
assert config.options["use_ssl"] is False
def test_elasticsearch_config_custom_options():
"""Test that ElasticsearchConfig accepts custom options."""
custom_options = {
"hosts": ["https://elastic.example.com:9200"],
"username": "user",
"password": "pass",
"use_ssl": True,
}
config = ElasticsearchConfig(
options=custom_options,
vector_dimension=768,
similarity="dot_product"
)
assert config.provider == "elasticsearch"
assert config.vector_dimension == 768
assert config.similarity == "dot_product"
assert config.options["hosts"] == ["https://elastic.example.com:9200"]
assert config.options["username"] == "user"
assert config.options["use_ssl"] is True
def test_elasticsearch_config_embedding_function():
"""Test that embedding function works correctly."""
config = ElasticsearchConfig()
embedding = config.embedding_function("test text")
assert isinstance(embedding, list)
assert len(embedding) == config.vector_dimension
assert all(isinstance(x, float) for x in embedding)

View File

@@ -1,40 +0,0 @@
"""Tests for Elasticsearch factory."""
from unittest.mock import Mock, patch
import pytest
from crewai.rag.elasticsearch.config import ElasticsearchConfig
def test_create_client():
"""Test that create_client creates an ElasticsearchClient."""
config = ElasticsearchConfig()
with patch.dict('sys.modules', {'elasticsearch': Mock()}):
mock_elasticsearch_module = Mock()
mock_client_instance = Mock()
mock_elasticsearch_module.Elasticsearch.return_value = mock_client_instance
with patch.dict('sys.modules', {'elasticsearch': mock_elasticsearch_module}):
from crewai.rag.elasticsearch.factory import create_client
client = create_client(config)
mock_elasticsearch_module.Elasticsearch.assert_called_once_with(**config.options)
assert client.client == mock_client_instance
assert client.embedding_function == config.embedding_function
assert client.vector_dimension == config.vector_dimension
assert client.similarity == config.similarity
def test_create_client_missing_elasticsearch():
"""Test that create_client raises ImportError when elasticsearch is not installed."""
config = ElasticsearchConfig()
with patch.dict('sys.modules', {}, clear=False):
if 'elasticsearch' in __import__('sys').modules:
del __import__('sys').modules['elasticsearch']
from crewai.rag.elasticsearch.factory import create_client
with pytest.raises(ImportError, match="elasticsearch package is required"):
create_client(config)

View File

@@ -624,12 +624,12 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
_, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"]
assert any(isinstance(tool, TestTool) for tool in tools), (
"TestTool should be present"
)
assert any("delegate" in tool.name.lower() for tool in tools), (
"Delegation tool should be present"
)
assert any(
isinstance(tool, TestTool) for tool in tools
), "TestTool should be present"
assert any(
"delegate" in tool.name.lower() for tool in tools
), "Delegation tool should be present"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -688,12 +688,12 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
_, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"]
assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), (
"TestTool should be present"
)
assert any("delegate" in tool.name.lower() for tool in tools), (
"Delegation tool should be present"
)
assert any(
isinstance(tool, TestTool) for tool in new_ceo.tools
), "TestTool should be present"
assert any(
"delegate" in tool.name.lower() for tool in tools
), "Delegation tool should be present"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -817,17 +817,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
used_tools = kwargs["tools"]
# Confirm AnotherTestTool is present but TestTool is not
assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), (
"AnotherTestTool should be present"
)
assert not any(isinstance(tool, TestTool) for tool in used_tools), (
"TestTool should not be present among used tools"
)
assert any(
isinstance(tool, AnotherTestTool) for tool in used_tools
), "AnotherTestTool should be present"
assert not any(
isinstance(tool, TestTool) for tool in used_tools
), "TestTool should not be present among used tools"
# Confirm delegation tool(s) are present
assert any("delegate" in tool.name.lower() for tool in used_tools), (
"Delegation tool should be present"
)
assert any(
"delegate" in tool.name.lower() for tool in used_tools
), "Delegation tool should be present"
# Finally, make sure the agent's original tools remain unchanged
assert len(researcher_with_delegation.tools) == 1
@@ -931,9 +931,9 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
tool="multiplier", input={"first_number": 2, "second_number": 6}
)
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
assert cache_calls[1] == expected_call, (
f"Second call mismatch: {cache_calls[1]}"
)
assert (
cache_calls[1] == expected_call
), f"Second call mismatch: {cache_calls[1]}"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -1676,9 +1676,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
# Verify that exactly one tool was used and it was a CodeInterpreterTool
assert len(used_tools) == 1, "Should have exactly one tool"
assert isinstance(used_tools[0], CodeInterpreterTool), (
"Tool should be CodeInterpreterTool"
)
assert isinstance(
used_tools[0], CodeInterpreterTool
), "Tool should be CodeInterpreterTool"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -2537,8 +2537,8 @@ def test_memory_events_are_emitted():
crew.kickoff()
assert len(events["MemorySaveStartedEvent"]) == 6
assert len(events["MemorySaveCompletedEvent"]) == 6
assert len(events["MemorySaveStartedEvent"]) == 3
assert len(events["MemorySaveCompletedEvent"]) == 3
assert len(events["MemorySaveFailedEvent"]) == 0
assert len(events["MemoryQueryStartedEvent"]) == 3
assert len(events["MemoryQueryCompletedEvent"]) == 3
@@ -3817,9 +3817,9 @@ def test_fetch_inputs():
expected_placeholders = {"role_detail", "topic", "field"}
actual_placeholders = crew.fetch_inputs()
assert actual_placeholders == expected_placeholders, (
f"Expected {expected_placeholders}, but got {actual_placeholders}"
)
assert (
actual_placeholders == expected_placeholders
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
def test_task_tools_preserve_code_execution_tools():
@@ -3894,20 +3894,20 @@ def test_task_tools_preserve_code_execution_tools():
used_tools = kwargs["tools"]
# Verify all expected tools are present
assert any(isinstance(tool, TestTool) for tool in used_tools), (
"Task's TestTool should be present"
)
assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), (
"CodeInterpreterTool should be present"
)
assert any("delegate" in tool.name.lower() for tool in used_tools), (
"Delegation tool should be present"
)
assert any(
isinstance(tool, TestTool) for tool in used_tools
), "Task's TestTool should be present"
assert any(
isinstance(tool, CodeInterpreterTool) for tool in used_tools
), "CodeInterpreterTool should be present"
assert any(
"delegate" in tool.name.lower() for tool in used_tools
), "Delegation tool should be present"
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
assert len(used_tools) == 4, (
"Should have TestTool, CodeInterpreter, and 2 delegation tools"
)
assert (
len(used_tools) == 4
), "Should have TestTool, CodeInterpreter, and 2 delegation tools"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -3951,9 +3951,9 @@ def test_multimodal_flag_adds_multimodal_tools():
used_tools = kwargs["tools"]
# Check that the multimodal tool was added
assert any(isinstance(tool, AddImageTool) for tool in used_tools), (
"AddImageTool should be present when agent is multimodal"
)
assert any(
isinstance(tool, AddImageTool) for tool in used_tools
), "AddImageTool should be present when agent is multimodal"
# Verify we have exactly one tool (just the AddImageTool)
assert len(used_tools) == 1, "Should only have the AddImageTool"
@@ -4217,9 +4217,9 @@ def test_crew_guardrail_feedback_in_context():
assert len(execution_contexts) > 1, "Task should have been executed multiple times"
# Verify that the second execution included the guardrail feedback
assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], (
"Guardrail feedback should be included in retry context"
)
assert (
"Output must contain the keyword 'IMPORTANT'" in execution_contexts[1]
), "Guardrail feedback should be included in retry context"
# Verify final output meets guardrail requirements
assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
@@ -4435,46 +4435,46 @@ def test_crew_copy_with_memory():
try:
crew_copy = crew.copy()
assert hasattr(crew_copy, "_short_term_memory"), (
"Copied crew should have _short_term_memory"
)
assert crew_copy._short_term_memory is not None, (
"Copied _short_term_memory should not be None"
)
assert id(crew_copy._short_term_memory) != original_short_term_id, (
"Copied _short_term_memory should be a new object"
)
assert hasattr(
crew_copy, "_short_term_memory"
), "Copied crew should have _short_term_memory"
assert (
crew_copy._short_term_memory is not None
), "Copied _short_term_memory should not be None"
assert (
id(crew_copy._short_term_memory) != original_short_term_id
), "Copied _short_term_memory should be a new object"
assert hasattr(crew_copy, "_long_term_memory"), (
"Copied crew should have _long_term_memory"
)
assert crew_copy._long_term_memory is not None, (
"Copied _long_term_memory should not be None"
)
assert id(crew_copy._long_term_memory) != original_long_term_id, (
"Copied _long_term_memory should be a new object"
)
assert hasattr(
crew_copy, "_long_term_memory"
), "Copied crew should have _long_term_memory"
assert (
crew_copy._long_term_memory is not None
), "Copied _long_term_memory should not be None"
assert (
id(crew_copy._long_term_memory) != original_long_term_id
), "Copied _long_term_memory should be a new object"
assert hasattr(crew_copy, "_entity_memory"), (
"Copied crew should have _entity_memory"
)
assert crew_copy._entity_memory is not None, (
"Copied _entity_memory should not be None"
)
assert id(crew_copy._entity_memory) != original_entity_id, (
"Copied _entity_memory should be a new object"
)
assert hasattr(
crew_copy, "_entity_memory"
), "Copied crew should have _entity_memory"
assert (
crew_copy._entity_memory is not None
), "Copied _entity_memory should not be None"
assert (
id(crew_copy._entity_memory) != original_entity_id
), "Copied _entity_memory should be a new object"
if original_external_id:
assert hasattr(crew_copy, "_external_memory"), (
"Copied crew should have _external_memory"
)
assert crew_copy._external_memory is not None, (
"Copied _external_memory should not be None"
)
assert id(crew_copy._external_memory) != original_external_id, (
"Copied _external_memory should be a new object"
)
assert hasattr(
crew_copy, "_external_memory"
), "Copied crew should have _external_memory"
assert (
crew_copy._external_memory is not None
), "Copied _external_memory should not be None"
assert (
id(crew_copy._external_memory) != original_external_id
), "Copied _external_memory should be a new object"
else:
assert (
not hasattr(crew_copy, "_external_memory")

View File

@@ -0,0 +1,91 @@
"""Tests for dependency compatibility, specifically for issue #3413.
This module tests that CrewAI can be installed alongside Google Cloud SDKs
without protobuf dependency conflicts.
"""
import subprocess
import sys
import tempfile
import pytest
class TestDependencyCompatibility:
"""Test dependency compatibility with Google Cloud SDKs."""
def test_opentelemetry_protobuf_compatibility(self):
"""Test that opentelemetry versions work with protobuf<5.0."""
try:
from opentelemetry.sdk.trace import TracerProvider
tracer_provider = TracerProvider()
tracer = tracer_provider.get_tracer("test")
with tracer.start_as_current_span("test_span") as span:
span.set_attribute("test", "value")
assert span is not None
except ImportError as e:
pytest.fail(f"Failed to import opentelemetry modules: {e}")
def test_google_cloud_sdk_compatibility_simulation(self):
"""Simulate Google Cloud SDK protobuf requirements."""
try:
import google.protobuf
version_parts = google.protobuf.__version__.split('.')
major_version = int(version_parts[0])
assert major_version < 5, f"protobuf version {google.protobuf.__version__} should be <5.0 for Google Cloud SDK compatibility"
except ImportError:
pytest.skip("protobuf not installed, skipping compatibility test")
def test_crewai_telemetry_functionality(self):
"""Test that CrewAI telemetry functionality works with downgraded opentelemetry."""
try:
from crewai.telemetry.telemetry import Telemetry
from crewai.utilities.crew.crew_context import get_crew_context
telemetry = Telemetry()
assert telemetry is not None
get_crew_context()
except ImportError as e:
pytest.fail(f"Failed to import CrewAI telemetry modules: {e}")
def test_dry_run_installation_compatibility(self):
"""Test that CrewAI and Google Cloud SDKs can be installed together."""
with tempfile.TemporaryDirectory() as temp_dir:
try:
result = subprocess.run([
sys.executable, "-m", "pip", "install", "--dry-run", "--quiet",
"opentelemetry-api>=1.27.0,<1.30.0",
"opentelemetry-sdk>=1.27.0,<1.30.0",
"opentelemetry-exporter-otlp-proto-http>=1.27.0,<1.30.0",
"google-cloud-storage"
], capture_output=True, text=True, cwd=temp_dir)
assert result.returncode == 0, f"Dry run installation failed: {result.stderr}"
assert "protobuf" in result.stdout.lower(), "protobuf should be in installation plan"
except Exception as e:
pytest.fail(f"Dry run installation test failed: {e}")
def test_protobuf_version_constraint_resolution(self):
"""Test that protobuf version constraints are properly resolved."""
try:
import google.protobuf
version = google.protobuf.__version__
version_parts = [int(x) for x in version.split('.')]
major, minor = version_parts[0], version_parts[1]
assert major >= 3, f"protobuf version {version} should be >=3.19"
if major == 3:
assert minor >= 19, f"protobuf version {version} should be >=3.19"
assert major < 5, f"protobuf version {version} should be <5.0 for Google Cloud SDK compatibility"
except ImportError:
pytest.skip("protobuf not installed, skipping version constraint test")

6928
uv.lock generated

File diff suppressed because it is too large Load Diff