Compare commits

..

4 Commits

Author SHA1 Message Date
Greyson LaLonde
d4c27d22cf feat: implement file-based locking decorator for concurrent RAG client access 2025-08-27 12:03:12 -04:00
Greyson LaLonde
109de91d08 fix: batch entity memory items to reduce redundant operations (#3409)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
* fix: batch save entity memory items to reduce redundant operations

* test: update memory event count after entity batch save implementation
2025-08-27 10:47:20 -04:00
Erika Shorten
92b70e652d Add hybrid search alpha parameter to the docs (#3397)
Co-authored-by: Tony Kipkemboi <iamtonykipkemboi@gmail.com>
2025-08-27 10:36:39 -04:00
Heitor Carvalho
fc3f2c49d2 chore: remove auth0 and the need of typing the email on 'crewai login' (#3408)
* Remove the need of typing the email on 'crewai login'

* Remove auth0 constants, update tests
2025-08-27 10:12:57 -04:00
27 changed files with 822 additions and 1799 deletions

View File

@@ -1,13 +1,13 @@
---
title: Weaviate Vector Search
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents.
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents using hybrid search.
icon: network-wired
---
## Overview
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector embeddings for more accurate and contextually relevant search results.
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector and keyword search for more accurate and contextually relevant search results.
[Weaviate](https://weaviate.io/) is a vector database that stores and queries vector embeddings, enabling semantic search capabilities.
@@ -39,6 +39,7 @@ from crewai_tools import WeaviateVectorSearchTool
tool = WeaviateVectorSearchTool(
collection_name='example_collections',
limit=3,
alpha=0.75,
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
weaviate_api_key="your-weaviate-api-key",
)
@@ -63,6 +64,7 @@ The `WeaviateVectorSearchTool` accepts the following parameters:
- **weaviate_cluster_url**: Required. The URL of the Weaviate cluster.
- **weaviate_api_key**: Required. The API key for the Weaviate cluster.
- **limit**: Optional. The number of results to return. Default is `3`.
- **alpha**: Optional. Controls the weighting between vector and keyword (BM25) search. alpha = 0 -> BM25 only, alpha = 1 -> vector search only. Default is `0.75`.
- **vectorizer**: Optional. The vectorizer to use. If not provided, it will use `text2vec_openai` with the `nomic-embed-text` model.
- **generative_model**: Optional. The generative model to use. If not provided, it will use OpenAI's `gpt-4o`.
@@ -78,6 +80,7 @@ from weaviate.classes.config import Configure
tool = WeaviateVectorSearchTool(
collection_name='example_collections',
limit=3,
alpha=0.75,
vectorizer=Configure.Vectorizer.text2vec_openai(model="nomic-embed-text"),
generative_model=Configure.Generative.openai(model="gpt-4o-mini"),
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
@@ -128,6 +131,7 @@ with test_docs.batch.dynamic() as batch:
tool = WeaviateVectorSearchTool(
collection_name='example_collections',
limit=3,
alpha=0.75,
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
weaviate_api_key="your-weaviate-api-key",
)
@@ -145,6 +149,7 @@ from crewai_tools import WeaviateVectorSearchTool
weaviate_tool = WeaviateVectorSearchTool(
collection_name='example_collections',
limit=3,
alpha=0.75,
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
weaviate_api_key="your-weaviate-api-key",
)

View File

@@ -98,8 +98,8 @@ class CrewAgentExecutorMixin:
)
self.crew._long_term_memory.save(long_term_memory)
for entity in evaluation.entities:
entity_memory = EntityMemoryItem(
entity_memories = [
EntityMemoryItem(
name=entity.name,
type=entity.type,
description=entity.description,
@@ -107,7 +107,10 @@ class CrewAgentExecutorMixin:
[f"- {r}" for r in entity.relationships]
),
)
self.crew._entity_memory.save(entity_memory)
for entity in evaluation.entities
]
if entity_memories:
self.crew._entity_memory.save(entity_memories)
except AttributeError as e:
print(f"Missing attributes for long term memory: {e}")
pass

View File

@@ -1,6 +1 @@
ALGORITHMS = ["RS256"]
#TODO: The AUTH0 constants should be removed after WorkOS migration is completed
AUTH0_DOMAIN = "crewai.us.auth0.com"
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"

View File

@@ -9,14 +9,7 @@ from pydantic import BaseModel, Field
from .utils import validate_jwt_token
from crewai.cli.shared.token_manager import TokenManager
from urllib.parse import quote
from crewai.cli.plus_api import PlusAPI
from crewai.cli.config import Settings
from crewai.cli.authentication.constants import (
AUTH0_AUDIENCE,
AUTH0_CLIENT_ID,
AUTH0_DOMAIN,
)
console = Console()
@@ -72,18 +65,6 @@ class AuthenticationCommand:
"""Sign up to CrewAI+"""
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
# TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete.
user_provider = self._determine_user_provider()
if user_provider == "auth0":
settings = Oauth2Settings(
provider="auth0",
client_id=AUTH0_CLIENT_ID,
domain=AUTH0_DOMAIN,
audience=AUTH0_AUDIENCE,
)
self.oauth2_provider = ProviderFactory.from_settings(settings)
# End of temporary code.
device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data)
@@ -206,30 +187,3 @@ class AuthenticationCommand:
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
)
# TODO: WORKOS - This method is temporary until migration to WorkOS is complete.
def _determine_user_provider(self) -> str:
"""Determine which provider to use for authentication."""
console.print(
"Enter your CrewAI Enterprise account email: ", style="bold blue", end=""
)
email = input()
email_encoded = quote(email)
# It's not correct to call this method directly, but it's temporary until migration is complete.
response = PlusAPI("")._make_request(
"GET", f"/crewai_plus/api/v1/me/provider?email={email_encoded}"
)
if response.status_code == 200:
if response.json().get("provider") == "auth0":
return "auth0"
else:
return "workos"
else:
console.print(
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
style="red",
)
raise SystemExit

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any
import time
from pydantic import PrivateAttr
@@ -24,7 +24,7 @@ class EntityMemory(Memory):
Inherits from the Memory class.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None
@@ -53,12 +53,33 @@ class EntityMemory(Memory):
super().__init__(storage=storage)
self._memory_provider = memory_provider
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
"""Saves an entity item into the SQLite storage."""
def save(
self,
value: EntityMemoryItem | list[EntityMemoryItem],
metadata: dict[str, Any] | None = None,
) -> None:
"""Saves one or more entity items into the SQLite storage.
Args:
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
metadata: Optional metadata dict (included for supertype compatibility but not used).
Notes:
The metadata parameter is included to satisfy the supertype signature but is not
used - entity metadata is extracted from the EntityMemoryItem objects themselves.
"""
if not value:
return
items = value if isinstance(value, list) else [value]
is_batch = len(items) > 1
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
metadata=item.metadata,
metadata=metadata,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
@@ -66,36 +87,61 @@ class EntityMemory(Memory):
)
start_time = time.time()
saved_count = 0
errors = []
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
for item in items:
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
saved_count += 1
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
if is_batch:
emit_value = f"Saved {saved_count} entities"
metadata = {"entity_count": saved_count, "errors": errors}
else:
data = f"{item.name}({item.type}): {item.description}"
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
metadata = items[0].metadata
super().save(data, item.metadata)
# Emit memory save completed event
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=data,
metadata=item.metadata,
value=emit_value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
if errors:
raise Exception(
f"Partial save: {len(errors)} failed out of {len(items)}"
)
except Exception as e:
fail_metadata = (
{"entity_count": len(items), "saved": saved_count}
if is_batch
else items[0].metadata
)
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
metadata=item.metadata,
metadata=fail_metadata,
error=str(e),
source_type="entity_memory",
from_agent=self.agent,

View File

@@ -14,7 +14,7 @@ class _MissingProvider:
Raises RuntimeError when instantiated to indicate missing dependencies.
"""
provider: Literal["chromadb", "qdrant", "elasticsearch", "__missing__"] = field(
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
default="__missing__"
)

View File

@@ -9,8 +9,6 @@ if TYPE_CHECKING:
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
from crewai.rag.elasticsearch.client import ElasticsearchClient
from crewai.rag.elasticsearch.config import ElasticsearchConfig
class ChromaFactoryModule(Protocol):
@@ -27,11 +25,3 @@ class QdrantFactoryModule(Protocol):
def create_client(self, config: QdrantConfig) -> QdrantClient:
"""Creates a Qdrant client from configuration."""
...
class ElasticsearchFactoryModule(Protocol):
"""Protocol for Elasticsearch factory module."""
def create_client(self, config: ElasticsearchConfig) -> ElasticsearchClient:
"""Creates an Elasticsearch client from configuration."""
...

View File

@@ -20,10 +20,3 @@ class MissingQdrantConfig(_MissingProvider):
"""Placeholder for missing Qdrant configuration."""
provider: Literal["qdrant"] = field(default="qdrant")
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingElasticsearchConfig(_MissingProvider):
"""Placeholder for missing Elasticsearch configuration."""
provider: Literal["elasticsearch"] = field(default="elasticsearch")

View File

@@ -3,6 +3,6 @@
from typing import Annotated, Literal
SupportedProvider = Annotated[
Literal["chromadb", "qdrant", "elasticsearch"],
Literal["chromadb", "qdrant"],
"Supported RAG provider types, add providers here as they become available",
]

View File

@@ -13,9 +13,6 @@ if TYPE_CHECKING:
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
QdrantConfig = QdrantConfig_
from crewai.rag.elasticsearch.config import ElasticsearchConfig as ElasticsearchConfig_
ElasticsearchConfig = ElasticsearchConfig_
else:
try:
from crewai.rag.chromadb.config import ChromaDBConfig
@@ -31,14 +28,7 @@ else:
MissingQdrantConfig as QdrantConfig,
)
try:
from crewai.rag.elasticsearch.config import ElasticsearchConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingElasticsearchConfig as ElasticsearchConfig,
)
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig | ElasticsearchConfig
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
RagConfigType: TypeAlias = Annotated[
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
]

View File

@@ -1 +0,0 @@
"""Elasticsearch RAG implementation."""

View File

@@ -1,502 +0,0 @@
"""Elasticsearch client implementation."""
from typing import Any, cast
from typing_extensions import Unpack
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionSearchParams,
)
from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.elasticsearch.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
ElasticsearchClientType,
ElasticsearchCollectionCreateParams,
)
from crewai.rag.elasticsearch.utils import (
_is_async_client,
_is_async_embedding_function,
_is_sync_client,
_prepare_document_for_elasticsearch,
_process_search_results,
_build_vector_search_query,
_get_index_mapping,
)
from crewai.rag.types import SearchResult
class ElasticsearchClient(BaseClient):
"""Elasticsearch implementation of the BaseClient protocol.
Provides vector database operations for Elasticsearch, supporting both
synchronous and asynchronous clients.
Attributes:
client: Elasticsearch client instance (Elasticsearch or AsyncElasticsearch).
embedding_function: Function to generate embeddings for documents.
vector_dimension: Dimension of the embedding vectors.
similarity: Similarity function to use for vector search.
"""
def __init__(
self,
client: ElasticsearchClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
vector_dimension: int = 384,
similarity: str = "cosine",
) -> None:
"""Initialize ElasticsearchClient with client and embedding function.
Args:
client: Pre-configured Elasticsearch client instance.
embedding_function: Embedding function for text to vector conversion.
vector_dimension: Dimension of the embedding vectors.
similarity: Similarity function to use for vector search.
"""
self.client = client
self.embedding_function = embedding_function
self.vector_dimension = vector_dimension
self.similarity = similarity
def create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
"""Create a new index in Elasticsearch.
Keyword Args:
collection_name: Name of the index to create. Must be unique.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Raises:
ValueError: If index with the same name already exists.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="create_collection",
expected_client="Elasticsearch",
alt_method="acreate_collection",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
if self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' already exists")
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
self.client.indices.create(index=collection_name, body=mapping)
async def acreate_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
"""Create a new index in Elasticsearch asynchronously.
Keyword Args:
collection_name: Name of the index to create. Must be unique.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Raises:
ValueError: If index with the same name already exists.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="acreate_collection",
expected_client="AsyncElasticsearch",
alt_method="create_collection",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
if await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' already exists")
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
await self.client.indices.create(index=collection_name, body=mapping)
def get_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
"""Get an existing index or create it if it doesn't exist.
Keyword Args:
collection_name: Name of the index to get or create.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Returns:
Index info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="get_or_create_collection",
expected_client="Elasticsearch",
alt_method="aget_or_create_collection",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
if self.client.indices.exists(index=collection_name):
return self.client.indices.get(index=collection_name)
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
self.client.indices.create(index=collection_name, body=mapping)
return self.client.indices.get(index=collection_name)
async def aget_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
"""Get an existing index or create it if it doesn't exist asynchronously.
Keyword Args:
collection_name: Name of the index to get or create.
index_settings: Optional index settings.
vector_dimension: Optional vector dimension override.
similarity: Optional similarity function override.
Returns:
Index info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aget_or_create_collection",
expected_client="AsyncElasticsearch",
alt_method="get_or_create_collection",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
if await self.client.indices.exists(index=collection_name):
return await self.client.indices.get(index=collection_name)
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
similarity = kwargs.get("similarity", self.similarity)
mapping = _get_index_mapping(vector_dim, similarity)
index_settings = kwargs.get("index_settings", {})
if index_settings:
mapping["settings"] = index_settings
await self.client.indices.create(index=collection_name, body=mapping)
return await self.client.indices.get(index=collection_name)
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to an index.
Keyword Args:
collection_name: The name of the index to add documents to.
documents: List of BaseRecord dicts containing document data.
Raises:
ValueError: If index doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="add_documents",
expected_client="Elasticsearch",
alt_method="aadd_documents",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
if not self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync add_documents. "
"Use aadd_documents instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
self.client.index(
index=collection_name,
id=prepared_doc["id"],
body=prepared_doc["body"]
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to an index asynchronously.
Keyword Args:
collection_name: The name of the index to add documents to.
documents: List of BaseRecord dicts containing document data.
Raises:
ValueError: If index doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aadd_documents",
expected_client="AsyncElasticsearch",
alt_method="add_documents",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
if not await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
await self.client.index(
index=collection_name,
id=prepared_doc["id"],
body=prepared_doc["body"]
)
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Keyword Args:
collection_name: Name of the index to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="search",
expected_client="Elasticsearch",
alt_method="asearch",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
if not self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync search. "
"Use asearch instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_query = _build_vector_search_query(
query_vector=query_embedding,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
response = self.client.search(index=collection_name, body=search_query)
return _process_search_results(response, score_threshold)
async def asearch(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Keyword Args:
collection_name: Name of the index to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="asearch",
expected_client="AsyncElasticsearch",
alt_method="search",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
if not await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
query_embedding = await async_fn(query)
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_query = _build_vector_search_query(
query_vector=query_embedding,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
response = await self.client.search(index=collection_name, body=search_query)
return _process_search_results(response, score_threshold)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete an index and all its data.
Keyword Args:
collection_name: Name of the index to delete.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="delete_collection",
expected_client="Elasticsearch",
alt_method="adelete_collection",
alt_client="AsyncElasticsearch",
)
collection_name = kwargs["collection_name"]
if not self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
self.client.indices.delete(index=collection_name)
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete an index and all its data asynchronously.
Keyword Args:
collection_name: Name of the index to delete.
Raises:
ValueError: If index doesn't exist.
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="adelete_collection",
expected_client="AsyncElasticsearch",
alt_method="delete_collection",
alt_client="Elasticsearch",
)
collection_name = kwargs["collection_name"]
if not await self.client.indices.exists(index=collection_name):
raise ValueError(f"Index '{collection_name}' does not exist")
await self.client.indices.delete(index=collection_name)
def reset(self) -> None:
"""Reset the vector database by deleting all indices and data.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="reset",
expected_client="Elasticsearch",
alt_method="areset",
alt_client="AsyncElasticsearch",
)
indices_response = self.client.indices.get(index="*")
for index_name in indices_response.keys():
if not index_name.startswith("."):
self.client.indices.delete(index=index_name)
async def areset(self) -> None:
"""Reset the vector database by deleting all indices and data asynchronously.
Raises:
ConnectionError: If unable to connect to Elasticsearch server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="areset",
expected_client="AsyncElasticsearch",
alt_method="reset",
alt_client="Elasticsearch",
)
indices_response = await self.client.indices.get(index="*")
for index_name in indices_response.keys():
if not index_name.startswith("."):
await self.client.indices.delete(index=index_name)

View File

@@ -1,92 +0,0 @@
"""Elasticsearch configuration model."""
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.elasticsearch.types import (
ElasticsearchClientParams,
ElasticsearchEmbeddingFunctionWrapper,
)
from crewai.rag.elasticsearch.constants import (
DEFAULT_HOST,
DEFAULT_PORT,
DEFAULT_EMBEDDING_MODEL,
DEFAULT_VECTOR_DIMENSION,
)
def _default_options() -> ElasticsearchClientParams:
"""Create default Elasticsearch client options.
Returns:
Default options with local Elasticsearch connection.
"""
return ElasticsearchClientParams(
hosts=[f"http://{DEFAULT_HOST}:{DEFAULT_PORT}"],
use_ssl=False,
verify_certs=False,
timeout=30,
)
def _default_embedding_function() -> ElasticsearchEmbeddingFunctionWrapper:
"""Create default Elasticsearch embedding function.
Returns:
Default embedding function using sentence-transformers.
"""
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(DEFAULT_EMBEDDING_MODEL)
def embed_fn(text: str) -> list[float]:
"""Embed a single text string.
Args:
text: Text to embed.
Returns:
Embedding vector as list of floats.
"""
embedding = model.encode(text, convert_to_tensor=False)
return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
return cast(ElasticsearchEmbeddingFunctionWrapper, embed_fn)
except ImportError:
def fallback_embed_fn(text: str) -> list[float]:
"""Fallback embedding function when sentence-transformers is not available."""
import hashlib
import struct
hash_obj = hashlib.md5(text.encode(), usedforsecurity=False)
hash_bytes = hash_obj.digest()
vector = []
for i in range(0, len(hash_bytes), 4):
chunk = hash_bytes[i:i+4]
if len(chunk) == 4:
value = struct.unpack('f', chunk)[0]
vector.append(float(value))
while len(vector) < DEFAULT_VECTOR_DIMENSION:
vector.extend(vector[:DEFAULT_VECTOR_DIMENSION - len(vector)])
return vector[:DEFAULT_VECTOR_DIMENSION]
return cast(ElasticsearchEmbeddingFunctionWrapper, fallback_embed_fn)
@pyd_dataclass(frozen=True)
class ElasticsearchConfig(BaseRagConfig):
"""Configuration for Elasticsearch client."""
provider: Literal["elasticsearch"] = field(default="elasticsearch", init=False)
options: ElasticsearchClientParams = field(default_factory=_default_options)
vector_dimension: int = DEFAULT_VECTOR_DIMENSION
similarity: str = "cosine"
embedding_function: ElasticsearchEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -1,12 +0,0 @@
"""Constants for Elasticsearch RAG implementation."""
from typing import Final
DEFAULT_HOST: Final[str] = "localhost"
DEFAULT_PORT: Final[int] = 9200
DEFAULT_INDEX_SETTINGS: Final[dict] = {
"number_of_shards": 1,
"number_of_replicas": 0,
}
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_VECTOR_DIMENSION: Final[int] = 384

View File

@@ -1,31 +0,0 @@
"""Factory functions for creating Elasticsearch clients."""
from crewai.rag.elasticsearch.config import ElasticsearchConfig
from crewai.rag.elasticsearch.client import ElasticsearchClient
def create_client(config: ElasticsearchConfig) -> ElasticsearchClient:
"""Create an ElasticsearchClient from configuration.
Args:
config: Elasticsearch configuration object.
Returns:
Configured ElasticsearchClient instance.
"""
try:
from elasticsearch import Elasticsearch
except ImportError as e:
raise ImportError(
"elasticsearch package is required for Elasticsearch support. "
"Install it with: pip install elasticsearch"
) from e
client = Elasticsearch(**config.options)
return ElasticsearchClient(
client=client,
embedding_function=config.embedding_function,
vector_dimension=config.vector_dimension,
similarity=config.similarity,
)

View File

@@ -1,93 +0,0 @@
"""Type definitions for Elasticsearch RAG implementation."""
from typing import Any, Protocol, Union, TYPE_CHECKING
from typing_extensions import NotRequired, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
if TYPE_CHECKING:
from typing import TypeAlias
from elasticsearch import Elasticsearch, AsyncElasticsearch
ElasticsearchClientType: TypeAlias = Union[Elasticsearch, AsyncElasticsearch]
else:
try:
from elasticsearch import Elasticsearch, AsyncElasticsearch
ElasticsearchClientType = Union[Elasticsearch, AsyncElasticsearch]
except ImportError:
ElasticsearchClientType = Any
class ElasticsearchClientParams(TypedDict, total=False):
"""Parameters for Elasticsearch client initialization."""
hosts: NotRequired[list[str]]
cloud_id: NotRequired[str]
username: NotRequired[str]
password: NotRequired[str]
api_key: NotRequired[str]
use_ssl: NotRequired[bool]
verify_certs: NotRequired[bool]
ca_certs: NotRequired[str]
timeout: NotRequired[int]
class ElasticsearchIndexSettings(TypedDict, total=False):
"""Settings for Elasticsearch index creation."""
number_of_shards: NotRequired[int]
number_of_replicas: NotRequired[int]
refresh_interval: NotRequired[str]
class ElasticsearchCollectionCreateParams(TypedDict, total=False):
"""Parameters for creating Elasticsearch collections/indices."""
collection_name: str
index_settings: NotRequired[ElasticsearchIndexSettings]
vector_dimension: NotRequired[int]
similarity: NotRequired[str]
class EmbeddingFunction(Protocol):
"""Protocol for embedding functions that convert text to vectors."""
def __call__(self, text: str) -> list[float]:
"""Convert text to embedding vector.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats.
"""
...
class AsyncEmbeddingFunction(Protocol):
"""Protocol for async embedding functions that convert text to vectors."""
async def __call__(self, text: str) -> list[float]:
"""Convert text to embedding vector asynchronously.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats.
"""
...
class ElasticsearchEmbeddingFunctionWrapper(EmbeddingFunction):
"""Base class for Elasticsearch EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Elasticsearch EmbeddingFunction.
This allows Pydantic to handle Elasticsearch's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()

