mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-05 22:28:29 +00:00
Compare commits
4 Commits
devin/1756
...
gl/feat/sy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4c27d22cf | ||
|
|
109de91d08 | ||
|
|
92b70e652d | ||
|
|
fc3f2c49d2 |
@@ -1,13 +1,13 @@
|
||||
---
|
||||
title: Weaviate Vector Search
|
||||
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents.
|
||||
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents using hybrid search.
|
||||
icon: network-wired
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector embeddings for more accurate and contextually relevant search results.
|
||||
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector and keyword search for more accurate and contextually relevant search results.
|
||||
|
||||
[Weaviate](https://weaviate.io/) is a vector database that stores and queries vector embeddings, enabling semantic search capabilities.
|
||||
|
||||
@@ -39,6 +39,7 @@ from crewai_tools import WeaviateVectorSearchTool
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
@@ -63,6 +64,7 @@ The `WeaviateVectorSearchTool` accepts the following parameters:
|
||||
- **weaviate_cluster_url**: Required. The URL of the Weaviate cluster.
|
||||
- **weaviate_api_key**: Required. The API key for the Weaviate cluster.
|
||||
- **limit**: Optional. The number of results to return. Default is `3`.
|
||||
- **alpha**: Optional. Controls the weighting between vector and keyword (BM25) search. alpha = 0 -> BM25 only, alpha = 1 -> vector search only. Default is `0.75`.
|
||||
- **vectorizer**: Optional. The vectorizer to use. If not provided, it will use `text2vec_openai` with the `nomic-embed-text` model.
|
||||
- **generative_model**: Optional. The generative model to use. If not provided, it will use OpenAI's `gpt-4o`.
|
||||
|
||||
@@ -78,6 +80,7 @@ from weaviate.classes.config import Configure
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
vectorizer=Configure.Vectorizer.text2vec_openai(model="nomic-embed-text"),
|
||||
generative_model=Configure.Generative.openai(model="gpt-4o-mini"),
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
@@ -128,6 +131,7 @@ with test_docs.batch.dynamic() as batch:
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
@@ -145,6 +149,7 @@ from crewai_tools import WeaviateVectorSearchTool
|
||||
weaviate_tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
|
||||
@@ -98,8 +98,8 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
self.crew._long_term_memory.save(long_term_memory)
|
||||
|
||||
for entity in evaluation.entities:
|
||||
entity_memory = EntityMemoryItem(
|
||||
entity_memories = [
|
||||
EntityMemoryItem(
|
||||
name=entity.name,
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
@@ -107,7 +107,10 @@ class CrewAgentExecutorMixin:
|
||||
[f"- {r}" for r in entity.relationships]
|
||||
),
|
||||
)
|
||||
self.crew._entity_memory.save(entity_memory)
|
||||
for entity in evaluation.entities
|
||||
]
|
||||
if entity_memories:
|
||||
self.crew._entity_memory.save(entity_memories)
|
||||
except AttributeError as e:
|
||||
print(f"Missing attributes for long term memory: {e}")
|
||||
pass
|
||||
|
||||
@@ -1,6 +1 @@
|
||||
ALGORITHMS = ["RS256"]
|
||||
|
||||
#TODO: The AUTH0 constants should be removed after WorkOS migration is completed
|
||||
AUTH0_DOMAIN = "crewai.us.auth0.com"
|
||||
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
|
||||
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"
|
||||
|
||||
@@ -9,14 +9,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from .utils import validate_jwt_token
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from urllib.parse import quote
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.authentication.constants import (
|
||||
AUTH0_AUDIENCE,
|
||||
AUTH0_CLIENT_ID,
|
||||
AUTH0_DOMAIN,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -72,18 +65,6 @@ class AuthenticationCommand:
|
||||
"""Sign up to CrewAI+"""
|
||||
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
|
||||
|
||||
# TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete.
|
||||
user_provider = self._determine_user_provider()
|
||||
if user_provider == "auth0":
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
client_id=AUTH0_CLIENT_ID,
|
||||
domain=AUTH0_DOMAIN,
|
||||
audience=AUTH0_AUDIENCE,
|
||||
)
|
||||
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
||||
# End of temporary code.
|
||||
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
|
||||
@@ -206,30 +187,3 @@ class AuthenticationCommand:
|
||||
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
)
|
||||
|
||||
# TODO: WORKOS - This method is temporary until migration to WorkOS is complete.
|
||||
def _determine_user_provider(self) -> str:
|
||||
"""Determine which provider to use for authentication."""
|
||||
|
||||
console.print(
|
||||
"Enter your CrewAI Enterprise account email: ", style="bold blue", end=""
|
||||
)
|
||||
email = input()
|
||||
email_encoded = quote(email)
|
||||
|
||||
# It's not correct to call this method directly, but it's temporary until migration is complete.
|
||||
response = PlusAPI("")._make_request(
|
||||
"GET", f"/crewai_plus/api/v1/me/provider?email={email_encoded}"
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
if response.json().get("provider") == "auth0":
|
||||
return "auth0"
|
||||
else:
|
||||
return "workos"
|
||||
else:
|
||||
console.print(
|
||||
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
|
||||
style="red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any
|
||||
import time
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
@@ -24,7 +24,7 @@ class EntityMemory(Memory):
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
@@ -53,12 +53,33 @@ class EntityMemory(Memory):
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
def save(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Saves one or more entity items into the SQLite storage.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (included for supertype compatibility but not used).
|
||||
|
||||
Notes:
|
||||
The metadata parameter is included to satisfy the supertype signature but is not
|
||||
used - entity metadata is extracted from the EntityMemoryItem objects themselves.
|
||||
"""
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=item.metadata,
|
||||
metadata=metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -66,36 +87,61 @@ class EntityMemory(Memory):
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors = []
|
||||
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
for item in items:
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super().save(data, item.metadata)
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
|
||||
super().save(data, item.metadata)
|
||||
|
||||
# Emit memory save completed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=data,
|
||||
metadata=item.metadata,
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=item.metadata,
|
||||
metadata=fail_metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
|
||||
@@ -14,7 +14,7 @@ class _MissingProvider:
|
||||
Raises RuntimeError when instantiated to indicate missing dependencies.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "qdrant", "elasticsearch", "__missing__"] = field(
|
||||
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
|
||||
default="__missing__"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
|
||||
|
||||
class ChromaFactoryModule(Protocol):
|
||||
@@ -27,11 +25,3 @@ class QdrantFactoryModule(Protocol):
|
||||
def create_client(self, config: QdrantConfig) -> QdrantClient:
|
||||
"""Creates a Qdrant client from configuration."""
|
||||
...
|
||||
|
||||
|
||||
class ElasticsearchFactoryModule(Protocol):
|
||||
"""Protocol for Elasticsearch factory module."""
|
||||
|
||||
def create_client(self, config: ElasticsearchConfig) -> ElasticsearchClient:
|
||||
"""Creates an Elasticsearch client from configuration."""
|
||||
...
|
||||
|
||||
@@ -20,10 +20,3 @@ class MissingQdrantConfig(_MissingProvider):
|
||||
"""Placeholder for missing Qdrant configuration."""
|
||||
|
||||
provider: Literal["qdrant"] = field(default="qdrant")
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class MissingElasticsearchConfig(_MissingProvider):
|
||||
"""Placeholder for missing Elasticsearch configuration."""
|
||||
|
||||
provider: Literal["elasticsearch"] = field(default="elasticsearch")
|
||||
|
||||
@@ -3,6 +3,6 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
SupportedProvider = Annotated[
|
||||
Literal["chromadb", "qdrant", "elasticsearch"],
|
||||
Literal["chromadb", "qdrant"],
|
||||
"Supported RAG provider types, add providers here as they become available",
|
||||
]
|
||||
|
||||
@@ -13,9 +13,6 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
|
||||
|
||||
QdrantConfig = QdrantConfig_
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig as ElasticsearchConfig_
|
||||
|
||||
ElasticsearchConfig = ElasticsearchConfig_
|
||||
else:
|
||||
try:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
@@ -31,14 +28,7 @@ else:
|
||||
MissingQdrantConfig as QdrantConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
except ImportError:
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingElasticsearchConfig as ElasticsearchConfig,
|
||||
)
|
||||
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig | ElasticsearchConfig
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
|
||||
RagConfigType: TypeAlias = Annotated[
|
||||
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
||||
]
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Elasticsearch RAG implementation."""
|
||||
@@ -1,502 +0,0 @@
|
||||
"""Elasticsearch client implementation."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.elasticsearch.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
ElasticsearchClientType,
|
||||
ElasticsearchCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.elasticsearch.utils import (
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_prepare_document_for_elasticsearch,
|
||||
_process_search_results,
|
||||
_build_vector_search_query,
|
||||
_get_index_mapping,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class ElasticsearchClient(BaseClient):
|
||||
"""Elasticsearch implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for Elasticsearch, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: Elasticsearch client instance (Elasticsearch or AsyncElasticsearch).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
vector_dimension: Dimension of the embedding vectors.
|
||||
similarity: Similarity function to use for vector search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ElasticsearchClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
vector_dimension: int = 384,
|
||||
similarity: str = "cosine",
|
||||
) -> None:
|
||||
"""Initialize ElasticsearchClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured Elasticsearch client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
vector_dimension: Dimension of the embedding vectors.
|
||||
similarity: Similarity function to use for vector search.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.vector_dimension = vector_dimension
|
||||
self.similarity = similarity
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
|
||||
"""Create a new index in Elasticsearch.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to create. Must be unique.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Raises:
|
||||
ValueError: If index with the same name already exists.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="create_collection",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="acreate_collection",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' already exists")
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
self.client.indices.create(index=collection_name, body=mapping)
|
||||
|
||||
async def acreate_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> None:
|
||||
"""Create a new index in Elasticsearch asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to create. Must be unique.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Raises:
|
||||
ValueError: If index with the same name already exists.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="acreate_collection",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="create_collection",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' already exists")
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
await self.client.indices.create(index=collection_name, body=mapping)
|
||||
|
||||
def get_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
|
||||
"""Get an existing index or create it if it doesn't exist.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to get or create.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Returns:
|
||||
Index info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="get_or_create_collection",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="aget_or_create_collection",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.indices.exists(index=collection_name):
|
||||
return self.client.indices.get(index=collection_name)
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
self.client.indices.create(index=collection_name, body=mapping)
|
||||
return self.client.indices.get(index=collection_name)
|
||||
|
||||
async def aget_or_create_collection(self, **kwargs: Unpack[ElasticsearchCollectionCreateParams]) -> Any:
|
||||
"""Get an existing index or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to get or create.
|
||||
index_settings: Optional index settings.
|
||||
vector_dimension: Optional vector dimension override.
|
||||
similarity: Optional similarity function override.
|
||||
|
||||
Returns:
|
||||
Index info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aget_or_create_collection",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="get_or_create_collection",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.indices.exists(index=collection_name):
|
||||
return await self.client.indices.get(index=collection_name)
|
||||
|
||||
vector_dim = kwargs.get("vector_dimension", self.vector_dimension)
|
||||
similarity = kwargs.get("similarity", self.similarity)
|
||||
|
||||
mapping = _get_index_mapping(vector_dim, similarity)
|
||||
|
||||
index_settings = kwargs.get("index_settings", {})
|
||||
if index_settings:
|
||||
mapping["settings"] = index_settings
|
||||
|
||||
await self.client.indices.create(index=collection_name, body=mapping)
|
||||
return await self.client.indices.get(index=collection_name)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to an index.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the index to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="add_documents",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="aadd_documents",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
|
||||
|
||||
self.client.index(
|
||||
index=collection_name,
|
||||
id=prepared_doc["id"],
|
||||
body=prepared_doc["body"]
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to an index asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the index to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aadd_documents",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="add_documents",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
|
||||
prepared_doc = _prepare_document_for_elasticsearch(doc, embedding)
|
||||
|
||||
await self.client.index(
|
||||
index=collection_name,
|
||||
id=prepared_doc["id"],
|
||||
body=prepared_doc["body"]
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="search",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="asearch",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync search. "
|
||||
"Use asearch instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_query = _build_vector_search_query(
|
||||
query_vector=query_embedding,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
response = self.client.search(index=collection_name, body=search_query)
|
||||
return _process_search_results(response, score_threshold)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="asearch",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="search",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
query_embedding = await async_fn(query)
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_query = _build_vector_search_query(
|
||||
query_vector=query_embedding,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
response = await self.client.search(index=collection_name, body=search_query)
|
||||
return _process_search_results(response, score_threshold)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete an index and all its data.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="delete_collection",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="adelete_collection",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
self.client.indices.delete(index=collection_name)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete an index and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the index to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If index doesn't exist.
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="adelete_collection",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="delete_collection",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not await self.client.indices.exists(index=collection_name):
|
||||
raise ValueError(f"Index '{collection_name}' does not exist")
|
||||
|
||||
await self.client.indices.delete(index=collection_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all indices and data.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="reset",
|
||||
expected_client="Elasticsearch",
|
||||
alt_method="areset",
|
||||
alt_client="AsyncElasticsearch",
|
||||
)
|
||||
|
||||
indices_response = self.client.indices.get(index="*")
|
||||
|
||||
for index_name in indices_response.keys():
|
||||
if not index_name.startswith("."):
|
||||
self.client.indices.delete(index=index_name)
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all indices and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Elasticsearch server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="areset",
|
||||
expected_client="AsyncElasticsearch",
|
||||
alt_method="reset",
|
||||
alt_client="Elasticsearch",
|
||||
)
|
||||
|
||||
indices_response = await self.client.indices.get(index="*")
|
||||
|
||||
for index_name in indices_response.keys():
|
||||
if not index_name.startswith("."):
|
||||
await self.client.indices.delete(index=index_name)
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Elasticsearch configuration model."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.elasticsearch.types import (
|
||||
ElasticsearchClientParams,
|
||||
ElasticsearchEmbeddingFunctionWrapper,
|
||||
)
|
||||
from crewai.rag.elasticsearch.constants import (
|
||||
DEFAULT_HOST,
|
||||
DEFAULT_PORT,
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
)
|
||||
|
||||
|
||||
def _default_options() -> ElasticsearchClientParams:
|
||||
"""Create default Elasticsearch client options.
|
||||
|
||||
Returns:
|
||||
Default options with local Elasticsearch connection.
|
||||
"""
|
||||
return ElasticsearchClientParams(
|
||||
hosts=[f"http://{DEFAULT_HOST}:{DEFAULT_PORT}"],
|
||||
use_ssl=False,
|
||||
verify_certs=False,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedding_function() -> ElasticsearchEmbeddingFunctionWrapper:
|
||||
"""Create default Elasticsearch embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function using sentence-transformers.
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
model = SentenceTransformer(DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
def embed_fn(text: str) -> list[float]:
|
||||
"""Embed a single text string.
|
||||
|
||||
Args:
|
||||
text: Text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
embedding = model.encode(text, convert_to_tensor=False)
|
||||
return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
|
||||
|
||||
return cast(ElasticsearchEmbeddingFunctionWrapper, embed_fn)
|
||||
except ImportError:
|
||||
def fallback_embed_fn(text: str) -> list[float]:
|
||||
"""Fallback embedding function when sentence-transformers is not available."""
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
hash_obj = hashlib.md5(text.encode(), usedforsecurity=False)
|
||||
hash_bytes = hash_obj.digest()
|
||||
|
||||
vector = []
|
||||
for i in range(0, len(hash_bytes), 4):
|
||||
chunk = hash_bytes[i:i+4]
|
||||
if len(chunk) == 4:
|
||||
value = struct.unpack('f', chunk)[0]
|
||||
vector.append(float(value))
|
||||
|
||||
while len(vector) < DEFAULT_VECTOR_DIMENSION:
|
||||
vector.extend(vector[:DEFAULT_VECTOR_DIMENSION - len(vector)])
|
||||
|
||||
return vector[:DEFAULT_VECTOR_DIMENSION]
|
||||
|
||||
return cast(ElasticsearchEmbeddingFunctionWrapper, fallback_embed_fn)
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class ElasticsearchConfig(BaseRagConfig):
|
||||
"""Configuration for Elasticsearch client."""
|
||||
|
||||
provider: Literal["elasticsearch"] = field(default="elasticsearch", init=False)
|
||||
options: ElasticsearchClientParams = field(default_factory=_default_options)
|
||||
vector_dimension: int = DEFAULT_VECTOR_DIMENSION
|
||||
similarity: str = "cosine"
|
||||
embedding_function: ElasticsearchEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Constants for Elasticsearch RAG implementation."""
|
||||
|
||||
from typing import Final
|
||||
|
||||
DEFAULT_HOST: Final[str] = "localhost"
|
||||
DEFAULT_PORT: Final[int] = 9200
|
||||
DEFAULT_INDEX_SETTINGS: Final[dict] = {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0,
|
||||
}
|
||||
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
DEFAULT_VECTOR_DIMENSION: Final[int] = 384
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Factory functions for creating Elasticsearch clients."""
|
||||
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
||||
|
||||
|
||||
def create_client(config: ElasticsearchConfig) -> ElasticsearchClient:
|
||||
"""Create an ElasticsearchClient from configuration.
|
||||
|
||||
Args:
|
||||
config: Elasticsearch configuration object.
|
||||
|
||||
Returns:
|
||||
Configured ElasticsearchClient instance.
|
||||
"""
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"elasticsearch package is required for Elasticsearch support. "
|
||||
"Install it with: pip install elasticsearch"
|
||||
) from e
|
||||
|
||||
client = Elasticsearch(**config.options)
|
||||
|
||||
return ElasticsearchClient(
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
vector_dimension=config.vector_dimension,
|
||||
similarity=config.similarity,
|
||||
)
|
||||
@@ -1,93 +0,0 @@
|
||||
"""Type definitions for Elasticsearch RAG implementation."""
|
||||
|
||||
from typing import Any, Protocol, Union, TYPE_CHECKING
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
from elasticsearch import Elasticsearch, AsyncElasticsearch
|
||||
ElasticsearchClientType: TypeAlias = Union[Elasticsearch, AsyncElasticsearch]
|
||||
else:
|
||||
try:
|
||||
from elasticsearch import Elasticsearch, AsyncElasticsearch
|
||||
ElasticsearchClientType = Union[Elasticsearch, AsyncElasticsearch]
|
||||
except ImportError:
|
||||
ElasticsearchClientType = Any
|
||||
|
||||
|
||||
class ElasticsearchClientParams(TypedDict, total=False):
|
||||
"""Parameters for Elasticsearch client initialization."""
|
||||
|
||||
hosts: NotRequired[list[str]]
|
||||
cloud_id: NotRequired[str]
|
||||
username: NotRequired[str]
|
||||
password: NotRequired[str]
|
||||
api_key: NotRequired[str]
|
||||
use_ssl: NotRequired[bool]
|
||||
verify_certs: NotRequired[bool]
|
||||
ca_certs: NotRequired[str]
|
||||
timeout: NotRequired[int]
|
||||
|
||||
|
||||
class ElasticsearchIndexSettings(TypedDict, total=False):
|
||||
"""Settings for Elasticsearch index creation."""
|
||||
|
||||
number_of_shards: NotRequired[int]
|
||||
number_of_replicas: NotRequired[int]
|
||||
refresh_interval: NotRequired[str]
|
||||
|
||||
|
||||
class ElasticsearchCollectionCreateParams(TypedDict, total=False):
|
||||
"""Parameters for creating Elasticsearch collections/indices."""
|
||||
|
||||
collection_name: str
|
||||
index_settings: NotRequired[ElasticsearchIndexSettings]
|
||||
vector_dimension: NotRequired[int]
|
||||
similarity: NotRequired[str]
|
||||
|
||||
|
||||
class EmbeddingFunction(Protocol):
|
||||
"""Protocol for embedding functions that convert text to vectors."""
|
||||
|
||||
def __call__(self, text: str) -> list[float]:
|
||||
"""Convert text to embedding vector.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AsyncEmbeddingFunction(Protocol):
|
||||
"""Protocol for async embedding functions that convert text to vectors."""
|
||||
|
||||
async def __call__(self, text: str) -> list[float]:
|
||||
"""Convert text to embedding vector asynchronously.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ElasticsearchEmbeddingFunctionWrapper(EmbeddingFunction):
|
||||
"""Base class for Elasticsearch EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for Elasticsearch EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle Elasticsearch's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
@@ -1,186 +0,0 @@
|
||||
"""Utility functions for Elasticsearch RAG implementation."""
|
||||
|
||||
import hashlib
|
||||
from typing import Any, TypeGuard
|
||||
|
||||
from crewai.rag.elasticsearch.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
ElasticsearchClientType,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
try:
|
||||
from elasticsearch import Elasticsearch, AsyncElasticsearch
|
||||
except ImportError:
|
||||
Elasticsearch = None
|
||||
AsyncElasticsearch = None
|
||||
|
||||
|
||||
def _is_sync_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
|
||||
"""Type guard to check if the client is a sync Elasticsearch client."""
|
||||
if Elasticsearch is None:
|
||||
return False
|
||||
return isinstance(client, Elasticsearch)
|
||||
|
||||
|
||||
def _is_async_client(client: ElasticsearchClientType) -> TypeGuard[Any]:
|
||||
"""Type guard to check if the client is an async Elasticsearch client."""
|
||||
if AsyncElasticsearch is None:
|
||||
return False
|
||||
return isinstance(client, AsyncElasticsearch)
|
||||
|
||||
|
||||
def _is_async_embedding_function(
|
||||
func: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
) -> TypeGuard[AsyncEmbeddingFunction]:
|
||||
"""Type guard to check if the embedding function is async."""
|
||||
import inspect
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _generate_doc_id(content: str) -> str:
|
||||
"""Generate a document ID from content using SHA256 hash."""
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def _prepare_document_for_elasticsearch(
|
||||
doc: BaseRecord, embedding: list[float]
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare a document for Elasticsearch indexing.
|
||||
|
||||
Args:
|
||||
doc: Document record to prepare.
|
||||
embedding: Embedding vector for the document.
|
||||
|
||||
Returns:
|
||||
Document formatted for Elasticsearch.
|
||||
"""
|
||||
doc_id = doc.get("doc_id") or _generate_doc_id(doc["content"])
|
||||
|
||||
es_doc = {
|
||||
"content": doc["content"],
|
||||
"content_vector": embedding,
|
||||
"metadata": doc.get("metadata", {}),
|
||||
}
|
||||
|
||||
return {"id": doc_id, "body": es_doc}
|
||||
|
||||
|
||||
def _process_search_results(
|
||||
response: dict[str, Any], score_threshold: float | None = None
|
||||
) -> list[SearchResult]:
|
||||
"""Process Elasticsearch search response into SearchResult format.
|
||||
|
||||
Args:
|
||||
response: Raw Elasticsearch search response.
|
||||
score_threshold: Optional minimum score threshold.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dictionaries.
|
||||
"""
|
||||
results = []
|
||||
|
||||
hits = response.get("hits", {}).get("hits", [])
|
||||
|
||||
for hit in hits:
|
||||
score = hit.get("_score", 0.0)
|
||||
|
||||
if score_threshold is not None and score < score_threshold:
|
||||
continue
|
||||
|
||||
source = hit.get("_source", {})
|
||||
|
||||
result = SearchResult(
|
||||
id=hit.get("_id", ""),
|
||||
content=source.get("content", ""),
|
||||
metadata=source.get("metadata", {}),
|
||||
score=score,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _build_vector_search_query(
|
||||
query_vector: list[float],
|
||||
limit: int = 10,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build Elasticsearch query for vector similarity search.
|
||||
|
||||
Args:
|
||||
query_vector: Query embedding vector.
|
||||
limit: Maximum number of results.
|
||||
metadata_filter: Optional metadata filter.
|
||||
score_threshold: Optional minimum score threshold.
|
||||
|
||||
Returns:
|
||||
Elasticsearch query dictionary.
|
||||
"""
|
||||
query = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
|
||||
"params": {"query_vector": query_vector}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if metadata_filter:
|
||||
bool_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query["query"]
|
||||
],
|
||||
"filter": []
|
||||
}
|
||||
}
|
||||
|
||||
for key, value in metadata_filter.items():
|
||||
bool_query["bool"]["filter"].append({
|
||||
"term": {f"metadata.{key}": value}
|
||||
})
|
||||
|
||||
query["query"] = bool_query
|
||||
|
||||
if score_threshold is not None:
|
||||
query["min_score"] = score_threshold
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def _get_index_mapping(vector_dimension: int, similarity: str = "cosine") -> dict[str, Any]:
|
||||
"""Get Elasticsearch index mapping for vector search.
|
||||
|
||||
Args:
|
||||
vector_dimension: Dimension of the embedding vectors.
|
||||
similarity: Similarity function to use.
|
||||
|
||||
Returns:
|
||||
Elasticsearch mapping dictionary.
|
||||
"""
|
||||
return {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "text",
|
||||
"analyzer": "standard"
|
||||
},
|
||||
"content_vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": vector_dimension,
|
||||
"similarity": similarity
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"dynamic": True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ from typing import cast
|
||||
from crewai.rag.config.optional_imports.protocols import (
|
||||
ChromaFactoryModule,
|
||||
QdrantFactoryModule,
|
||||
ElasticsearchFactoryModule,
|
||||
)
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
@@ -44,15 +43,3 @@ def create_client(config: RagConfigType) -> BaseClient:
|
||||
),
|
||||
)
|
||||
return qdrant_mod.create_client(config)
|
||||
|
||||
if config.provider == "elasticsearch":
|
||||
elasticsearch_mod = cast(
|
||||
ElasticsearchFactoryModule,
|
||||
require(
|
||||
"crewai.rag.elasticsearch.factory",
|
||||
purpose="The 'elasticsearch' provider",
|
||||
),
|
||||
)
|
||||
return elasticsearch_mod.create_client(config)
|
||||
|
||||
raise ValueError(f"Unsupported provider: {config.provider}")
|
||||
|
||||
622
src/crewai/rag/utils/synchronized.py
Normal file
622
src/crewai/rag/utils/synchronized.py
Normal file
@@ -0,0 +1,622 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any, TypeVar, ParamSpec, Concatenate, TypedDict
|
||||
|
||||
import portalocker
|
||||
from portalocker import constants
|
||||
from typing_extensions import NotRequired, Self
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
_STATE: dict[str, int] = {"pid": os.getpid()}
|
||||
|
||||
|
||||
def _reset_after_fork() -> None:
|
||||
"""Reset in-process state in the child after a fork.
|
||||
|
||||
Resets all locks and thread-local storage after a process fork
|
||||
to prevent lock contamination across processes.
|
||||
"""
|
||||
global _sync_rlocks, _async_locks_by_loop, _tls, _task_depths_var, _STATE
|
||||
_sync_rlocks = {}
|
||||
_async_locks_by_loop = weakref.WeakKeyDictionary()
|
||||
_tls = threading.local()
|
||||
# Reset task-local depths for async
|
||||
_task_depths_var = contextvars.ContextVar("locked_task_depths", default=None)
|
||||
_STATE["pid"] = os.getpid()
|
||||
|
||||
|
||||
def _ensure_same_process() -> None:
|
||||
"""Ensure we're in the same process, reset if forked.
|
||||
|
||||
Checks if the current PID matches the stored PID and resets
|
||||
state if a fork has occurred.
|
||||
"""
|
||||
if _STATE["pid"] != os.getpid():
|
||||
_reset_after_fork()
|
||||
|
||||
|
||||
# Automatically reset in a forked child on POSIX
|
||||
_register_at_fork = getattr(os, "register_at_fork", None)
|
||||
if _register_at_fork is not None:
|
||||
_register_at_fork(after_in_child=_reset_after_fork)
|
||||
|
||||
|
||||
class LockConfig(TypedDict):
|
||||
"""Configuration for portalocker locks.
|
||||
|
||||
Attributes:
|
||||
mode: File open mode.
|
||||
timeout: Optional lock timeout.
|
||||
check_interval: Optional check interval.
|
||||
fail_when_locked: Whether to fail if already locked.
|
||||
flags: Portalocker lock flags.
|
||||
"""
|
||||
|
||||
mode: str
|
||||
timeout: NotRequired[float]
|
||||
check_interval: NotRequired[float]
|
||||
fail_when_locked: bool
|
||||
flags: portalocker.LockFlags
|
||||
|
||||
|
||||
def _get_platform_lock_flags() -> portalocker.LockFlags:
|
||||
"""Get platform-appropriate lock flags.
|
||||
|
||||
Returns:
|
||||
LockFlags.EXCLUSIVE for exclusive file locking.
|
||||
"""
|
||||
# Use EXCLUSIVE flag only - let portalocker handle blocking/non-blocking internally
|
||||
return constants.LockFlags.EXCLUSIVE
|
||||
|
||||
|
||||
def _get_lock_config() -> LockConfig:
|
||||
"""Get lock configuration appropriate for the platform.
|
||||
|
||||
Returns:
|
||||
LockConfig dict with mode, flags, and fail_when_locked settings.
|
||||
"""
|
||||
config: LockConfig = {
|
||||
"mode": "a+",
|
||||
"fail_when_locked": False,
|
||||
"flags": _get_platform_lock_flags(),
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
LOCK_CONFIG: LockConfig = _get_lock_config()
|
||||
LOCK_STALE_SECONDS = 120
|
||||
|
||||
|
||||
def _default_lock_dir() -> Path:
|
||||
"""Get or create the default lock directory.
|
||||
|
||||
Returns:
|
||||
Path to ~/.crewai/locks directory, created if necessary.
|
||||
"""
|
||||
lock_dir = Path.home() / ".crewai" / "locks"
|
||||
lock_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Best-effort: restrict perms on POSIX
|
||||
try:
|
||||
if os.name == "posix":
|
||||
lock_dir.chmod(0o700)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clean up old lock files
|
||||
_cleanup_stale_locks(lock_dir)
|
||||
return lock_dir
|
||||
|
||||
|
||||
def _cleanup_stale_locks(lock_dir: Path, max_age_seconds: int = 86400) -> None:
|
||||
"""Remove lock files older than max_age_seconds.
|
||||
|
||||
Args:
|
||||
lock_dir: Directory containing lock files.
|
||||
max_age_seconds: Maximum age before considering a lock stale (default 24 hours).
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
for lock_file in lock_dir.glob("*.lock"):
|
||||
try:
|
||||
# Check if file is old and not currently locked
|
||||
file_age = current_time - lock_file.stat().st_mtime
|
||||
if file_age > max_age_seconds:
|
||||
# Try to acquire exclusive lock - if successful, file is not in use
|
||||
try:
|
||||
with portalocker.Lock(
|
||||
str(lock_file),
|
||||
mode="a+",
|
||||
timeout=0.01, # Very short timeout
|
||||
fail_when_locked=True,
|
||||
flags=constants.LockFlags.EXCLUSIVE,
|
||||
):
|
||||
pass # We got the lock, file is not in use
|
||||
# Safe to remove
|
||||
lock_file.unlink(missing_ok=True)
|
||||
except (portalocker.LockException, OSError):
|
||||
# File is locked or can't be accessed, skip it
|
||||
pass
|
||||
except (OSError, IOError):
|
||||
# Skip files we can't stat or process
|
||||
pass
|
||||
except Exception:
|
||||
# Cleanup is best-effort, don't fail on errors
|
||||
pass
|
||||
|
||||
|
||||
def _hash_str(value: str) -> str:
|
||||
"""Generate a short hash of a string.
|
||||
|
||||
Args:
|
||||
value: String to hash.
|
||||
|
||||
Returns:
|
||||
First 10 characters of SHA256 hash.
|
||||
"""
|
||||
return hashlib.sha256(value.encode()).hexdigest()[:10]
|
||||
|
||||
|
||||
def _qualname_for(func: Callable[..., Any], owner: object | None = None) -> str:
|
||||
"""Get qualified name for a function.
|
||||
|
||||
Args:
|
||||
func: Function to get qualified name for.
|
||||
owner: Optional owner object for the function.
|
||||
|
||||
Returns:
|
||||
Fully qualified name including module and class.
|
||||
"""
|
||||
target = inspect.unwrap(func)
|
||||
|
||||
if inspect.ismethod(func) and getattr(func, "__self__", None) is not None:
|
||||
owner_obj = func.__self__
|
||||
cls = owner_obj if inspect.isclass(owner_obj) else owner_obj.__class__
|
||||
return f"{target.__module__}.{cls.__qualname__}.{getattr(target, '__name__', '<?>')}"
|
||||
|
||||
if owner is not None:
|
||||
cls = owner if inspect.isclass(owner) else owner.__class__
|
||||
return f"{target.__module__}.{cls.__qualname__}.{getattr(target, '__name__', '<?>')}"
|
||||
|
||||
qn = getattr(target, "__qualname__", None)
|
||||
if qn is not None:
|
||||
return f"{getattr(target, '__module__', target.__class__.__module__)}.{qn}"
|
||||
|
||||
if isinstance(target, functools.partial):
|
||||
f = inspect.unwrap(target.func)
|
||||
return f"{getattr(f, '__module__', 'builtins')}.{getattr(f, '__qualname__', getattr(f, '__name__', '<?>'))}"
|
||||
|
||||
cls = target.__class__
|
||||
return f"{cls.__module__}.{cls.__qualname__}.__call__"
|
||||
|
||||
|
||||
def _get_lock_context(
|
||||
instance: Any | None,
|
||||
func: Callable[..., Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> tuple[Path, str | None]:
|
||||
"""Extract lock context from function call.
|
||||
|
||||
Args:
|
||||
instance: Instance the function is called on.
|
||||
func: Function being called.
|
||||
kwargs: Keyword arguments passed to function.
|
||||
|
||||
Returns:
|
||||
Tuple of (lock_file_path, collection_name).
|
||||
"""
|
||||
collection_name = (
|
||||
str(kwargs.get("collection_name")) if "collection_name" in kwargs else None
|
||||
)
|
||||
lock_dir = _default_lock_dir()
|
||||
base = _qualname_for(func, owner=instance)
|
||||
safe_base = re.sub(r"[^\w.\-]+", "_", base)
|
||||
suffix = f"_{_hash_str(collection_name)}" if collection_name else ""
|
||||
path = lock_dir / f"{safe_base}{suffix}.lock"
|
||||
return path, collection_name
|
||||
|
||||
|
||||
def _write_lock_metadata(lock_file_path: Path) -> None:
|
||||
"""Write metadata to lock file for staleness detection.
|
||||
|
||||
Args:
|
||||
lock_file_path: Path to the lock file.
|
||||
"""
|
||||
try:
|
||||
with open(lock_file_path, "w") as f:
|
||||
f.write(f"{os.getpid()}\n{time.time()}\n")
|
||||
f.flush()
|
||||
os.fsync(f.fileno()) # Ensure data is written to disk
|
||||
|
||||
# Set restrictive permissions on lock file (Unix only)
|
||||
if sys.platform not in ("win32", "cygwin"):
|
||||
try:
|
||||
lock_file_path.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
# Best effort - don't fail if we can't write metadata
|
||||
pass
|
||||
|
||||
|
||||
def _check_lock_staleness(lock_file_path: Path) -> bool:
|
||||
"""Check if a lock file is stale.
|
||||
|
||||
Args:
|
||||
lock_file_path: Path to the lock file.
|
||||
|
||||
Returns:
|
||||
True if lock is stale, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if not lock_file_path.exists():
|
||||
return False
|
||||
|
||||
with open(lock_file_path) as f:
|
||||
lines = f.readlines()
|
||||
if len(lines) < 2:
|
||||
return True # unreadable metadata
|
||||
|
||||
pid = int(lines[0].strip())
|
||||
ts = float(lines[1].strip())
|
||||
|
||||
# If the process is alive, do NOT treat as stale based on time alone.
|
||||
if sys.platform not in ("win32", "cygwin"):
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return False # alive → not stale
|
||||
except (OSError, ProcessLookupError):
|
||||
pass # dead process → proceed to time check
|
||||
|
||||
# Process dead: time window can be small; consider stale now
|
||||
return (time.time() - ts) > 1.0 # essentially “dead means stale”
|
||||
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
_sync_rlocks: dict[Path, threading.RLock] = {}
|
||||
_sync_rlocks_guard = threading.Lock()
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def _get_sync_rlock(path: Path) -> threading.RLock:
|
||||
"""Get or create a reentrant lock for a path.
|
||||
|
||||
Args:
|
||||
path: Path to get lock for.
|
||||
|
||||
Returns:
|
||||
Threading RLock for the given path.
|
||||
"""
|
||||
with _sync_rlocks_guard:
|
||||
lk = _sync_rlocks.get(path)
|
||||
if lk is None:
|
||||
lk = threading.RLock()
|
||||
_sync_rlocks[path] = lk
|
||||
return lk
|
||||
|
||||
|
||||
class _SyncDepthManager:
|
||||
"""Context manager for sync depth tracking.
|
||||
|
||||
Tracks reentrancy depth for synchronous locks to determine
|
||||
when to acquire/release file locks.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
"""Initialize depth manager.
|
||||
|
||||
Args:
|
||||
path: Path to track depth for.
|
||||
"""
|
||||
self.path = path
|
||||
self.depth = 0
|
||||
|
||||
def __enter__(self) -> int:
|
||||
"""Enter context and increment depth.
|
||||
|
||||
Returns:
|
||||
Current depth after increment.
|
||||
"""
|
||||
depths = getattr(_tls, "depths", None)
|
||||
if depths is None:
|
||||
depths = {}
|
||||
_tls.depths = depths
|
||||
self.depth = depths.get(self.path, 0) + 1
|
||||
depths[self.path] = self.depth
|
||||
return self.depth
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit context and decrement depth.
|
||||
|
||||
Args:
|
||||
*args: Exception information if any.
|
||||
"""
|
||||
depths = getattr(_tls, "depths", {})
|
||||
v = depths.get(self.path, 1) - 1
|
||||
if v <= 0:
|
||||
depths.pop(self.path, None)
|
||||
else:
|
||||
depths[self.path] = v
|
||||
|
||||
|
||||
def _safe_to_delete(path: Path) -> bool:
|
||||
"""Check if a lock file can be safely deleted.
|
||||
|
||||
Args:
|
||||
path: Path to the lock file.
|
||||
|
||||
Returns:
|
||||
True if file can be deleted safely, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with portalocker.Lock(
|
||||
str(path),
|
||||
mode="a+",
|
||||
timeout=0.01, # very short, non-blocking-ish
|
||||
fail_when_locked=True, # fail if someone holds it
|
||||
flags=constants.LockFlags.EXCLUSIVE,
|
||||
):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def with_lock(func: Callable[Concatenate[T, P], R]) -> Callable[Concatenate[T, P], R]:
|
||||
"""Decorator for file-based cross-process locking.
|
||||
|
||||
Args:
|
||||
func: Function to wrap with locking.
|
||||
|
||||
Returns:
|
||||
Wrapped function with locking behavior.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
_ensure_same_process()
|
||||
|
||||
path, _ = _get_lock_context(self, func, kwargs)
|
||||
local_lock = _get_sync_rlock(path)
|
||||
|
||||
prune_after = False
|
||||
with local_lock:
|
||||
with _SyncDepthManager(path) as depth:
|
||||
if depth == 1:
|
||||
# stale handling
|
||||
if _check_lock_staleness(path) and _safe_to_delete(path):
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# acquire file lock
|
||||
lock_config = LockConfig(
|
||||
mode=LOCK_CONFIG["mode"],
|
||||
fail_when_locked=LOCK_CONFIG["fail_when_locked"],
|
||||
flags=LOCK_CONFIG["flags"],
|
||||
)
|
||||
with portalocker.Lock(str(path), **lock_config) as _fh:
|
||||
_write_lock_metadata(path)
|
||||
result = func(self, *args, **kwargs)
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
prune_after = True
|
||||
else:
|
||||
result = func(self, *args, **kwargs)
|
||||
|
||||
# <-- NOW it’s safe to remove the entry
|
||||
if prune_after:
|
||||
with _sync_rlocks_guard:
|
||||
_sync_rlocks.pop(path, None)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Use weak references to avoid keeping event loops alive
|
||||
_async_locks_by_loop: weakref.WeakKeyDictionary[
|
||||
asyncio.AbstractEventLoop, dict[Path, asyncio.Lock]
|
||||
] = weakref.WeakKeyDictionary()
|
||||
_async_locks_guard = threading.Lock()
|
||||
_task_depths_var: contextvars.ContextVar[dict[Path, int] | None] = (
|
||||
contextvars.ContextVar("locked_task_depths", default=None)
|
||||
)
|
||||
|
||||
|
||||
def _get_async_lock(path: Path) -> asyncio.Lock:
|
||||
"""Get or create an async lock for the current event loop.
|
||||
|
||||
Args:
|
||||
path: Path to get lock for.
|
||||
|
||||
Returns:
|
||||
Asyncio Lock for the given path in current event loop.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
with _async_locks_guard:
|
||||
# Get locks dict for this event loop
|
||||
loop_locks = _async_locks_by_loop.get(loop)
|
||||
if loop_locks is None:
|
||||
loop_locks = {}
|
||||
_async_locks_by_loop[loop] = loop_locks
|
||||
|
||||
# Get or create lock for this path
|
||||
lk = loop_locks.get(path)
|
||||
if lk is None:
|
||||
# Create lock in the context of the running loop
|
||||
lk = asyncio.Lock()
|
||||
loop_locks[path] = lk
|
||||
return lk
|
||||
|
||||
|
||||
class _AsyncDepthManager:
|
||||
"""Context manager for async task-local depth tracking.
|
||||
|
||||
Tracks reentrancy depth for async locks to determine
|
||||
when to acquire/release file locks.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
"""Initialize async depth manager.
|
||||
|
||||
Args:
|
||||
path: Path to track depth for.
|
||||
"""
|
||||
self.path = path
|
||||
self.depths: dict[Path, int] | None = None
|
||||
self.is_reentrant = False
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Enter context and track async task depth.
|
||||
|
||||
Returns:
|
||||
Self for context management.
|
||||
"""
|
||||
d = _task_depths_var.get()
|
||||
if d is None:
|
||||
d = {}
|
||||
_task_depths_var.set(d)
|
||||
self.depths = d
|
||||
|
||||
cur_depth = self.depths.get(self.path, 0)
|
||||
if cur_depth > 0:
|
||||
self.is_reentrant = True
|
||||
self.depths[self.path] = cur_depth + 1
|
||||
else:
|
||||
self.depths[self.path] = 1
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit context and update task depth.
|
||||
|
||||
Args:
|
||||
*args: Exception information if any.
|
||||
"""
|
||||
if self.depths is not None:
|
||||
new_depth = self.depths.get(self.path, 1) - 1
|
||||
if new_depth <= 0:
|
||||
self.depths.pop(self.path, None)
|
||||
else:
|
||||
self.depths[self.path] = new_depth
|
||||
|
||||
|
||||
async def _safe_to_delete_async(path: Path) -> bool:
|
||||
"""Check if a lock file can be safely deleted (async).
|
||||
|
||||
Args:
|
||||
path: Path to the lock file.
|
||||
|
||||
Returns:
|
||||
True if file can be deleted safely, False otherwise.
|
||||
"""
|
||||
|
||||
def _try_lock() -> bool:
|
||||
try:
|
||||
with portalocker.Lock(
|
||||
str(path),
|
||||
mode="a+",
|
||||
timeout=0.01, # very short, effectively non-blocking
|
||||
fail_when_locked=True, # fail if another process holds it
|
||||
flags=constants.LockFlags.EXCLUSIVE,
|
||||
):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return await asyncio.to_thread(_try_lock)
|
||||
|
||||
|
||||
def async_with_lock(
|
||||
func: Callable[Concatenate[T, P], Coroutine[Any, Any, R]],
|
||||
) -> Callable[Concatenate[T, P], Coroutine[Any, Any, R]]:
|
||||
"""Async decorator for file-based cross-process locking.
|
||||
|
||||
Args:
|
||||
func: Async function to wrap with locking.
|
||||
|
||||
Returns:
|
||||
Wrapped async function with locking behavior.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
_ensure_same_process()
|
||||
|
||||
path, _ = _get_lock_context(self, func, kwargs)
|
||||
|
||||
with _AsyncDepthManager(path) as depth_mgr:
|
||||
if depth_mgr.is_reentrant:
|
||||
# Re-entrant within the same task: skip file lock
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
# Safer stale handling: only unlink if we can lock it first
|
||||
if _check_lock_staleness(path) and await _safe_to_delete_async(path):
|
||||
try:
|
||||
await asyncio.to_thread(lambda: path.unlink(missing_ok=True))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Acquire per-loop async lock to serialize within this loop
|
||||
async_lock = _get_async_lock(path)
|
||||
await async_lock.acquire()
|
||||
try:
|
||||
# Acquire cross-process file lock in a thread
|
||||
lock_config = LockConfig(
|
||||
mode=LOCK_CONFIG["mode"],
|
||||
fail_when_locked=LOCK_CONFIG["fail_when_locked"],
|
||||
flags=LOCK_CONFIG["flags"],
|
||||
)
|
||||
file_lock = portalocker.Lock(str(path), **lock_config)
|
||||
|
||||
await asyncio.to_thread(file_lock.acquire)
|
||||
try:
|
||||
# Write/refresh metadata while lock is held
|
||||
await asyncio.to_thread(lambda: _write_lock_metadata(path))
|
||||
|
||||
result = await func(self, *args, **kwargs)
|
||||
finally:
|
||||
# Release file lock before unlink to avoid inode race
|
||||
try:
|
||||
await asyncio.to_thread(file_lock.release)
|
||||
finally:
|
||||
# Now it's safe to unlink the path
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
lambda: path.unlink(missing_ok=True)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
finally:
|
||||
async_lock.release()
|
||||
|
||||
with _async_locks_guard:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop_locks = _async_locks_by_loop.get(loop)
|
||||
if loop_locks is not None:
|
||||
loop_locks.pop(path, None)
|
||||
|
||||
return wrapper
|
||||
@@ -3,11 +3,6 @@ from datetime import datetime, timedelta
|
||||
import requests
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
from crewai.cli.authentication.main import AuthenticationCommand
|
||||
from crewai.cli.authentication.constants import (
|
||||
AUTH0_AUDIENCE,
|
||||
AUTH0_CLIENT_ID,
|
||||
AUTH0_DOMAIN
|
||||
)
|
||||
from crewai.cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
@@ -22,16 +17,6 @@ class TestAuthenticationCommand:
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,expected_urls",
|
||||
[
|
||||
(
|
||||
"auth0",
|
||||
{
|
||||
"device_code_url": f"https://{AUTH0_DOMAIN}/oauth/device/code",
|
||||
"token_url": f"https://{AUTH0_DOMAIN}/oauth/token",
|
||||
"client_id": AUTH0_CLIENT_ID,
|
||||
"audience": AUTH0_AUDIENCE,
|
||||
"domain": AUTH0_DOMAIN,
|
||||
},
|
||||
),
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
@@ -44,9 +29,6 @@ class TestAuthenticationCommand:
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
|
||||
)
|
||||
@patch("crewai.cli.authentication.main.AuthenticationCommand._get_device_code")
|
||||
@patch(
|
||||
"crewai.cli.authentication.main.AuthenticationCommand._display_auth_instructions"
|
||||
@@ -59,11 +41,9 @@ class TestAuthenticationCommand:
|
||||
mock_poll,
|
||||
mock_display,
|
||||
mock_get_device,
|
||||
mock_determine_provider,
|
||||
user_provider,
|
||||
expected_urls,
|
||||
):
|
||||
mock_determine_provider.return_value = user_provider
|
||||
mock_get_device.return_value = {
|
||||
"device_code": "test_code",
|
||||
"user_code": "123456",
|
||||
@@ -74,7 +54,6 @@ class TestAuthenticationCommand:
|
||||
mock_console_print.assert_called_once_with(
|
||||
"Signing in to CrewAI Enterprise...\n", style="bold blue"
|
||||
)
|
||||
mock_determine_provider.assert_called_once()
|
||||
mock_get_device.assert_called_once()
|
||||
mock_display.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"}
|
||||
@@ -82,9 +61,17 @@ class TestAuthenticationCommand:
|
||||
mock_poll.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"},
|
||||
)
|
||||
assert self.auth_command.oauth2_provider.get_client_id() == expected_urls["client_id"]
|
||||
assert self.auth_command.oauth2_provider.get_audience() == expected_urls["audience"]
|
||||
assert self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_client_id()
|
||||
== expected_urls["client_id"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_audience()
|
||||
== expected_urls["audience"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
||||
)
|
||||
|
||||
@patch("crewai.cli.authentication.main.webbrowser")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@@ -106,14 +93,6 @@ class TestAuthenticationCommand:
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,jwt_config",
|
||||
[
|
||||
(
|
||||
"auth0",
|
||||
{
|
||||
"jwks_url": f"https://{AUTH0_DOMAIN}/.well-known/jwks.json",
|
||||
"issuer": f"https://{AUTH0_DOMAIN}/",
|
||||
"audience": AUTH0_AUDIENCE,
|
||||
},
|
||||
),
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
@@ -135,14 +114,18 @@ class TestAuthenticationCommand:
|
||||
jwt_config,
|
||||
has_expiration,
|
||||
):
|
||||
from crewai.cli.authentication.providers.auth0 import Auth0Provider
|
||||
from crewai.cli.authentication.providers.workos import WorkosProvider
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
|
||||
if user_provider == "auth0":
|
||||
self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=AUTH0_DOMAIN, audience=jwt_config["audience"]))
|
||||
elif user_provider == "workos":
|
||||
self.auth_command.oauth2_provider = WorkosProvider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, audience=jwt_config["audience"]))
|
||||
if user_provider == "workos":
|
||||
self.auth_command.oauth2_provider = WorkosProvider(
|
||||
settings=Oauth2Settings(
|
||||
provider=user_provider,
|
||||
client_id="test-client-id",
|
||||
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
audience=jwt_config["audience"],
|
||||
)
|
||||
)
|
||||
|
||||
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
|
||||
|
||||
@@ -234,83 +217,6 @@ class TestAuthenticationCommand:
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_response,expected_provider",
|
||||
[
|
||||
({"provider": "auth0"}, "auth0"),
|
||||
({"provider": "workos"}, "workos"),
|
||||
({"provider": "none"}, "workos"), # Default to workos for any other value
|
||||
(
|
||||
{},
|
||||
"workos",
|
||||
), # Default to workos if no provider key is sent in the response
|
||||
],
|
||||
)
|
||||
@patch("crewai.cli.authentication.main.PlusAPI")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("builtins.input", return_value="test@example.com")
|
||||
def test_determine_user_provider_success(
|
||||
self,
|
||||
mock_input,
|
||||
mock_console_print,
|
||||
mock_plus_api,
|
||||
api_response,
|
||||
expected_provider,
|
||||
):
|
||||
mock_api_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = api_response
|
||||
mock_api_instance._make_request.return_value = mock_response
|
||||
mock_plus_api.return_value = mock_api_instance
|
||||
|
||||
result = self.auth_command._determine_user_provider()
|
||||
|
||||
mock_input.assert_called_once()
|
||||
|
||||
mock_plus_api.assert_called_once_with("")
|
||||
mock_api_instance._make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
|
||||
)
|
||||
|
||||
assert result == expected_provider
|
||||
|
||||
@patch("crewai.cli.authentication.main.PlusAPI")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("builtins.input", return_value="test@example.com")
|
||||
def test_determine_user_provider_error(
|
||||
self, mock_input, mock_console_print, mock_plus_api
|
||||
):
|
||||
mock_api_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_api_instance._make_request.return_value = mock_response
|
||||
mock_plus_api.return_value = mock_api_instance
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
self.auth_command._determine_user_provider()
|
||||
|
||||
mock_input.assert_called_once()
|
||||
|
||||
mock_plus_api.assert_called_once_with("")
|
||||
mock_api_instance._make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
|
||||
)
|
||||
|
||||
mock_console_print.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
"Enter your CrewAI Enterprise account email: ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
),
|
||||
call(
|
||||
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
|
||||
style="red",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@patch("requests.post")
|
||||
def test_get_device_code(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
@@ -323,7 +229,9 @@ class TestAuthenticationCommand:
|
||||
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
self.auth_command.oauth2_provider.get_authorize_url.return_value = "https://example.com/device"
|
||||
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
|
||||
"https://example.com/device"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||
|
||||
result = self.auth_command._get_device_code()
|
||||
@@ -366,12 +274,12 @@ class TestAuthenticationCommand:
|
||||
) as mock_tool_login,
|
||||
):
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_token_url.return_value = "https://example.com/token"
|
||||
self.auth_command.oauth2_provider.get_token_url.return_value = (
|
||||
"https://example.com/token"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
|
||||
self.auth_command._poll_for_token(
|
||||
device_code_data
|
||||
)
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
"https://example.com/token",
|
||||
@@ -406,9 +314,7 @@ class TestAuthenticationCommand:
|
||||
"interval": 0.1, # Short interval for testing
|
||||
}
|
||||
|
||||
self.auth_command._poll_for_token(
|
||||
device_code_data
|
||||
)
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_console_print.assert_any_call(
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
@@ -429,15 +335,4 @@ class TestAuthenticationCommand:
|
||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||
|
||||
with pytest.raises(requests.HTTPError):
|
||||
self.auth_command._poll_for_token(
|
||||
device_code_data
|
||||
)
|
||||
# @patch(
|
||||
# "crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
|
||||
# )
|
||||
# def test_login_with_auth0(self, mock_determine_provider):
|
||||
# from crewai.cli.authentication.providers.auth0 import Auth0Provider
|
||||
# from crewai.cli.authentication.main import Oauth2Settings
|
||||
|
||||
# self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider="auth0", client_id=AUTH0_CLIENT_ID, domain=AUTH0_DOMAIN, audience=AUTH0_AUDIENCE))
|
||||
# self.auth_command.login()
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.factory import create_client
|
||||
|
||||
|
||||
@@ -27,50 +25,10 @@ def test_create_client_chromadb():
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_qdrant():
|
||||
"""Test Qdrant client creation."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "qdrant"
|
||||
|
||||
with patch("crewai.rag.factory.require") as mock_require:
|
||||
mock_module = Mock()
|
||||
mock_client = Mock()
|
||||
mock_module.create_client.return_value = mock_client
|
||||
mock_require.return_value = mock_module
|
||||
|
||||
result = create_client(mock_config)
|
||||
|
||||
assert result == mock_client
|
||||
mock_require.assert_called_once_with(
|
||||
"crewai.rag.qdrant.factory", purpose="The 'qdrant' provider"
|
||||
)
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_elasticsearch():
|
||||
"""Test Elasticsearch client creation."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "elasticsearch"
|
||||
|
||||
with patch("crewai.rag.factory.require") as mock_require:
|
||||
mock_module = Mock()
|
||||
mock_client = Mock()
|
||||
mock_module.create_client.return_value = mock_client
|
||||
mock_require.return_value = mock_module
|
||||
|
||||
result = create_client(mock_config)
|
||||
|
||||
assert result == mock_client
|
||||
mock_require.assert_called_once_with(
|
||||
"crewai.rag.elasticsearch.factory", purpose="The 'elasticsearch' provider"
|
||||
)
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_unsupported_provider():
|
||||
"""Test that unsupported provider raises ValueError."""
|
||||
"""Test unsupported provider returns None for now."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "unsupported"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
|
||||
create_client(mock_config)
|
||||
result = create_client(mock_config)
|
||||
assert result is None
|
||||
|
||||
@@ -3,10 +3,7 @@
|
||||
import pytest
|
||||
|
||||
from crewai.rag.config.optional_imports.base import _MissingProvider
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingChromaDBConfig,
|
||||
MissingElasticsearchConfig,
|
||||
)
|
||||
from crewai.rag.config.optional_imports.providers import MissingChromaDBConfig
|
||||
|
||||
|
||||
def test_missing_provider_raises_runtime_error():
|
||||
@@ -23,11 +20,3 @@ def test_missing_chromadb_config_raises_runtime_error():
|
||||
RuntimeError, match="provider 'chromadb' requested but not installed"
|
||||
):
|
||||
MissingChromaDBConfig()
|
||||
|
||||
|
||||
def test_missing_elasticsearch_config_raises_runtime_error():
|
||||
"""Test that MissingElasticsearchConfig raises RuntimeError on instantiation."""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="provider 'elasticsearch' requested but not installed"
|
||||
):
|
||||
MissingElasticsearchConfig()
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Tests for Elasticsearch RAG implementation."""
|
||||
@@ -1,397 +0,0 @@
|
||||
"""Tests for ElasticsearchClient implementation."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
||||
from crewai.rag.types import BaseRecord
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_elasticsearch_client():
|
||||
"""Create a mock Elasticsearch client."""
|
||||
mock_client = Mock()
|
||||
mock_client.indices = Mock()
|
||||
mock_client.indices.exists.return_value = False
|
||||
mock_client.indices.create.return_value = {"acknowledged": True}
|
||||
mock_client.indices.get.return_value = {"test_index": {"mappings": {}}}
|
||||
mock_client.indices.delete.return_value = {"acknowledged": True}
|
||||
mock_client.index.return_value = {"_id": "test_id", "result": "created"}
|
||||
mock_client.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "doc1",
|
||||
"_score": 0.9,
|
||||
"_source": {
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_elasticsearch_client():
|
||||
"""Create a mock async Elasticsearch client."""
|
||||
mock_client = Mock()
|
||||
mock_client.indices = Mock()
|
||||
mock_client.indices.exists = AsyncMock(return_value=False)
|
||||
mock_client.indices.create = AsyncMock(return_value={"acknowledged": True})
|
||||
mock_client.indices.get = AsyncMock(return_value={"test_index": {"mappings": {}}})
|
||||
mock_client.indices.delete = AsyncMock(return_value={"acknowledged": True})
|
||||
mock_client.index = AsyncMock(return_value={"_id": "test_id", "result": "created"})
|
||||
mock_client.search = AsyncMock(return_value={
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "doc1",
|
||||
"_score": 0.9,
|
||||
"_source": {
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_elasticsearch_client) -> ElasticsearchClient:
|
||||
"""Create an ElasticsearchClient instance for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
client = ElasticsearchClient(
|
||||
client=mock_elasticsearch_client,
|
||||
embedding_function=mock_embedding,
|
||||
vector_dimension=3,
|
||||
similarity="cosine"
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_elasticsearch_client) -> ElasticsearchClient:
|
||||
"""Create an ElasticsearchClient instance with async client for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
client = ElasticsearchClient(
|
||||
client=mock_async_elasticsearch_client,
|
||||
embedding_function=mock_embedding,
|
||||
vector_dimension=3,
|
||||
similarity="cosine"
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class TestElasticsearchClient:
|
||||
"""Test suite for ElasticsearchClient."""
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that create_collection creates a new index."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
client.create_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.create.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.indices.create.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "mappings" in call_args.kwargs["body"]
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_create_collection_already_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that create_collection raises error if index exists."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Index 'test_index' already exists"
|
||||
):
|
||||
client.create_collection(collection_name="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
def test_create_collection_wrong_client_type(self, mock_is_async, mock_is_sync, mock_async_elasticsearch_client):
|
||||
"""Test that create_collection raises error with async client."""
|
||||
mock_embedding = Mock()
|
||||
client = ElasticsearchClient(
|
||||
client=mock_async_elasticsearch_client,
|
||||
embedding_function=mock_embedding
|
||||
)
|
||||
|
||||
with pytest.raises(ClientMethodMismatchError):
|
||||
client.create_collection(collection_name="test_index")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_acreate_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that acreate_collection creates a new index asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
await async_client.acreate_collection(collection_name="test_index")
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.indices.create.assert_called_once()
|
||||
call_args = mock_async_elasticsearch_client.indices.create.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "mappings" in call_args.kwargs["body"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_acreate_collection_already_exists(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that acreate_collection raises error if index exists."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Index 'test_index' already exists"
|
||||
):
|
||||
await async_client.acreate_collection(collection_name="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_get_or_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that get_or_create_collection returns existing index."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
||||
assert result == {"test_index": {"mappings": {}}}
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_get_or_create_collection_creates_new(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that get_or_create_collection creates new index if not exists."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
client.get_or_create_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.create.assert_called_once()
|
||||
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_aget_or_create_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that aget_or_create_collection returns existing index asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
result = await async_client.aget_or_create_collection(collection_name="test_index")
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
||||
assert result == {"test_index": {"mappings": {}}}
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_add_documents(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that add_documents indexes documents correctly."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_index", documents=documents)
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.index.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.index.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "body" in call_args.kwargs
|
||||
assert call_args.kwargs["body"]["content"] == "test content"
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_add_documents_empty_list_raises_error(self, mock_is_async, mock_is_sync, client):
|
||||
"""Test that add_documents raises error with empty documents list."""
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
client.add_documents(collection_name="test_index", documents=[])
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_add_documents_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that add_documents raises error if index doesn't exist."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
documents: list[BaseRecord] = [{"content": "test content"}]
|
||||
|
||||
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
||||
client.add_documents(collection_name="test_index", documents=documents)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_aadd_documents(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that aadd_documents indexes documents correctly asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "test content",
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(collection_name="test_index", documents=documents)
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.index.assert_called_once()
|
||||
call_args = mock_async_elasticsearch_client.index.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "body" in call_args.kwargs
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_search(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that search performs vector similarity search."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
results = client.search(
|
||||
collection_name="test_index",
|
||||
query="test query",
|
||||
limit=5
|
||||
)
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.search.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.search.call_args
|
||||
assert call_args.kwargs["index"] == "test_index"
|
||||
assert "body" in call_args.kwargs
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "test content"
|
||||
assert results[0]["score"] == 0.9
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_search_with_metadata_filter(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that search applies metadata filter correctly."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
client.search(
|
||||
collection_name="test_index",
|
||||
query="test query",
|
||||
metadata_filter={"key": "value"}
|
||||
)
|
||||
|
||||
mock_elasticsearch_client.search.assert_called_once()
|
||||
call_args = mock_elasticsearch_client.search.call_args
|
||||
query_body = call_args.kwargs["body"]
|
||||
assert "bool" in query_body["query"]
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_search_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that search raises error if index doesn't exist."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
||||
client.search(collection_name="test_index", query="test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_asearch(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that asearch performs vector similarity search asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_index",
|
||||
query="test query",
|
||||
limit=5
|
||||
)
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.search.assert_called_once()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "test content"
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_delete_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that delete_collection deletes the index."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
client.delete_collection(collection_name="test_index")
|
||||
|
||||
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_delete_collection_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that delete_collection raises error if index doesn't exist."""
|
||||
mock_elasticsearch_client.indices.exists.return_value = False
|
||||
|
||||
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
||||
client.delete_collection(collection_name="test_index")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_adelete_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that adelete_collection deletes the index asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.exists.return_value = True
|
||||
|
||||
await async_client.adelete_collection(collection_name="test_index")
|
||||
|
||||
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
||||
mock_async_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
|
||||
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
||||
def test_reset(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
||||
"""Test that reset deletes all non-system indices."""
|
||||
mock_elasticsearch_client.indices.get.return_value = {
|
||||
"test_index": {},
|
||||
".system_index": {},
|
||||
"another_index": {}
|
||||
}
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_elasticsearch_client.indices.get.assert_called_once_with(index="*")
|
||||
assert mock_elasticsearch_client.indices.delete.call_count == 2
|
||||
delete_calls = [call.kwargs["index"] for call in mock_elasticsearch_client.indices.delete.call_args_list]
|
||||
assert "test_index" in delete_calls
|
||||
assert "another_index" in delete_calls
|
||||
assert ".system_index" not in delete_calls
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
||||
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
||||
async def test_areset(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
||||
"""Test that areset deletes all non-system indices asynchronously."""
|
||||
mock_async_elasticsearch_client.indices.get.return_value = {
|
||||
"test_index": {},
|
||||
".system_index": {},
|
||||
"another_index": {}
|
||||
}
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="*")
|
||||
assert mock_async_elasticsearch_client.indices.delete.call_count == 2
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Tests for Elasticsearch configuration."""
|
||||
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
|
||||
|
||||
def test_elasticsearch_config_defaults():
|
||||
"""Test that ElasticsearchConfig has correct defaults."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
assert config.provider == "elasticsearch"
|
||||
assert config.vector_dimension == 384
|
||||
assert config.similarity == "cosine"
|
||||
assert config.embedding_function is not None
|
||||
assert config.options["hosts"] == ["http://localhost:9200"]
|
||||
assert config.options["use_ssl"] is False
|
||||
|
||||
|
||||
def test_elasticsearch_config_custom_options():
|
||||
"""Test that ElasticsearchConfig accepts custom options."""
|
||||
custom_options = {
|
||||
"hosts": ["https://elastic.example.com:9200"],
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"use_ssl": True,
|
||||
}
|
||||
|
||||
config = ElasticsearchConfig(
|
||||
options=custom_options,
|
||||
vector_dimension=768,
|
||||
similarity="dot_product"
|
||||
)
|
||||
|
||||
assert config.provider == "elasticsearch"
|
||||
assert config.vector_dimension == 768
|
||||
assert config.similarity == "dot_product"
|
||||
assert config.options["hosts"] == ["https://elastic.example.com:9200"]
|
||||
assert config.options["username"] == "user"
|
||||
assert config.options["use_ssl"] is True
|
||||
|
||||
|
||||
def test_elasticsearch_config_embedding_function():
|
||||
"""Test that embedding function works correctly."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
embedding = config.embedding_function("test text")
|
||||
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) == config.vector_dimension
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Tests for Elasticsearch factory."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.elasticsearch.config import ElasticsearchConfig
|
||||
|
||||
|
||||
def test_create_client():
|
||||
"""Test that create_client creates an ElasticsearchClient."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
with patch.dict('sys.modules', {'elasticsearch': Mock()}):
|
||||
mock_elasticsearch_module = Mock()
|
||||
mock_client_instance = Mock()
|
||||
mock_elasticsearch_module.Elasticsearch.return_value = mock_client_instance
|
||||
|
||||
with patch.dict('sys.modules', {'elasticsearch': mock_elasticsearch_module}):
|
||||
from crewai.rag.elasticsearch.factory import create_client
|
||||
client = create_client(config)
|
||||
|
||||
mock_elasticsearch_module.Elasticsearch.assert_called_once_with(**config.options)
|
||||
assert client.client == mock_client_instance
|
||||
assert client.embedding_function == config.embedding_function
|
||||
assert client.vector_dimension == config.vector_dimension
|
||||
assert client.similarity == config.similarity
|
||||
|
||||
|
||||
def test_create_client_missing_elasticsearch():
|
||||
"""Test that create_client raises ImportError when elasticsearch is not installed."""
|
||||
config = ElasticsearchConfig()
|
||||
|
||||
with patch.dict('sys.modules', {}, clear=False):
|
||||
if 'elasticsearch' in __import__('sys').modules:
|
||||
del __import__('sys').modules['elasticsearch']
|
||||
|
||||
from crewai.rag.elasticsearch.factory import create_client
|
||||
with pytest.raises(ImportError, match="elasticsearch package is required"):
|
||||
create_client(config)
|
||||
@@ -624,12 +624,12 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
tools = kwargs["tools"]
|
||||
|
||||
assert any(isinstance(tool, TestTool) for tool in tools), (
|
||||
"TestTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in tools
|
||||
), "TestTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -688,12 +688,12 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
tools = kwargs["tools"]
|
||||
|
||||
assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), (
|
||||
"TestTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in new_ceo.tools
|
||||
), "TestTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -817,17 +817,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Confirm AnotherTestTool is present but TestTool is not
|
||||
assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), (
|
||||
"AnotherTestTool should be present"
|
||||
)
|
||||
assert not any(isinstance(tool, TestTool) for tool in used_tools), (
|
||||
"TestTool should not be present among used tools"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, AnotherTestTool) for tool in used_tools
|
||||
), "AnotherTestTool should be present"
|
||||
assert not any(
|
||||
isinstance(tool, TestTool) for tool in used_tools
|
||||
), "TestTool should not be present among used tools"
|
||||
|
||||
# Confirm delegation tool(s) are present
|
||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in used_tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
# Finally, make sure the agent's original tools remain unchanged
|
||||
assert len(researcher_with_delegation.tools) == 1
|
||||
@@ -931,9 +931,9 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
|
||||
tool="multiplier", input={"first_number": 2, "second_number": 6}
|
||||
)
|
||||
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
|
||||
assert cache_calls[1] == expected_call, (
|
||||
f"Second call mismatch: {cache_calls[1]}"
|
||||
)
|
||||
assert (
|
||||
cache_calls[1] == expected_call
|
||||
), f"Second call mismatch: {cache_calls[1]}"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -1676,9 +1676,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
|
||||
|
||||
# Verify that exactly one tool was used and it was a CodeInterpreterTool
|
||||
assert len(used_tools) == 1, "Should have exactly one tool"
|
||||
assert isinstance(used_tools[0], CodeInterpreterTool), (
|
||||
"Tool should be CodeInterpreterTool"
|
||||
)
|
||||
assert isinstance(
|
||||
used_tools[0], CodeInterpreterTool
|
||||
), "Tool should be CodeInterpreterTool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -2537,8 +2537,8 @@ def test_memory_events_are_emitted():
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) == 6
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 6
|
||||
assert len(events["MemorySaveStartedEvent"]) == 3
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 3
|
||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 3
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 3
|
||||
@@ -3817,9 +3817,9 @@ def test_fetch_inputs():
|
||||
expected_placeholders = {"role_detail", "topic", "field"}
|
||||
actual_placeholders = crew.fetch_inputs()
|
||||
|
||||
assert actual_placeholders == expected_placeholders, (
|
||||
f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||
)
|
||||
assert (
|
||||
actual_placeholders == expected_placeholders
|
||||
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||
|
||||
|
||||
def test_task_tools_preserve_code_execution_tools():
|
||||
@@ -3894,20 +3894,20 @@ def test_task_tools_preserve_code_execution_tools():
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Verify all expected tools are present
|
||||
assert any(isinstance(tool, TestTool) for tool in used_tools), (
|
||||
"Task's TestTool should be present"
|
||||
)
|
||||
assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), (
|
||||
"CodeInterpreterTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in used_tools
|
||||
), "Task's TestTool should be present"
|
||||
assert any(
|
||||
isinstance(tool, CodeInterpreterTool) for tool in used_tools
|
||||
), "CodeInterpreterTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in used_tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
|
||||
assert len(used_tools) == 4, (
|
||||
"Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
||||
)
|
||||
assert (
|
||||
len(used_tools) == 4
|
||||
), "Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -3951,9 +3951,9 @@ def test_multimodal_flag_adds_multimodal_tools():
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Check that the multimodal tool was added
|
||||
assert any(isinstance(tool, AddImageTool) for tool in used_tools), (
|
||||
"AddImageTool should be present when agent is multimodal"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, AddImageTool) for tool in used_tools
|
||||
), "AddImageTool should be present when agent is multimodal"
|
||||
|
||||
# Verify we have exactly one tool (just the AddImageTool)
|
||||
assert len(used_tools) == 1, "Should only have the AddImageTool"
|
||||
@@ -4217,9 +4217,9 @@ def test_crew_guardrail_feedback_in_context():
|
||||
assert len(execution_contexts) > 1, "Task should have been executed multiple times"
|
||||
|
||||
# Verify that the second execution included the guardrail feedback
|
||||
assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], (
|
||||
"Guardrail feedback should be included in retry context"
|
||||
)
|
||||
assert (
|
||||
"Output must contain the keyword 'IMPORTANT'" in execution_contexts[1]
|
||||
), "Guardrail feedback should be included in retry context"
|
||||
|
||||
# Verify final output meets guardrail requirements
|
||||
assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
|
||||
@@ -4435,46 +4435,46 @@ def test_crew_copy_with_memory():
|
||||
try:
|
||||
crew_copy = crew.copy()
|
||||
|
||||
assert hasattr(crew_copy, "_short_term_memory"), (
|
||||
"Copied crew should have _short_term_memory"
|
||||
)
|
||||
assert crew_copy._short_term_memory is not None, (
|
||||
"Copied _short_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._short_term_memory) != original_short_term_id, (
|
||||
"Copied _short_term_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_short_term_memory"
|
||||
), "Copied crew should have _short_term_memory"
|
||||
assert (
|
||||
crew_copy._short_term_memory is not None
|
||||
), "Copied _short_term_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._short_term_memory) != original_short_term_id
|
||||
), "Copied _short_term_memory should be a new object"
|
||||
|
||||
assert hasattr(crew_copy, "_long_term_memory"), (
|
||||
"Copied crew should have _long_term_memory"
|
||||
)
|
||||
assert crew_copy._long_term_memory is not None, (
|
||||
"Copied _long_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._long_term_memory) != original_long_term_id, (
|
||||
"Copied _long_term_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_long_term_memory"
|
||||
), "Copied crew should have _long_term_memory"
|
||||
assert (
|
||||
crew_copy._long_term_memory is not None
|
||||
), "Copied _long_term_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._long_term_memory) != original_long_term_id
|
||||
), "Copied _long_term_memory should be a new object"
|
||||
|
||||
assert hasattr(crew_copy, "_entity_memory"), (
|
||||
"Copied crew should have _entity_memory"
|
||||
)
|
||||
assert crew_copy._entity_memory is not None, (
|
||||
"Copied _entity_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._entity_memory) != original_entity_id, (
|
||||
"Copied _entity_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_entity_memory"
|
||||
), "Copied crew should have _entity_memory"
|
||||
assert (
|
||||
crew_copy._entity_memory is not None
|
||||
), "Copied _entity_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._entity_memory) != original_entity_id
|
||||
), "Copied _entity_memory should be a new object"
|
||||
|
||||
if original_external_id:
|
||||
assert hasattr(crew_copy, "_external_memory"), (
|
||||
"Copied crew should have _external_memory"
|
||||
)
|
||||
assert crew_copy._external_memory is not None, (
|
||||
"Copied _external_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._external_memory) != original_external_id, (
|
||||
"Copied _external_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_external_memory"
|
||||
), "Copied crew should have _external_memory"
|
||||
assert (
|
||||
crew_copy._external_memory is not None
|
||||
), "Copied _external_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._external_memory) != original_external_id
|
||||
), "Copied _external_memory should be a new object"
|
||||
else:
|
||||
assert (
|
||||
not hasattr(crew_copy, "_external_memory")
|
||||
|
||||
Reference in New Issue
Block a user