View File

@@ -1,186 +0,0 @@
"""Utility functions for Elasticsearch RAG implementation."""
import hashlib
from typing import Any, TypeGuard
from crewai.rag.elasticsearch.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
ElasticsearchClientType,
)
from crewai.rag.types import BaseRecord, SearchResult
try:
from elasticsearch import Elasticsearch, AsyncElasticsearch
except ImportError:
Elasticsearch = None
AsyncElasticsearch = None
def _is_sync_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
"""Type guard to check if the client is a sync Elasticsearch client."""
if Elasticsearch is None:
return False
return isinstance(client, Elasticsearch)
def _is_async_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
"""Type guard to check if the client is an async Elasticsearch client."""
if AsyncElasticsearch is None:
return False
return isinstance(client, AsyncElasticsearch)
def _is_async_embedding_function(
func: EmbeddingFunction | AsyncEmbeddingFunction,
) -> TypeGuard[AsyncEmbeddingFunction]:
"""Type guard to check if the embedding function is async."""
import inspect
return inspect.iscoroutinefunction(func)
def _generate_doc_id(content: str) -> str:
"""Generate a document ID from content using SHA256 hash."""
return hashlib.sha256(content.encode()).hexdigest()
def _prepare_document_for_elasticsearch(
doc: BaseRecord, embedding: list[float]
) -> dict[str, Any]:
"""Prepare a document for Elasticsearch indexing.
Args:
doc: Document record to prepare.
embedding: Embedding vector for the document.
Returns:
Document formatted for Elasticsearch.
"""
doc_id = doc.get("doc_id") or _generate_doc_id(doc["content"])
es_doc = {
"content": doc["content"],
"content_vector": embedding,
"metadata": doc.get("metadata", {}),
}
return {"id": doc_id, "body": es_doc}
def _process_search_results(
response: dict[str, Any], score_threshold: float | None = None
) -> list[SearchResult]:
"""Process Elasticsearch search response into SearchResult format.
Args:
response: Raw Elasticsearch search response.
score_threshold: Optional minimum score threshold.
Returns:
List of SearchResult dictionaries.
"""
results = []
hits = response.get("hits", {}).get("hits", [])
for hit in hits:
score = hit.get("_score", 0.0)
if score_threshold is not None and score < score_threshold:
continue
source = hit.get("_source", {})
result = SearchResult(
id=hit.get("_id", ""),
content=source.get("content", ""),
metadata=source.get("metadata", {}),
score=score,
)
results.append(result)
return results
def _build_vector_search_query(
query_vector: list[float],
limit: int = 10,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float | None = None,
) -> dict[str, Any]:
"""Build Elasticsearch query for vector similarity search.
Args:
query_vector: Query embedding vector.
limit: Maximum number of results.
metadata_filter: Optional metadata filter.
score_threshold: Optional minimum score threshold.
Returns:
Elasticsearch query dictionary.
"""
query = {
"size": limit,
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
"params": {"query_vector": query_vector}
}
}
}
}
if metadata_filter:
bool_query = {
"bool": {
"must": [
query["query"]
],
"filter": []
}
}
for key, value in metadata_filter.items():
bool_query["bool"]["filter"].append({
"term": {f"metadata.{key}": value}
})
query["query"] = bool_query
if score_threshold is not None:
query["min_score"] = score_threshold
return query
def _get_index_mapping(vector_dimension: int, similarity: str = "cosine") -> dict[str, Any]:
"""Get Elasticsearch index mapping for vector search.
Args:
vector_dimension: Dimension of the embedding vectors.
similarity: Similarity function to use.
Returns:
Elasticsearch mapping dictionary.
"""
return {
"mappings": {
"properties": {
"content": {
"type": "text",
"analyzer": "standard"
},
"content_vector": {
"type": "dense_vector",
"dims": vector_dimension,
"similarity": similarity
},
"metadata": {
"type": "object",
"dynamic": True
}
}
}
}

View File

@@ -5,7 +5,6 @@ from typing import cast
from crewai.rag.config.optional_imports.protocols import (
ChromaFactoryModule,
QdrantFactoryModule,
ElasticsearchFactoryModule,
)
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
@@ -44,15 +43,3 @@ def create_client(config: RagConfigType) -> BaseClient:
),
)
return qdrant_mod.create_client(config)
if config.provider == "elasticsearch":
elasticsearch_mod = cast(
ElasticsearchFactoryModule,
require(
"crewai.rag.elasticsearch.factory",
purpose="The 'elasticsearch' provider",
),
)
return elasticsearch_mod.create_client(config)
raise ValueError(f"Unsupported provider: {config.provider}")

View File

@@ -0,0 +1,622 @@
from __future__ import annotations
import asyncio
import contextvars
import functools
import hashlib
import inspect
import os
import re
import sys
import threading
import time
import weakref
from pathlib import Path
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar, ParamSpec, Concatenate, TypedDict
import portalocker
from portalocker import constants
from typing_extensions import NotRequired, Self
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
_STATE: dict[str, int] = {"pid": os.getpid()}
def _reset_after_fork() -> None:
"""Reset in-process state in the child after a fork.
Resets all locks and thread-local storage after a process fork
to prevent lock contamination across processes.
"""
global _sync_rlocks, _async_locks_by_loop, _tls, _task_depths_var, _STATE
_sync_rlocks = {}
_async_locks_by_loop = weakref.WeakKeyDictionary()
_tls = threading.local()
# Reset task-local depths for async
_task_depths_var = contextvars.ContextVar("locked_task_depths", default=None)
_STATE["pid"] = os.getpid()
def _ensure_same_process() -> None:
"""Ensure we're in the same process, reset if forked.
Checks if the current PID matches the stored PID and resets
state if a fork has occurred.
"""
if _STATE["pid"] != os.getpid():
_reset_after_fork()
# Automatically reset in a forked child on POSIX
_register_at_fork = getattr(os, "register_at_fork", None)
if _register_at_fork is not None:
_register_at_fork(after_in_child=_reset_after_fork)
class LockConfig(TypedDict):
"""Configuration for portalocker locks.
Attributes:
mode: File open mode.
timeout: Optional lock timeout.
check_interval: Optional check interval.
fail_when_locked: Whether to fail if already locked.
flags: Portalocker lock flags.
"""
mode: str
timeout: NotRequired[float]
check_interval: NotRequired[float]
fail_when_locked: bool
flags: portalocker.LockFlags
def _get_platform_lock_flags() -> portalocker.LockFlags:
"""Get platform-appropriate lock flags.
Returns:
LockFlags.EXCLUSIVE for exclusive file locking.
"""
# Use EXCLUSIVE flag only - let portalocker handle blocking/non-blocking internally
return constants.LockFlags.EXCLUSIVE
def _get_lock_config() -> LockConfig:
"""Get lock configuration appropriate for the platform.
Returns:
LockConfig dict with mode, flags, and fail_when_locked settings.
"""
config: LockConfig = {
"mode": "a+",
"fail_when_locked": False,
"flags": _get_platform_lock_flags(),
}
return config
LOCK_CONFIG: LockConfig = _get_lock_config()
LOCK_STALE_SECONDS = 120
def _default_lock_dir() -> Path:
"""Get or create the default lock directory.
Returns:
Path to ~/.crewai/locks directory, created if necessary.
"""
lock_dir = Path.home() / ".crewai" / "locks"
lock_dir.mkdir(parents=True, exist_ok=True)
# Best-effort: restrict perms on POSIX
try:
if os.name == "posix":
lock_dir.chmod(0o700)
except Exception:
pass
# Clean up old lock files
_cleanup_stale_locks(lock_dir)
return lock_dir
def _cleanup_stale_locks(lock_dir: Path, max_age_seconds: int = 86400) -> None:
"""Remove lock files older than max_age_seconds.
Args:
lock_dir: Directory containing lock files.
max_age_seconds: Maximum age before considering a lock stale (default 24 hours).
"""
try:
current_time = time.time()
for lock_file in lock_dir.glob("*.lock"):
try:
# Check if file is old and not currently locked
file_age = current_time - lock_file.stat().st_mtime
if file_age > max_age_seconds:
# Try to acquire exclusive lock - if successful, file is not in use
try:
with portalocker.Lock(
str(lock_file),
mode="a+",
timeout=0.01, # Very short timeout
fail_when_locked=True,
flags=constants.LockFlags.EXCLUSIVE,
):
pass # We got the lock, file is not in use
# Safe to remove
lock_file.unlink(missing_ok=True)
except (portalocker.LockException, OSError):
# File is locked or can't be accessed, skip it
pass
except (OSError, IOError):
# Skip files we can't stat or process
pass
except Exception:
# Cleanup is best-effort, don't fail on errors
pass
def _hash_str(value: str) -> str:
"""Generate a short hash of a string.
Args:
value: String to hash.
Returns:
First 10 characters of SHA256 hash.
"""
return hashlib.sha256(value.encode()).hexdigest()[:10]
def _qualname_for(func: Callable[..., Any], owner: object | None = None) -> str:
"""Get qualified name for a function.
Args:
func: Function to get qualified name for.
owner: Optional owner object for the function.
Returns:
Fully qualified name including module and class.
"""
target = inspect.unwrap(func)
if inspect.ismethod(func) and getattr(func, "__self__", None) is not None:
owner_obj = func.__self__
cls = owner_obj if inspect.isclass(owner_obj) else owner_obj.__class__
return f"{target.__module__}.{cls.__qualname__}.{getattr(target, '__name__', '<?>')}"
if owner is not None:
cls = owner if inspect.isclass(owner) else owner.__class__
return f"{target.__module__}.{cls.__qualname__}.{getattr(target, '__name__', '<?>')}"
qn = getattr(target, "__qualname__", None)
if qn is not None:
return f"{getattr(target, '__module__', target.__class__.__module__)}.{qn}"
if isinstance(target, functools.partial):
f = inspect.unwrap(target.func)
return f"{getattr(f, '__module__', 'builtins')}.{getattr(f, '__qualname__', getattr(f, '__name__', '<?>'))}"
cls = target.__class__
return f"{cls.__module__}.{cls.__qualname__}.__call__"
def _get_lock_context(
instance: Any | None,
func: Callable[..., Any],
kwargs: dict[str, Any],
) -> tuple[Path, str | None]:
"""Extract lock context from function call.
Args:
instance: Instance the function is called on.
func: Function being called.
kwargs: Keyword arguments passed to function.
Returns:
Tuple of (lock_file_path, collection_name).
"""
collection_name = (
str(kwargs.get("collection_name")) if "collection_name" in kwargs else None
)
lock_dir = _default_lock_dir()
base = _qualname_for(func, owner=instance)
safe_base = re.sub(r"[^\w.\-]+", "_", base)
suffix = f"_{_hash_str(collection_name)}" if collection_name else ""
path = lock_dir / f"{safe_base}{suffix}.lock"
return path, collection_name
def _write_lock_metadata(lock_file_path: Path) -> None:
"""Write metadata to lock file for staleness detection.
Args:
lock_file_path: Path to the lock file.
"""
try:
with open(lock_file_path, "w") as f:
f.write(f"{os.getpid()}\n{time.time()}\n")
f.flush()
os.fsync(f.fileno()) # Ensure data is written to disk
# Set restrictive permissions on lock file (Unix only)
if sys.platform not in ("win32", "cygwin"):
try:
lock_file_path.chmod(0o600)
except Exception:
pass
except Exception:
# Best effort - don't fail if we can't write metadata
pass
def _check_lock_staleness(lock_file_path: Path) -> bool:
"""Check if a lock file is stale.
Args:
lock_file_path: Path to the lock file.
Returns:
True if lock is stale, False otherwise.
"""
try:
if not lock_file_path.exists():
return False
with open(lock_file_path) as f:
lines = f.readlines()
if len(lines) < 2:
return True # unreadable metadata
pid = int(lines[0].strip())
ts = float(lines[1].strip())
# If the process is alive, do NOT treat as stale based on time alone.
if sys.platform not in ("win32", "cygwin"):
try:
os.kill(pid, 0)
return False # alive → not stale
except (OSError, ProcessLookupError):
pass # dead process → proceed to time check
# Process dead: time window can be small; consider stale now
return (time.time() - ts) > 1.0 # essentially “dead means stale”
except Exception:
return True
_sync_rlocks: dict[Path, threading.RLock] = {}
_sync_rlocks_guard = threading.Lock()
_tls = threading.local()
def _get_sync_rlock(path: Path) -> threading.RLock:
"""Get or create a reentrant lock for a path.
Args:
path: Path to get lock for.
Returns:
Threading RLock for the given path.
"""
with _sync_rlocks_guard:
lk = _sync_rlocks.get(path)
if lk is None:
lk = threading.RLock()
_sync_rlocks[path] = lk
return lk
class _SyncDepthManager:
"""Context manager for sync depth tracking.
Tracks reentrancy depth for synchronous locks to determine
when to acquire/release file locks.
"""
def __init__(self, path: Path):
"""Initialize depth manager.
Args:
path: Path to track depth for.
"""
self.path = path
self.depth = 0
def __enter__(self) -> int:
"""Enter context and increment depth.
Returns:
Current depth after increment.
"""
depths = getattr(_tls, "depths", None)
if depths is None:
depths = {}
_tls.depths = depths
self.depth = depths.get(self.path, 0) + 1
depths[self.path] = self.depth
return self.depth
def __exit__(self, *args: Any) -> None:
"""Exit context and decrement depth.
Args:
*args: Exception information if any.
"""
depths = getattr(_tls, "depths", {})
v = depths.get(self.path, 1) - 1
if v <= 0:
depths.pop(self.path, None)
else:
depths[self.path] = v
def _safe_to_delete(path: Path) -> bool:
"""Check if a lock file can be safely deleted.
Args:
path: Path to the lock file.
Returns:
True if file can be deleted safely, False otherwise.
"""
try:
with portalocker.Lock(
str(path),
mode="a+",
timeout=0.01, # very short, non-blocking-ish
fail_when_locked=True, # fail if someone holds it
flags=constants.LockFlags.EXCLUSIVE,
):
return True
except Exception:
return False
def with_lock(func: Callable[Concatenate[T, P], R]) -> Callable[Concatenate[T, P], R]:
"""Decorator for file-based cross-process locking.
Args:
func: Function to wrap with locking.
Returns:
Wrapped function with locking behavior.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
_ensure_same_process()
path, _ = _get_lock_context(self, func, kwargs)
local_lock = _get_sync_rlock(path)
prune_after = False
with local_lock:
with _SyncDepthManager(path) as depth:
if depth == 1:
# stale handling
if _check_lock_staleness(path) and _safe_to_delete(path):
try:
path.unlink(missing_ok=True)
except Exception:
pass
# acquire file lock
lock_config = LockConfig(
mode=LOCK_CONFIG["mode"],
fail_when_locked=LOCK_CONFIG["fail_when_locked"],
flags=LOCK_CONFIG["flags"],
)
with portalocker.Lock(str(path), **lock_config) as _fh:
_write_lock_metadata(path)
result = func(self, *args, **kwargs)
try:
path.unlink(missing_ok=True)
except Exception:
pass
prune_after = True
else:
result = func(self, *args, **kwargs)
# <-- NOW its safe to remove the entry
if prune_after:
with _sync_rlocks_guard:
_sync_rlocks.pop(path, None)
return result
return wrapper
# Use weak references to avoid keeping event loops alive
_async_locks_by_loop: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, dict[Path, asyncio.Lock]
] = weakref.WeakKeyDictionary()
_async_locks_guard = threading.Lock()
_task_depths_var: contextvars.ContextVar[dict[Path, int] | None] = (
contextvars.ContextVar("locked_task_depths", default=None)
)
def _get_async_lock(path: Path) -> asyncio.Lock:
"""Get or create an async lock for the current event loop.
Args:
path: Path to get lock for.
Returns:
Asyncio Lock for the given path in current event loop.
"""
loop = asyncio.get_running_loop()
with _async_locks_guard:
# Get locks dict for this event loop
loop_locks = _async_locks_by_loop.get(loop)
if loop_locks is None:
loop_locks = {}
_async_locks_by_loop[loop] = loop_locks
# Get or create lock for this path
lk = loop_locks.get(path)
if lk is None:
# Create lock in the context of the running loop
lk = asyncio.Lock()
loop_locks[path] = lk
return lk
class _AsyncDepthManager:
"""Context manager for async task-local depth tracking.
Tracks reentrancy depth for async locks to determine
when to acquire/release file locks.
"""
def __init__(self, path: Path):
"""Initialize async depth manager.
Args:
path: Path to track depth for.
"""
self.path = path
self.depths: dict[Path, int] | None = None
self.is_reentrant = False
def __enter__(self) -> Self:
"""Enter context and track async task depth.
Returns:
Self for context management.
"""
d = _task_depths_var.get()
if d is None:
d = {}
_task_depths_var.set(d)
self.depths = d
cur_depth = self.depths.get(self.path, 0)
if cur_depth > 0:
self.is_reentrant = True
self.depths[self.path] = cur_depth + 1
else:
self.depths[self.path] = 1
return self
def __exit__(self, *args: Any) -> None:
"""Exit context and update task depth.
Args:
*args: Exception information if any.
"""
if self.depths is not None:
new_depth = self.depths.get(self.path, 1) - 1
if new_depth <= 0:
self.depths.pop(self.path, None)
else:
self.depths[self.path] = new_depth
async def _safe_to_delete_async(path: Path) -> bool:
"""Check if a lock file can be safely deleted (async).
Args:
path: Path to the lock file.
Returns:
True if file can be deleted safely, False otherwise.
"""
def _try_lock() -> bool:
try:
with portalocker.Lock(
str(path),
mode="a+",
timeout=0.01, # very short, effectively non-blocking
fail_when_locked=True, # fail if another process holds it
flags=constants.LockFlags.EXCLUSIVE,
):
return True
except Exception:
return False
return await asyncio.to_thread(_try_lock)
def async_with_lock(
func: Callable[Concatenate[T, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[T, P], Coroutine[Any, Any, R]]:
"""Async decorator for file-based cross-process locking.
Args:
func: Async function to wrap with locking.
Returns:
Wrapped async function with locking behavior.
"""
@functools.wraps(func)
async def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
_ensure_same_process()
path, _ = _get_lock_context(self, func, kwargs)
with _AsyncDepthManager(path) as depth_mgr:
if depth_mgr.is_reentrant:
# Re-entrant within the same task: skip file lock
return await func(self, *args, **kwargs)
# Safer stale handling: only unlink if we can lock it first
if _check_lock_staleness(path) and await _safe_to_delete_async(path):
try:
await asyncio.to_thread(lambda: path.unlink(missing_ok=True))
except Exception:
pass
# Acquire per-loop async lock to serialize within this loop
async_lock = _get_async_lock(path)
await async_lock.acquire()
try:
# Acquire cross-process file lock in a thread
lock_config = LockConfig(
mode=LOCK_CONFIG["mode"],
fail_when_locked=LOCK_CONFIG["fail_when_locked"],
flags=LOCK_CONFIG["flags"],
)
file_lock = portalocker.Lock(str(path), **lock_config)
await asyncio.to_thread(file_lock.acquire)
try:
# Write/refresh metadata while lock is held
await asyncio.to_thread(lambda: _write_lock_metadata(path))
result = await func(self, *args, **kwargs)
finally:
# Release file lock before unlink to avoid inode race
try:
await asyncio.to_thread(file_lock.release)
finally:
# Now it's safe to unlink the path
try:
await asyncio.to_thread(
lambda: path.unlink(missing_ok=True)
)
except Exception:
pass
return result
finally:
async_lock.release()
with _async_locks_guard:
loop = asyncio.get_running_loop()
loop_locks = _async_locks_by_loop.get(loop)
if loop_locks is not None:
loop_locks.pop(path, None)
return wrapper

View File

@@ -3,11 +3,6 @@ from datetime import datetime, timedelta
import requests
from unittest.mock import MagicMock, patch, call
from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.authentication.constants import (
AUTH0_AUDIENCE,
AUTH0_CLIENT_ID,
AUTH0_DOMAIN
)
from crewai.cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
@@ -22,16 +17,6 @@ class TestAuthenticationCommand:
@pytest.mark.parametrize(
"user_provider,expected_urls",
[
(
"auth0",
{
"device_code_url": f"https://{AUTH0_DOMAIN}/oauth/device/code",
"token_url": f"https://{AUTH0_DOMAIN}/oauth/token",
"client_id": AUTH0_CLIENT_ID,
"audience": AUTH0_AUDIENCE,
"domain": AUTH0_DOMAIN,
},
),
(
"workos",
{
@@ -44,9 +29,6 @@ class TestAuthenticationCommand:
),
],
)
@patch(
"crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
)
@patch("crewai.cli.authentication.main.AuthenticationCommand._get_device_code")
@patch(
"crewai.cli.authentication.main.AuthenticationCommand._display_auth_instructions"
@@ -59,11 +41,9 @@ class TestAuthenticationCommand:
mock_poll,
mock_display,
mock_get_device,
mock_determine_provider,
user_provider,
expected_urls,
):
mock_determine_provider.return_value = user_provider
mock_get_device.return_value = {
"device_code": "test_code",
"user_code": "123456",
@@ -74,7 +54,6 @@ class TestAuthenticationCommand:
mock_console_print.assert_called_once_with(
"Signing in to CrewAI Enterprise...\n", style="bold blue"
)
mock_determine_provider.assert_called_once()
mock_get_device.assert_called_once()
mock_display.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"}
@@ -82,9 +61,17 @@ class TestAuthenticationCommand:
mock_poll.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"},
)
assert self.auth_command.oauth2_provider.get_client_id() == expected_urls["client_id"]
assert self.auth_command.oauth2_provider.get_audience() == expected_urls["audience"]
assert self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
assert (
self.auth_command.oauth2_provider.get_client_id()
== expected_urls["client_id"]
)
assert (
self.auth_command.oauth2_provider.get_audience()
== expected_urls["audience"]
)
assert (
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
)
@patch("crewai.cli.authentication.main.webbrowser")
@patch("crewai.cli.authentication.main.console.print")
@@ -106,14 +93,6 @@ class TestAuthenticationCommand:
@pytest.mark.parametrize(
"user_provider,jwt_config",
[
(
"auth0",
{
"jwks_url": f"https://{AUTH0_DOMAIN}/.well-known/jwks.json",
"issuer": f"https://{AUTH0_DOMAIN}/",
"audience": AUTH0_AUDIENCE,
},
),
(
"workos",
{
@@ -135,14 +114,18 @@ class TestAuthenticationCommand:
jwt_config,
has_expiration,
):
from crewai.cli.authentication.providers.auth0 import Auth0Provider
from crewai.cli.authentication.providers.workos import WorkosProvider
from crewai.cli.authentication.main import Oauth2Settings
if user_provider == "auth0":
self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=AUTH0_DOMAIN, audience=jwt_config["audience"]))
elif user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, audience=jwt_config["audience"]))
if user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(
settings=Oauth2Settings(
provider=user_provider,
client_id="test-client-id",
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
audience=jwt_config["audience"],
)
)
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
@@ -234,83 +217,6 @@ class TestAuthenticationCommand:
]
mock_console_print.assert_has_calls(expected_calls)
@pytest.mark.parametrize(
"api_response,expected_provider",
[
({"provider": "auth0"}, "auth0"),
({"provider": "workos"}, "workos"),
({"provider": "none"}, "workos"), # Default to workos for any other value
(
{},
"workos",
), # Default to workos if no provider key is sent in the response
],
)
@patch("crewai.cli.authentication.main.PlusAPI")
@patch("crewai.cli.authentication.main.console.print")
@patch("builtins.input", return_value="test@example.com")
def test_determine_user_provider_success(
self,
mock_input,
mock_console_print,
mock_plus_api,
api_response,
expected_provider,
):
mock_api_instance = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = api_response
mock_api_instance._make_request.return_value = mock_response
mock_plus_api.return_value = mock_api_instance
result = self.auth_command._determine_user_provider()
mock_input.assert_called_once()
mock_plus_api.assert_called_once_with("")
mock_api_instance._make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
)
assert result == expected_provider
@patch("crewai.cli.authentication.main.PlusAPI")
@patch("crewai.cli.authentication.main.console.print")
@patch("builtins.input", return_value="test@example.com")
def test_determine_user_provider_error(
self, mock_input, mock_console_print, mock_plus_api
):
mock_api_instance = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 500
mock_api_instance._make_request.return_value = mock_response
mock_plus_api.return_value = mock_api_instance
with pytest.raises(SystemExit):
self.auth_command._determine_user_provider()
mock_input.assert_called_once()
mock_plus_api.assert_called_once_with("")
mock_api_instance._make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
)
mock_console_print.assert_has_calls(
[
call(
"Enter your CrewAI Enterprise account email: ",
style="bold blue",
end="",
),
call(
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
style="red",
),
]
)
@patch("requests.post")
def test_get_device_code(self, mock_post):
mock_response = MagicMock()
@@ -323,7 +229,9 @@ class TestAuthenticationCommand:
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command.oauth2_provider.get_authorize_url.return_value = "https://example.com/device"
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
"https://example.com/device"
)
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
result = self.auth_command._get_device_code()
@@ -366,12 +274,12 @@ class TestAuthenticationCommand:
) as mock_tool_login,
):
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_token_url.return_value = "https://example.com/token"
self.auth_command.oauth2_provider.get_token_url.return_value = (
"https://example.com/token"
)
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command._poll_for_token(
device_code_data
)
self.auth_command._poll_for_token(device_code_data)
mock_post.assert_called_once_with(
"https://example.com/token",
@@ -406,9 +314,7 @@ class TestAuthenticationCommand:
"interval": 0.1, # Short interval for testing
}
self.auth_command._poll_for_token(
device_code_data
)
self.auth_command._poll_for_token(device_code_data)
mock_console_print.assert_any_call(
"Timeout: Failed to get the token. Please try again.", style="bold red"
@@ -429,15 +335,4 @@ class TestAuthenticationCommand:
device_code_data = {"device_code": "test_device_code", "interval": 1}
with pytest.raises(requests.HTTPError):
self.auth_command._poll_for_token(
device_code_data
)
# @patch(
# "crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
# )
# def test_login_with_auth0(self, mock_determine_provider):
# from crewai.cli.authentication.providers.auth0 import Auth0Provider
# from crewai.cli.authentication.main import Oauth2Settings
# self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider="auth0", client_id=AUTH0_CLIENT_ID, domain=AUTH0_DOMAIN, audience=AUTH0_AUDIENCE))
# self.auth_command.login()
self.auth_command._poll_for_token(device_code_data)

View File

@@ -2,8 +2,6 @@
from unittest.mock import Mock, patch
import pytest
from crewai.rag.factory import create_client
@@ -27,50 +25,10 @@ def test_create_client_chromadb():
mock_module.create_client.assert_called_once_with(mock_config)
def test_create_client_qdrant():
"""Test Qdrant client creation."""
mock_config = Mock()
mock_config.provider = "qdrant"
with patch("crewai.rag.factory.require") as mock_require:
mock_module = Mock()
mock_client = Mock()
mock_module.create_client.return_value = mock_client
mock_require.return_value = mock_module
result = create_client(mock_config)
assert result == mock_client
mock_require.assert_called_once_with(
"crewai.rag.qdrant.factory", purpose="The 'qdrant' provider"
)
mock_module.create_client.assert_called_once_with(mock_config)
def test_create_client_elasticsearch():
"""Test Elasticsearch client creation."""
mock_config = Mock()
mock_config.provider = "elasticsearch"
with patch("crewai.rag.factory.require") as mock_require:
mock_module = Mock()
mock_client = Mock()
mock_module.create_client.return_value = mock_client
mock_require.return_value = mock_module
result = create_client(mock_config)
assert result == mock_client
mock_require.assert_called_once_with(
"crewai.rag.elasticsearch.factory", purpose="The 'elasticsearch' provider"
)
mock_module.create_client.assert_called_once_with(mock_config)
def test_create_client_unsupported_provider():
"""Test that unsupported provider raises ValueError."""
"""Test unsupported provider returns None for now."""
mock_config = Mock()
mock_config.provider = "unsupported"
with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
create_client(mock_config)
result = create_client(mock_config)
assert result is None

View File

@@ -3,10 +3,7 @@
import pytest
from crewai.rag.config.optional_imports.base import _MissingProvider
from crewai.rag.config.optional_imports.providers import (
MissingChromaDBConfig,
MissingElasticsearchConfig,
)
from crewai.rag.config.optional_imports.providers import MissingChromaDBConfig
def test_missing_provider_raises_runtime_error():
@@ -23,11 +20,3 @@ def test_missing_chromadb_config_raises_runtime_error():
RuntimeError, match="provider 'chromadb' requested but not installed"
):
MissingChromaDBConfig()
def test_missing_elasticsearch_config_raises_runtime_error():
"""Test that MissingElasticsearchConfig raises RuntimeError on instantiation."""
with pytest.raises(
RuntimeError, match="provider 'elasticsearch' requested but not installed"
):
MissingElasticsearchConfig()

View File

@@ -1 +0,0 @@
"""Tests for Elasticsearch RAG implementation."""

View File

@@ -1,397 +0,0 @@
"""Tests for ElasticsearchClient implementation."""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from crewai.rag.elasticsearch.client import ElasticsearchClient
from crewai.rag.types import BaseRecord
from crewai.rag.core.exceptions import ClientMethodMismatchError
@pytest.fixture
def mock_elasticsearch_client():
"""Create a mock Elasticsearch client."""
mock_client = Mock()
mock_client.indices = Mock()
mock_client.indices.exists.return_value = False
mock_client.indices.create.return_value = {"acknowledged": True}
mock_client.indices.get.return_value = {"test_index": {"mappings": {}}}
mock_client.indices.delete.return_value = {"acknowledged": True}
mock_client.index.return_value = {"_id": "test_id", "result": "created"}
mock_client.search.return_value = {
"hits": {
"hits": [
{
"_id": "doc1",
"_score": 0.9,
"_source": {
"content": "test content",
"metadata": {"key": "value"}
}
}
]
}
}
return mock_client
@pytest.fixture
def mock_async_elasticsearch_client():
"""Create a mock async Elasticsearch client."""
mock_client = Mock()
mock_client.indices = Mock()
mock_client.indices.exists = AsyncMock(return_value=False)
mock_client.indices.create = AsyncMock(return_value={"acknowledged": True})
mock_client.indices.get = AsyncMock(return_value={"test_index": {"mappings": {}}})
mock_client.indices.delete = AsyncMock(return_value={"acknowledged": True})
mock_client.index = AsyncMock(return_value={"_id": "test_id", "result": "created"})
mock_client.search = AsyncMock(return_value={
"hits": {
"hits": [
{
"_id": "doc1",
"_score": 0.9,
"_source": {
"content": "test content",
"metadata": {"key": "value"}
}
}
]
}
})
return mock_client
@pytest.fixture
def client(mock_elasticsearch_client) -> ElasticsearchClient:
"""Create an ElasticsearchClient instance for testing."""
mock_embedding = Mock()
mock_embedding.return_value = [0.1, 0.2, 0.3]
client = ElasticsearchClient(
client=mock_elasticsearch_client,
embedding_function=mock_embedding,
vector_dimension=3,
similarity="cosine"
)
return client
@pytest.fixture
def async_client(mock_async_elasticsearch_client) -> ElasticsearchClient:
"""Create an ElasticsearchClient instance with async client for testing."""
mock_embedding = Mock()
mock_embedding.return_value = [0.1, 0.2, 0.3]
client = ElasticsearchClient(
client=mock_async_elasticsearch_client,
embedding_function=mock_embedding,
vector_dimension=3,
similarity="cosine"
)
return client
class TestElasticsearchClient:
"""Test suite for ElasticsearchClient."""
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that create_collection creates a new index."""
mock_elasticsearch_client.indices.exists.return_value = False
client.create_collection(collection_name="test_index")
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.indices.create.assert_called_once()
call_args = mock_elasticsearch_client.indices.create.call_args
assert call_args.kwargs["index"] == "test_index"
assert "mappings" in call_args.kwargs["body"]
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_create_collection_already_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that create_collection raises error if index exists."""
mock_elasticsearch_client.indices.exists.return_value = True
with pytest.raises(
ValueError, match="Index 'test_index' already exists"
):
client.create_collection(collection_name="test_index")
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
def test_create_collection_wrong_client_type(self, mock_is_async, mock_is_sync, mock_async_elasticsearch_client):
"""Test that create_collection raises error with async client."""
mock_embedding = Mock()
client = ElasticsearchClient(
client=mock_async_elasticsearch_client,
embedding_function=mock_embedding
)
with pytest.raises(ClientMethodMismatchError):
client.create_collection(collection_name="test_index")
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_acreate_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that acreate_collection creates a new index asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = False
await async_client.acreate_collection(collection_name="test_index")
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.indices.create.assert_called_once()
call_args = mock_async_elasticsearch_client.indices.create.call_args
assert call_args.kwargs["index"] == "test_index"
assert "mappings" in call_args.kwargs["body"]
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_acreate_collection_already_exists(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that acreate_collection raises error if index exists."""
mock_async_elasticsearch_client.indices.exists.return_value = True
with pytest.raises(
ValueError, match="Index 'test_index' already exists"
):
await async_client.acreate_collection(collection_name="test_index")
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_get_or_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that get_or_create_collection returns existing index."""
mock_elasticsearch_client.indices.exists.return_value = True
result = client.get_or_create_collection(collection_name="test_index")
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
assert result == {"test_index": {"mappings": {}}}
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_get_or_create_collection_creates_new(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that get_or_create_collection creates new index if not exists."""
mock_elasticsearch_client.indices.exists.return_value = False
client.get_or_create_collection(collection_name="test_index")
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.indices.create.assert_called_once()
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_aget_or_create_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that aget_or_create_collection returns existing index asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = True
result = await async_client.aget_or_create_collection(collection_name="test_index")
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
assert result == {"test_index": {"mappings": {}}}
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_add_documents(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that add_documents indexes documents correctly."""
mock_elasticsearch_client.indices.exists.return_value = True
documents: list[BaseRecord] = [
{
"content": "test content",
"metadata": {"key": "value"}
}
]
client.add_documents(collection_name="test_index", documents=documents)
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.index.assert_called_once()
call_args = mock_elasticsearch_client.index.call_args
assert call_args.kwargs["index"] == "test_index"
assert "body" in call_args.kwargs
assert call_args.kwargs["body"]["content"] == "test content"
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_add_documents_empty_list_raises_error(self, mock_is_async, mock_is_sync, client):
"""Test that add_documents raises error with empty documents list."""
with pytest.raises(ValueError, match="Documents list cannot be empty"):
client.add_documents(collection_name="test_index", documents=[])
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_add_documents_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that add_documents raises error if index doesn't exist."""
mock_elasticsearch_client.indices.exists.return_value = False
documents: list[BaseRecord] = [{"content": "test content"}]
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
client.add_documents(collection_name="test_index", documents=documents)
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_aadd_documents(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that aadd_documents indexes documents correctly asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = True
documents: list[BaseRecord] = [
{
"content": "test content",
"metadata": {"key": "value"}
}
]
await async_client.aadd_documents(collection_name="test_index", documents=documents)
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.index.assert_called_once()
call_args = mock_async_elasticsearch_client.index.call_args
assert call_args.kwargs["index"] == "test_index"
assert "body" in call_args.kwargs
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_search(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that search performs vector similarity search."""
mock_elasticsearch_client.indices.exists.return_value = True
results = client.search(
collection_name="test_index",
query="test query",
limit=5
)
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.search.assert_called_once()
call_args = mock_elasticsearch_client.search.call_args
assert call_args.kwargs["index"] == "test_index"
assert "body" in call_args.kwargs
assert len(results) == 1
assert results[0]["id"] == "doc1"
assert results[0]["content"] == "test content"
assert results[0]["score"] == 0.9
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_search_with_metadata_filter(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that search applies metadata filter correctly."""
mock_elasticsearch_client.indices.exists.return_value = True
client.search(
collection_name="test_index",
query="test query",
metadata_filter={"key": "value"}
)
mock_elasticsearch_client.search.assert_called_once()
call_args = mock_elasticsearch_client.search.call_args
query_body = call_args.kwargs["body"]
assert "bool" in query_body["query"]
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_search_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that search raises error if index doesn't exist."""
mock_elasticsearch_client.indices.exists.return_value = False
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
client.search(collection_name="test_index", query="test query")
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_asearch(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that asearch performs vector similarity search asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = True
results = await async_client.asearch(
collection_name="test_index",
query="test query",
limit=5
)
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.search.assert_called_once()
assert len(results) == 1
assert results[0]["id"] == "doc1"
assert results[0]["content"] == "test content"
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_delete_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that delete_collection deletes the index."""
mock_elasticsearch_client.indices.exists.return_value = True
client.delete_collection(collection_name="test_index")
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_delete_collection_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that delete_collection raises error if index doesn't exist."""
mock_elasticsearch_client.indices.exists.return_value = False
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
client.delete_collection(collection_name="test_index")
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_adelete_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that adelete_collection deletes the index asynchronously."""
mock_async_elasticsearch_client.indices.exists.return_value = True
await async_client.adelete_collection(collection_name="test_index")
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
mock_async_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
def test_reset(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
"""Test that reset deletes all non-system indices."""
mock_elasticsearch_client.indices.get.return_value = {
"test_index": {},
".system_index": {},
"another_index": {}
}
client.reset()
mock_elasticsearch_client.indices.get.assert_called_once_with(index="*")
assert mock_elasticsearch_client.indices.delete.call_count == 2
delete_calls = [call.kwargs["index"] for call in mock_elasticsearch_client.indices.delete.call_args_list]
assert "test_index" in delete_calls
assert "another_index" in delete_calls
assert ".system_index" not in delete_calls
@pytest.mark.asyncio
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
async def test_areset(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
"""Test that areset deletes all non-system indices asynchronously."""
mock_async_elasticsearch_client.indices.get.return_value = {
"test_index": {},
".system_index": {},
"another_index": {}
}
await async_client.areset()
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="*")
assert mock_async_elasticsearch_client.indices.delete.call_count == 2

View File

@@ -1,49 +0,0 @@
"""Tests for Elasticsearch configuration."""
from crewai.rag.elasticsearch.config import ElasticsearchConfig
def test_elasticsearch_config_defaults():
"""Test that ElasticsearchConfig has correct defaults."""
config = ElasticsearchConfig()
assert config.provider == "elasticsearch"
assert config.vector_dimension == 384
assert config.similarity == "cosine"
assert config.embedding_function is not None
assert config.options["hosts"] == ["http://localhost:9200"]
assert config.options["use_ssl"] is False
def test_elasticsearch_config_custom_options():
"""Test that ElasticsearchConfig accepts custom options."""
custom_options = {
"hosts": ["https://elastic.example.com:9200"],
"username": "user",
"password": "pass",
"use_ssl": True,
}
config = ElasticsearchConfig(
options=custom_options,
vector_dimension=768,
similarity="dot_product"
)
assert config.provider == "elasticsearch"
assert config.vector_dimension == 768
assert config.similarity == "dot_product"
assert config.options["hosts"] == ["https://elastic.example.com:9200"]
assert config.options["username"] == "user"
assert config.options["use_ssl"] is True
def test_elasticsearch_config_embedding_function():
"""Test that embedding function works correctly."""
config = ElasticsearchConfig()
embedding = config.embedding_function("test text")
assert isinstance(embedding, list)
assert len(embedding) == config.vector_dimension
assert all(isinstance(x, float) for x in embedding)

View File

@@ -1,40 +0,0 @@
"""Tests for Elasticsearch factory."""
from unittest.mock import Mock, patch
import pytest
from crewai.rag.elasticsearch.config import ElasticsearchConfig
def test_create_client():
"""Test that create_client creates an ElasticsearchClient."""
config = ElasticsearchConfig()
with patch.dict('sys.modules', {'elasticsearch': Mock()}):
mock_elasticsearch_module = Mock()
mock_client_instance = Mock()
mock_elasticsearch_module.Elasticsearch.return_value = mock_client_instance
with patch.dict('sys.modules', {'elasticsearch': mock_elasticsearch_module}):
from crewai.rag.elasticsearch.factory import create_client
client = create_client(config)
mock_elasticsearch_module.Elasticsearch.assert_called_once_with(**config.options)
assert client.client == mock_client_instance
assert client.embedding_function == config.embedding_function
assert client.vector_dimension == config.vector_dimension
assert client.similarity == config.similarity
def test_create_client_missing_elasticsearch():
"""Test that create_client raises ImportError when elasticsearch is not installed."""
config = ElasticsearchConfig()
with patch.dict('sys.modules', {}, clear=False):
if 'elasticsearch' in __import__('sys').modules:
del __import__('sys').modules['elasticsearch']
from crewai.rag.elasticsearch.factory import create_client
with pytest.raises(ImportError, match="elasticsearch package is required"):
create_client(config)

View File

@@ -624,12 +624,12 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
_, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"]
assert any(isinstance(tool, TestTool) for tool in tools), (
"TestTool should be present"
)
assert any("delegate" in tool.name.lower() for tool in tools), (
"Delegation tool should be present"
)
assert any(
isinstance(tool, TestTool) for tool in tools
), "TestTool should be present"
assert any(
"delegate" in tool.name.lower() for tool in tools
), "Delegation tool should be present"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -688,12 +688,12 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
_, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"]
assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), (
"TestTool should be present"
)
assert any("delegate" in tool.name.lower() for tool in tools), (
"Delegation tool should be present"
)
assert any(
isinstance(tool, TestTool) for tool in new_ceo.tools
), "TestTool should be present"
assert any(
"delegate" in tool.name.lower() for tool in tools
), "Delegation tool should be present"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -817,17 +817,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
used_tools = kwargs["tools"]
# Confirm AnotherTestTool is present but TestTool is not
assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), (
"AnotherTestTool should be present"
)
assert not any(isinstance(tool, TestTool) for tool in used_tools), (
"TestTool should not be present among used tools"
)
assert any(
isinstance(tool, AnotherTestTool) for tool in used_tools
), "AnotherTestTool should be present"
assert not any(
isinstance(tool, TestTool) for tool in used_tools
), "TestTool should not be present among used tools"
# Confirm delegation tool(s) are present
assert any("delegate" in tool.name.lower() for tool in used_tools), (
"Delegation tool should be present"
)
assert any(
"delegate" in tool.name.lower() for tool in used_tools
), "Delegation tool should be present"
# Finally, make sure the agent's original tools remain unchanged
assert len(researcher_with_delegation.tools) == 1
@@ -931,9 +931,9 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
tool="multiplier", input={"first_number": 2, "second_number": 6}
)
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
assert cache_calls[1] == expected_call, (
f"Second call mismatch: {cache_calls[1]}"
)
assert (
cache_calls[1] == expected_call
), f"Second call mismatch: {cache_calls[1]}"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -1676,9 +1676,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
# Verify that exactly one tool was used and it was a CodeInterpreterTool
assert len(used_tools) == 1, "Should have exactly one tool"
assert isinstance(used_tools[0], CodeInterpreterTool), (
"Tool should be CodeInterpreterTool"
)
assert isinstance(
used_tools[0], CodeInterpreterTool
), "Tool should be CodeInterpreterTool"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -2537,8 +2537,8 @@ def test_memory_events_are_emitted():
crew.kickoff()
assert len(events["MemorySaveStartedEvent"]) == 6
assert len(events["MemorySaveCompletedEvent"]) == 6
assert len(events["MemorySaveStartedEvent"]) == 3
assert len(events["MemorySaveCompletedEvent"]) == 3
assert len(events["MemorySaveFailedEvent"]) == 0
assert len(events["MemoryQueryStartedEvent"]) == 3
assert len(events["MemoryQueryCompletedEvent"]) == 3
@@ -3817,9 +3817,9 @@ def test_fetch_inputs():
expected_placeholders = {"role_detail", "topic", "field"}
actual_placeholders = crew.fetch_inputs()
assert actual_placeholders == expected_placeholders, (
f"Expected {expected_placeholders}, but got {actual_placeholders}"
)
assert (
actual_placeholders == expected_placeholders
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
def test_task_tools_preserve_code_execution_tools():
@@ -3894,20 +3894,20 @@ def test_task_tools_preserve_code_execution_tools():
used_tools = kwargs["tools"]
# Verify all expected tools are present
assert any(isinstance(tool, TestTool) for tool in used_tools), (
"Task's TestTool should be present"
)
assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), (
"CodeInterpreterTool should be present"
)
assert any("delegate" in tool.name.lower() for tool in used_tools), (
"Delegation tool should be present"
)
assert any(
isinstance(tool, TestTool) for tool in used_tools
), "Task's TestTool should be present"
assert any(
isinstance(tool, CodeInterpreterTool) for tool in used_tools
), "CodeInterpreterTool should be present"
assert any(
"delegate" in tool.name.lower() for tool in used_tools
), "Delegation tool should be present"
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
assert len(used_tools) == 4, (
"Should have TestTool, CodeInterpreter, and 2 delegation tools"
)
assert (
len(used_tools) == 4
), "Should have TestTool, CodeInterpreter, and 2 delegation tools"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -3951,9 +3951,9 @@ def test_multimodal_flag_adds_multimodal_tools():
used_tools = kwargs["tools"]
# Check that the multimodal tool was added
assert any(isinstance(tool, AddImageTool) for tool in used_tools), (
"AddImageTool should be present when agent is multimodal"
)
assert any(
isinstance(tool, AddImageTool) for tool in used_tools
), "AddImageTool should be present when agent is multimodal"
# Verify we have exactly one tool (just the AddImageTool)
assert len(used_tools) == 1, "Should only have the AddImageTool"
@@ -4217,9 +4217,9 @@ def test_crew_guardrail_feedback_in_context():
assert len(execution_contexts) > 1, "Task should have been executed multiple times"
# Verify that the second execution included the guardrail feedback
assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], (
"Guardrail feedback should be included in retry context"
)
assert (
"Output must contain the keyword 'IMPORTANT'" in execution_contexts[1]
), "Guardrail feedback should be included in retry context"
# Verify final output meets guardrail requirements
assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
@@ -4435,46 +4435,46 @@ def test_crew_copy_with_memory():
try:
crew_copy = crew.copy()
assert hasattr(crew_copy, "_short_term_memory"), (
"Copied crew should have _short_term_memory"
)
assert crew_copy._short_term_memory is not None, (
"Copied _short_term_memory should not be None"
)
assert id(crew_copy._short_term_memory) != original_short_term_id, (
"Copied _short_term_memory should be a new object"
)
assert hasattr(
crew_copy, "_short_term_memory"
), "Copied crew should have _short_term_memory"
assert (
crew_copy._short_term_memory is not None
), "Copied _short_term_memory should not be None"
assert (
id(crew_copy._short_term_memory) != original_short_term_id
), "Copied _short_term_memory should be a new object"
assert hasattr(crew_copy, "_long_term_memory"), (
"Copied crew should have _long_term_memory"
)
assert crew_copy._long_term_memory is not None, (
"Copied _long_term_memory should not be None"
)
assert id(crew_copy._long_term_memory) != original_long_term_id, (
"Copied _long_term_memory should be a new object"
)
assert hasattr(
crew_copy, "_long_term_memory"
), "Copied crew should have _long_term_memory"
assert (
crew_copy._long_term_memory is not None
), "Copied _long_term_memory should not be None"
assert (
id(crew_copy._long_term_memory) != original_long_term_id
), "Copied _long_term_memory should be a new object"
assert hasattr(crew_copy, "_entity_memory"), (
"Copied crew should have _entity_memory"
)
assert crew_copy._entity_memory is not None, (
"Copied _entity_memory should not be None"
)
assert id(crew_copy._entity_memory) != original_entity_id, (
"Copied _entity_memory should be a new object"
)
assert hasattr(
crew_copy, "_entity_memory"
), "Copied crew should have _entity_memory"
assert (
crew_copy._entity_memory is not None
), "Copied _entity_memory should not be None"
assert (
id(crew_copy._entity_memory) != original_entity_id
), "Copied _entity_memory should be a new object"
if original_external_id:
assert hasattr(crew_copy, "_external_memory"), (
"Copied crew should have _external_memory"
)
assert crew_copy._external_memory is not None, (
"Copied _external_memory should not be None"
)
assert id(crew_copy._external_memory) != original_external_id, (
"Copied _external_memory should be a new object"
)
assert hasattr(
crew_copy, "_external_memory"
), "Copied crew should have _external_memory"
assert (
crew_copy._external_memory is not None
), "Copied _external_memory should not be None"
assert (
id(crew_copy._external_memory) != original_external_id
), "Copied _external_memory should be a new object"
else:
assert (
not hasattr(crew_copy, "_external_memory")