mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
Compare commits
5 Commits
devin/1745
...
devin/1745
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4e61ae714 | ||
|
|
958751fe36 | ||
|
|
3c838f16ff | ||
|
|
6c08e6062a | ||
|
|
2e4c97661a |
@@ -179,7 +179,78 @@ def crew(self) -> Crew:
|
||||
```
|
||||
</Note>
|
||||
|
||||
### 10. API Keys
|
||||
### 10. Deploy
|
||||
|
||||
Deploy the crew or flow to [CrewAI Enterprise](https://app.crewai.com).
|
||||
|
||||
- **Authentication**: You need to be authenticated to deploy to CrewAI Enterprise.
|
||||
```shell Terminal
|
||||
crewai signup
|
||||
```
|
||||
If you already have an account, you can login with:
|
||||
```shell Terminal
|
||||
crewai login
|
||||
```
|
||||
|
||||
- **Create a deployment**: Once you are authenticated, you can create a deployment for your crew or flow from the root of your localproject.
|
||||
```shell Terminal
|
||||
crewai deploy create
|
||||
```
|
||||
- Reads your local project configuration.
|
||||
- Prompts you to confirm the environment variables (like `OPENAI_API_KEY`, `SERPER_API_KEY`) found locally. These will be securely stored with the deployment on the Enterprise platform. Ensure your sensitive keys are correctly configured locally (e.g., in a `.env` file) before running this.
|
||||
- Links the deployment to the corresponding remote GitHub repository (it usually detects this automatically).
|
||||
|
||||
- **Deploy the Crew**: Once you are authenticated, you can deploy your crew or flow to CrewAI Enterprise.
|
||||
```shell Terminal
|
||||
crewai deploy push
|
||||
```
|
||||
- Initiates the deployment process on the CrewAI Enterprise platform.
|
||||
- Upon successful initiation, it will output the Deployment created successfully! message along with the Deployment Name and a unique Deployment ID (UUID).
|
||||
|
||||
- **Deployment Status**: You can check the status of your deployment with:
|
||||
```shell Terminal
|
||||
crewai deploy status
|
||||
```
|
||||
This fetches the latest deployment status of your most recent deployment attempt (e.g., `Building Images for Crew`, `Deploy Enqueued`, `Online`).
|
||||
|
||||
- **Deployment Logs**: You can check the logs of your deployment with:
|
||||
```shell Terminal
|
||||
crewai deploy logs
|
||||
```
|
||||
This streams the deployment logs to your terminal.
|
||||
|
||||
- **List deployments**: You can list all your deployments with:
|
||||
```shell Terminal
|
||||
crewai deploy list
|
||||
```
|
||||
This lists all your deployments.
|
||||
|
||||
- **Delete a deployment**: You can delete a deployment with:
|
||||
```shell Terminal
|
||||
crewai deploy remove
|
||||
```
|
||||
This deletes the deployment from the CrewAI Enterprise platform.
|
||||
|
||||
- **Help Command**: You can get help with the CLI with:
|
||||
```shell Terminal
|
||||
crewai deploy --help
|
||||
```
|
||||
This shows the help message for the CrewAI Deploy CLI.
|
||||
|
||||
Watch this video tutorial for a step-by-step demonstration of deploying your crew to [CrewAI Enterprise](http://app.crewai.com) using the CLI.
|
||||
|
||||
<iframe
|
||||
width="100%"
|
||||
height="400"
|
||||
src="https://www.youtube.com/embed/3EqSV-CYDZA"
|
||||
title="CrewAI Deployment Guide"
|
||||
frameborder="0"
|
||||
style={{ borderRadius: '10px' }}
|
||||
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
|
||||
allowfullscreen
|
||||
></iframe>
|
||||
|
||||
### 11. API Keys
|
||||
|
||||
When running ```crewai create crew``` command, the CLI will first show you the top 5 most common LLM providers and ask you to select one.
|
||||
|
||||
|
||||
@@ -790,9 +790,6 @@ Visualizing your AI workflows can provide valuable insights into the structure a
|
||||
|
||||
Plots in CrewAI are graphical representations of your AI workflows. They display the various tasks, their connections, and the flow of data between them. This visualization helps in understanding the sequence of operations, identifying bottlenecks, and ensuring that the workflow logic aligns with your expectations.
|
||||
|
||||

|
||||
*An example visualization of a simple flow with start method, sequential steps, and directional execution.*
|
||||
|
||||
### How to Generate a Plot
|
||||
|
||||
CrewAI provides two convenient methods to generate plots of your flows:
|
||||
|
||||
117
docs/how-to/elasticsearch-integration.md
Normal file
117
docs/how-to/elasticsearch-integration.md
Normal file
@@ -0,0 +1,117 @@
|
||||
# Elasticsearch Integration
|
||||
|
||||
CrewAI supports using Elasticsearch as an alternative to ChromaDB for RAG (Retrieval Augmented Generation) storage. This allows you to leverage Elasticsearch's powerful search capabilities and scalability for your AI agents.
|
||||
|
||||
## Installation
|
||||
|
||||
To use Elasticsearch with CrewAI, you need to install the Elasticsearch Python client:
|
||||
|
||||
```bash
|
||||
pip install elasticsearch
|
||||
```
|
||||
|
||||
## Using Elasticsearch for Memory
|
||||
|
||||
You can configure your crew to use Elasticsearch for memory storage:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
# Create agents and tasks
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research a topic",
|
||||
backstory="You are a researcher who loves to find information.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research about AI",
|
||||
expected_output="Information about AI",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a crew with Elasticsearch memory
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
memory_config={
|
||||
"provider": "elasticsearch",
|
||||
"host": "localhost", # Optional, defaults to localhost
|
||||
"port": 9200, # Optional, defaults to 9200
|
||||
"username": "user", # Optional
|
||||
"password": "pass", # Optional
|
||||
},
|
||||
)
|
||||
|
||||
# Execute the crew
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
## Using Elasticsearch for Knowledge
|
||||
|
||||
You can also use Elasticsearch for knowledge storage:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
|
||||
# Create knowledge with Elasticsearch storage
|
||||
content = "AI is a field of computer science that focuses on creating machines that can perform tasks that typically require human intelligence."
|
||||
string_source = StringKnowledgeSource(
|
||||
content=content, metadata={"topic": "AI"}
|
||||
)
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[string_source],
|
||||
storage_provider="elasticsearch", # Use Elasticsearch
|
||||
# Optional Elasticsearch configuration
|
||||
host="localhost",
|
||||
port=9200,
|
||||
username="user",
|
||||
password="pass",
|
||||
)
|
||||
|
||||
# Create an agent with the knowledge
|
||||
agent = Agent(
|
||||
role="AI Expert",
|
||||
goal="Explain AI",
|
||||
backstory="You are an AI expert who loves to explain AI concepts.",
|
||||
knowledge=[knowledge],
|
||||
)
|
||||
|
||||
# Create a task
|
||||
task = Task(
|
||||
description="Explain what AI is",
|
||||
expected_output="Explanation of AI",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a crew
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
# Execute the crew
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
The Elasticsearch integration supports the following configuration options:
|
||||
|
||||
- `host`: Elasticsearch host (default: "localhost")
|
||||
- `port`: Elasticsearch port (default: 9200)
|
||||
- `username`: Elasticsearch username (optional)
|
||||
- `password`: Elasticsearch password (optional)
|
||||
- Additional keyword arguments are passed directly to the Elasticsearch client
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the Elasticsearch tests, you need to set the `RUN_ELASTICSEARCH_TESTS` environment variable to `true`:
|
||||
|
||||
```bash
|
||||
RUN_ELASTICSEARCH_TESTS=true pytest tests/memory/elasticsearch_storage_test.py tests/knowledge/elasticsearch_knowledge_storage_test.py tests/integration/elasticsearch_integration_test.py
|
||||
```
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 30 KiB |
@@ -37,6 +37,7 @@ dependencies = [
|
||||
"tomli>=2.0.2",
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"elasticsearch>=9.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
|
||||
try:
|
||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import (
|
||||
ElasticsearchKnowledgeStorage,
|
||||
)
|
||||
except ImportError:
|
||||
ElasticsearchKnowledgeStorage = None
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
|
||||
@@ -30,17 +37,29 @@ class Knowledge(BaseModel):
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
storage_provider: str = "chromadb",
|
||||
**data,
|
||||
):
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
if storage_provider == "elasticsearch":
|
||||
try:
|
||||
self.storage = cast(KnowledgeStorage, self._create_elasticsearch_storage(
|
||||
embedder, collection_name
|
||||
))
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
if self.storage is not None:
|
||||
self.storage.initialize_knowledge_storage()
|
||||
self._add_sources()
|
||||
|
||||
def query(
|
||||
@@ -71,6 +90,16 @@ class Knowledge(BaseModel):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _create_elasticsearch_storage(self, embedder, collection_name):
|
||||
"""Create an Elasticsearch storage instance."""
|
||||
if ElasticsearchKnowledgeStorage is None:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
return ElasticsearchKnowledgeStorage(
|
||||
embedder_config=embedder, collection_name=collection_name
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.reset()
|
||||
|
||||
268
src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py
Normal file
268
src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import contextlib
|
||||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(logger_name="elasticsearch", level=logging.ERROR):
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
"""
|
||||
Extends BaseKnowledgeStorage to use Elasticsearch for storing embeddings
|
||||
and improving search efficiency.
|
||||
"""
|
||||
|
||||
app: Any = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
host: str = "localhost",
|
||||
port: int = 9200,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._set_embedder_config(embedder_config)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.index_name = f"crewai_knowledge_{collection_name if collection_name else 'default'}".lower()
|
||||
self.additional_config = kwargs
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not self.app:
|
||||
self.initialize_knowledge_storage()
|
||||
|
||||
try:
|
||||
embedding = self._get_embedding_for_text(query[0])
|
||||
|
||||
search_query: Dict[str, Any] = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {"query_vector": embedding}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if filter:
|
||||
query_obj = search_query.get("query", {})
|
||||
if isinstance(query_obj, dict):
|
||||
script_score_obj = query_obj.get("script_score", {})
|
||||
if isinstance(script_score_obj, dict):
|
||||
query_part = script_score_obj.get("query", {})
|
||||
if isinstance(query_part, dict):
|
||||
for key, value in filter.items():
|
||||
new_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query_part,
|
||||
{"match": {f"metadata.{key}": value}}
|
||||
]
|
||||
}
|
||||
}
|
||||
if isinstance(script_score_obj, dict):
|
||||
script_score_obj["query"] = new_query
|
||||
|
||||
with suppress_logging():
|
||||
if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
|
||||
response = self.app.search(
|
||||
index=self.index_name,
|
||||
body=search_query
|
||||
)
|
||||
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
adjusted_score = (hit["_score"] - 1.0)
|
||||
|
||||
if adjusted_score >= score_threshold:
|
||||
results.append({
|
||||
"id": hit["_id"],
|
||||
"metadata": hit["_source"]["metadata"],
|
||||
"context": hit["_source"]["text"],
|
||||
"score": adjusted_score,
|
||||
})
|
||||
|
||||
return results
|
||||
else:
|
||||
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
|
||||
return []
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log("error", f"Search error: {e}", "red")
|
||||
raise Exception(f"Error during knowledge search: {str(e)}")
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
es_auth = {}
|
||||
if self.username and self.password:
|
||||
es_auth = {"basic_auth": (self.username, self.password)}
|
||||
|
||||
self.app = Elasticsearch(
|
||||
[f"http://{self.host}:{self.port}"],
|
||||
**es_auth,
|
||||
**self.additional_config
|
||||
)
|
||||
|
||||
if not self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.create(
|
||||
index=self.index_name,
|
||||
body={
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"embedding": {
|
||||
"type": "dense_vector",
|
||||
"dims": 1536, # Default for OpenAI embeddings
|
||||
"index": True,
|
||||
"similarity": "cosine"
|
||||
},
|
||||
"metadata": {"type": "object"}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
f"Error initializing Elasticsearch: {str(e)}",
|
||||
"red"
|
||||
)
|
||||
raise Exception(f"Error initializing Elasticsearch: {str(e)}")
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
if self.app is not None:
|
||||
if self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.delete(index=self.index_name)
|
||||
|
||||
self.initialize_knowledge_storage()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the knowledge storage: {e}"
|
||||
)
|
||||
|
||||
def save(
|
||||
self,
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
) -> None:
|
||||
if not self.app:
|
||||
self.initialize_knowledge_storage()
|
||||
|
||||
try:
|
||||
unique_docs = {}
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest()
|
||||
doc_metadata = None
|
||||
if metadata is not None:
|
||||
if isinstance(metadata, list):
|
||||
doc_metadata = metadata[idx]
|
||||
else:
|
||||
doc_metadata = metadata
|
||||
unique_docs[doc_id] = (doc, doc_metadata)
|
||||
|
||||
for doc_id, (doc, meta) in unique_docs.items():
|
||||
embedding = self._get_embedding_for_text(doc)
|
||||
|
||||
doc_body = {
|
||||
"text": doc,
|
||||
"embedding": embedding,
|
||||
"metadata": meta or {},
|
||||
}
|
||||
|
||||
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||
index_func = getattr(self.app, "index")
|
||||
index_func(
|
||||
index=self.index_name,
|
||||
id=doc_id,
|
||||
document=doc_body,
|
||||
refresh=True # Make the document immediately available for search
|
||||
)
|
||||
else:
|
||||
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
|
||||
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log("error", f"Save error: {e}", "red")
|
||||
raise Exception(f"Error during knowledge save: {str(e)}")
|
||||
|
||||
def _get_embedding_for_text(self, text: str) -> List[float]:
|
||||
"""Get embedding for text using the configured embedder."""
|
||||
if self.embedder_config is None:
|
||||
raise ValueError("Embedder configuration is not set")
|
||||
|
||||
embedder = self.embedder_config
|
||||
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||
embed_func = getattr(embedder, "embed_documents")
|
||||
return embed_func([text])[0]
|
||||
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||
embed_func = getattr(embedder, "embed")
|
||||
return embed_func(text)
|
||||
else:
|
||||
raise ValueError("Invalid embedding function configuration")
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(
|
||||
self, embedder: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
"""
|
||||
self.embedder_config = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
if embedder
|
||||
else self._create_default_embedding_function()
|
||||
)
|
||||
@@ -22,7 +22,9 @@ class EntityMemory(Memory):
|
||||
else:
|
||||
memory_provider = None
|
||||
|
||||
if memory_provider == "mem0":
|
||||
if storage:
|
||||
pass
|
||||
elif memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
@@ -30,17 +32,26 @@ class EntityMemory(Memory):
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
storage = Mem0Storage(type="entities", crew=crew)
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
elif memory_provider == "elasticsearch":
|
||||
try:
|
||||
storage = self._create_elasticsearch_storage(
|
||||
type="entities",
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
else:
|
||||
storage = RAGStorage(
|
||||
type="entities",
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
|
||||
super().__init__(storage=storage)
|
||||
@@ -59,6 +70,11 @@ class EntityMemory(Memory):
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
super().save(data, item.metadata)
|
||||
|
||||
def _create_elasticsearch_storage(self, **kwargs):
|
||||
"""Create an Elasticsearch storage instance."""
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
return ElasticsearchStorage(**kwargs)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -24,7 +24,9 @@ class ShortTermMemory(Memory):
|
||||
else:
|
||||
memory_provider = None
|
||||
|
||||
if memory_provider == "mem0":
|
||||
if storage:
|
||||
pass
|
||||
elif memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
@@ -32,16 +34,24 @@ class ShortTermMemory(Memory):
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew)
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
elif memory_provider == "elasticsearch":
|
||||
try:
|
||||
storage = self._create_elasticsearch_storage(
|
||||
type="short_term",
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
else:
|
||||
storage = RAGStorage(
|
||||
type="short_term",
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
@@ -68,6 +78,11 @@ class ShortTermMemory(Memory):
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
|
||||
def _create_elasticsearch_storage(self, **kwargs):
|
||||
"""Create an Elasticsearch storage instance."""
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
return ElasticsearchStorage(**kwargs)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
275
src/crewai/memory/storage/elasticsearch_storage.py
Normal file
275
src/crewai/memory/storage/elasticsearch_storage.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(logger_name="elasticsearch", level=logging.ERROR):
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class ElasticsearchStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends BaseRAGStorage to use Elasticsearch for storing embeddings
|
||||
and improving search efficiency.
|
||||
"""
|
||||
|
||||
app: Any = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Any = None,
|
||||
crew: Any = None,
|
||||
path: Optional[str] = None,
|
||||
host: str = "localhost",
|
||||
port: int = 9200,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.index_name = f"crewai_{type}".lower()
|
||||
self.additional_config = kwargs
|
||||
|
||||
self._initialize_app()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory and index names.
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
||||
)
|
||||
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def _set_embedder_config(self):
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
self._set_embedder_config()
|
||||
|
||||
es_auth = {}
|
||||
if self.username and self.password:
|
||||
es_auth = {"basic_auth": (self.username, self.password)}
|
||||
|
||||
self.app = Elasticsearch(
|
||||
[f"http://{self.host}:{self.port}"],
|
||||
**es_auth,
|
||||
**self.additional_config
|
||||
)
|
||||
|
||||
if not self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.create(
|
||||
index=self.index_name,
|
||||
body={
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"embedding": {
|
||||
"type": "dense_vector",
|
||||
"dims": 1536, # Default for OpenAI embeddings
|
||||
"index": True,
|
||||
"similarity": "cosine"
|
||||
},
|
||||
"metadata": {"type": "object"}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
f"Error initializing Elasticsearch: {str(e)}",
|
||||
"red"
|
||||
)
|
||||
raise Exception(f"Error initializing Elasticsearch: {str(e)}")
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
try:
|
||||
self._generate_embedding(value, metadata)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
if not hasattr(self, "app") or self.app is None:
|
||||
self._initialize_app()
|
||||
|
||||
try:
|
||||
embedding = self._get_embedding_for_text(query)
|
||||
|
||||
search_query: Dict[str, Any] = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {"query_vector": embedding}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if filter:
|
||||
query_obj = search_query.get("query", {})
|
||||
if isinstance(query_obj, dict):
|
||||
script_score_obj = query_obj.get("script_score", {})
|
||||
if isinstance(script_score_obj, dict):
|
||||
query_part = script_score_obj.get("query", {})
|
||||
if isinstance(query_part, dict):
|
||||
for key, value in filter.items():
|
||||
new_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query_part,
|
||||
{"match": {f"metadata.{key}": value}}
|
||||
]
|
||||
}
|
||||
}
|
||||
if isinstance(script_score_obj, dict):
|
||||
script_score_obj["query"] = new_query
|
||||
|
||||
with suppress_logging():
|
||||
if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
|
||||
search_func = getattr(self.app, "search")
|
||||
response = search_func(
|
||||
index=self.index_name,
|
||||
body=search_query
|
||||
)
|
||||
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
adjusted_score = (hit["_score"] - 1.0)
|
||||
|
||||
if adjusted_score >= score_threshold:
|
||||
results.append({
|
||||
"id": hit["_id"],
|
||||
"metadata": hit["_source"]["metadata"],
|
||||
"context": hit["_source"]["text"],
|
||||
"score": adjusted_score,
|
||||
})
|
||||
|
||||
return results
|
||||
else:
|
||||
logging.error("Elasticsearch client is not initialized")
|
||||
return []
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _get_embedding_for_text(self, text: str) -> List[float]:
|
||||
"""Get embedding for text using the configured embedder."""
|
||||
if self.embedder_config is None:
|
||||
raise ValueError("Embedder configuration is not set")
|
||||
|
||||
embedder = self.embedder_config
|
||||
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||
embed_func = getattr(embedder, "embed_documents")
|
||||
return embed_func([text])[0]
|
||||
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||
embed_func = getattr(embedder, "embed")
|
||||
return embed_func(text)
|
||||
else:
|
||||
raise ValueError("Invalid embedding function configuration")
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Generate embedding for text and save to Elasticsearch.
|
||||
|
||||
This method overrides the BaseRAGStorage method to use Elasticsearch.
|
||||
"""
|
||||
if not hasattr(self, "app") or self.app is None:
|
||||
self._initialize_app()
|
||||
|
||||
embedding = self._get_embedding_for_text(text)
|
||||
|
||||
doc = {
|
||||
"text": text,
|
||||
"embedding": embedding,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||
index_func = getattr(self.app, "index")
|
||||
result = index_func(
|
||||
index=self.index_name,
|
||||
id=str(uuid.uuid4()),
|
||||
document=doc,
|
||||
refresh=True # Make the document immediately available for search
|
||||
)
|
||||
return result
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
if self.app is not None:
|
||||
if self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.delete(index=self.index_name)
|
||||
|
||||
self._initialize_app()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
81
src/crewai/memory/storage/storage_factory.py
Normal file
81
src/crewai/memory/storage/storage_factory.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, Optional, Type, cast
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
try:
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
except ImportError:
|
||||
ElasticsearchStorage = None
|
||||
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
Mem0Storage = None
|
||||
|
||||
|
||||
class StorageFactory:
|
||||
"""Factory for creating storage instances based on provider type."""
|
||||
|
||||
@classmethod
|
||||
def create_storage(
|
||||
cls,
|
||||
provider: str,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Any] = None,
|
||||
crew: Any = None,
|
||||
path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> BaseRAGStorage:
|
||||
"""Create a storage instance based on the provider type.
|
||||
|
||||
Args:
|
||||
provider: Type of storage provider ("chromadb", "elasticsearch", "mem0").
|
||||
type: Type of memory storage (e.g., "short_term", "entity").
|
||||
allow_reset: Whether to allow resetting the storage.
|
||||
embedder_config: Configuration for the embedder.
|
||||
crew: Crew instance.
|
||||
path: Path to the storage.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
Storage instance.
|
||||
"""
|
||||
if provider == "elasticsearch":
|
||||
if ElasticsearchStorage is None:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`.",
|
||||
"red",
|
||||
)
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
return ElasticsearchStorage(
|
||||
type=type,
|
||||
allow_reset=allow_reset,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
elif provider == "mem0":
|
||||
if Mem0Storage is None:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`.",
|
||||
"red",
|
||||
)
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew))
|
||||
return RAGStorage(
|
||||
type=type,
|
||||
allow_reset=allow_reset,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
90
tests/integration/elasticsearch_integration_test.py
Normal file
90
tests/integration/elasticsearch_integration_test.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Integration test for Elasticsearch with CrewAI."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RUN_ELASTICSEARCH_TESTS") != "true",
|
||||
reason="Elasticsearch tests require RUN_ELASTICSEARCH_TESTS=true"
|
||||
)
|
||||
class TestElasticsearchIntegration(unittest.TestCase):
|
||||
"""Integration test for Elasticsearch with CrewAI."""
|
||||
|
||||
def test_crew_with_elasticsearch_memory(self):
|
||||
"""Test a crew with Elasticsearch memory."""
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="Research a topic",
|
||||
backstory="You are a researcher who loves to find information.",
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
role="Writer",
|
||||
goal="Write about a topic",
|
||||
backstory="You are a writer who loves to write about topics.",
|
||||
)
|
||||
|
||||
research_task = Task(
|
||||
description="Research about AI",
|
||||
expected_output="Information about AI",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
write_task = Task(
|
||||
description="Write about AI",
|
||||
expected_output="Article about AI",
|
||||
agent=writer,
|
||||
context=[research_task],
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[research_task, write_task],
|
||||
memory_config={"provider": "elasticsearch"},
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_crew_with_elasticsearch_knowledge(self):
|
||||
"""Test a crew with Elasticsearch knowledge."""
|
||||
content = "AI is a field of computer science that focuses on creating machines that can perform tasks that typically require human intelligence."
|
||||
string_source = StringKnowledgeSource(
|
||||
content=content, metadata={"topic": "AI"}
|
||||
)
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[string_source],
|
||||
storage_provider="elasticsearch",
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="AI Expert",
|
||||
goal="Explain AI",
|
||||
backstory="You are an AI expert who loves to explain AI concepts.",
|
||||
knowledge=[knowledge],
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Explain what AI is",
|
||||
expected_output="Explanation of AI",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
94
tests/knowledge/elasticsearch_knowledge_storage_test.py
Normal file
94
tests/knowledge/elasticsearch_knowledge_storage_test.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Test Elasticsearch knowledge storage functionality."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import (
|
||||
ElasticsearchKnowledgeStorage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RUN_ELASTICSEARCH_TESTS") != "true",
|
||||
reason="Elasticsearch tests require RUN_ELASTICSEARCH_TESTS=true"
|
||||
)
|
||||
class TestElasticsearchKnowledgeStorage(unittest.TestCase):
|
||||
"""Test Elasticsearch knowledge storage functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.es_mock = MagicMock()
|
||||
self.es_mock.indices.exists.return_value = False
|
||||
|
||||
self.embedder_mock = MagicMock()
|
||||
self.embedder_mock.embed_documents.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
self.es_patcher = patch(
|
||||
"crewai.knowledge.storage.elasticsearch_knowledge_storage.Elasticsearch",
|
||||
return_value=self.es_mock
|
||||
)
|
||||
self.es_class_mock = self.es_patcher.start()
|
||||
|
||||
self.storage = ElasticsearchKnowledgeStorage(
|
||||
embedder_config=self.embedder_mock,
|
||||
collection_name="test"
|
||||
)
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
def tearDown(self):
|
||||
"""Tear down test fixtures."""
|
||||
self.es_patcher.stop()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test initialization of Elasticsearch knowledge storage."""
|
||||
self.es_class_mock.assert_called_once()
|
||||
|
||||
self.es_mock.indices.create.assert_called_once()
|
||||
|
||||
def test_save(self):
|
||||
"""Test saving to Elasticsearch knowledge storage."""
|
||||
self.storage.save(["Test document 1", "Test document 2"], {"source": "test"})
|
||||
|
||||
self.assertEqual(self.es_mock.index.call_count, 2)
|
||||
|
||||
self.assertEqual(self.embedder_mock.embed_documents.call_count, 2)
|
||||
|
||||
def test_search(self):
|
||||
"""Test searching in Elasticsearch knowledge storage."""
|
||||
self.es_mock.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "test_id",
|
||||
"_score": 1.5, # Score between 1-2 (Elasticsearch range)
|
||||
"_source": {
|
||||
"text": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
results = self.storage.search(["test query"])
|
||||
|
||||
self.es_mock.search.assert_called_once()
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]["id"], "test_id")
|
||||
self.assertEqual(results[0]["context"], "Test document")
|
||||
self.assertEqual(results[0]["metadata"], {"source": "test"})
|
||||
self.assertEqual(results[0]["score"], 0.5) # Adjusted to 0-1 range
|
||||
|
||||
def test_reset(self):
|
||||
"""Test resetting Elasticsearch knowledge storage."""
|
||||
self.es_mock.indices.exists.return_value = True
|
||||
|
||||
self.storage.reset()
|
||||
|
||||
self.es_mock.indices.delete.assert_called_once()
|
||||
|
||||
self.assertEqual(self.es_mock.indices.create.call_count, 2)
|
||||
91
tests/memory/elasticsearch_storage_test.py
Normal file
91
tests/memory/elasticsearch_storage_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Test Elasticsearch storage functionality."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RUN_ELASTICSEARCH_TESTS") != "true",
|
||||
reason="Elasticsearch tests require RUN_ELASTICSEARCH_TESTS=true"
|
||||
)
|
||||
class TestElasticsearchStorage(unittest.TestCase):
|
||||
"""Test Elasticsearch storage functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.es_mock = MagicMock()
|
||||
self.es_mock.indices.exists.return_value = False
|
||||
|
||||
self.embedder_mock = MagicMock()
|
||||
self.embedder_mock.embed_documents.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
self.es_patcher = patch(
|
||||
"crewai.memory.storage.elasticsearch_storage.Elasticsearch",
|
||||
return_value=self.es_mock
|
||||
)
|
||||
self.es_class_mock = self.es_patcher.start()
|
||||
|
||||
self.storage = ElasticsearchStorage(
|
||||
type="test",
|
||||
embedder_config=self.embedder_mock
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Tear down test fixtures."""
|
||||
self.es_patcher.stop()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test initialization of Elasticsearch storage."""
|
||||
self.es_class_mock.assert_called_once()
|
||||
|
||||
self.es_mock.indices.create.assert_called_once()
|
||||
|
||||
def test_save(self):
|
||||
"""Test saving to Elasticsearch storage."""
|
||||
self.storage.save("Test document", {"source": "test"})
|
||||
|
||||
self.es_mock.index.assert_called_once()
|
||||
|
||||
self.embedder_mock.embed_documents.assert_called_once_with(["Test document"])
|
||||
|
||||
def test_search(self):
|
||||
"""Test searching in Elasticsearch storage."""
|
||||
self.es_mock.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "test_id",
|
||||
"_score": 1.5, # Score between 1-2 (Elasticsearch range)
|
||||
"_source": {
|
||||
"text": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
results = self.storage.search("test query")
|
||||
|
||||
self.es_mock.search.assert_called_once()
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]["id"], "test_id")
|
||||
self.assertEqual(results[0]["context"], "Test document")
|
||||
self.assertEqual(results[0]["metadata"], {"source": "test"})
|
||||
self.assertEqual(results[0]["score"], 0.5) # Adjusted to 0-1 range
|
||||
|
||||
def test_reset(self):
|
||||
"""Test resetting Elasticsearch storage."""
|
||||
self.es_mock.indices.exists.return_value = True
|
||||
|
||||
self.storage.reset()
|
||||
|
||||
self.es_mock.indices.delete.assert_called_once()
|
||||
|
||||
self.assertEqual(self.es_mock.indices.create.call_count, 2)
|
||||
33
uv.lock
generated
33
uv.lock
generated
@@ -1,5 +1,4 @@
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.10, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11' and platform_python_implementation == 'PyPy' and sys_platform == 'darwin'",
|
||||
@@ -628,6 +627,7 @@ dependencies = [
|
||||
{ name = "blinker" },
|
||||
{ name = "chromadb" },
|
||||
{ name = "click" },
|
||||
{ name = "elasticsearch" },
|
||||
{ name = "instructor" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "json5" },
|
||||
@@ -710,6 +710,7 @@ requires-dist = [
|
||||
{ name = "click", specifier = ">=8.1.7" },
|
||||
{ name = "crewai-tools", marker = "extra == 'tools'", specifier = "~=0.40.1" },
|
||||
{ name = "docling", marker = "extra == 'docling'", specifier = ">=2.12.0" },
|
||||
{ name = "elasticsearch", specifier = ">=9.0.0" },
|
||||
{ name = "fastembed", marker = "extra == 'fastembed'", specifier = ">=0.4.1" },
|
||||
{ name = "instructor", specifier = ">=1.3.3" },
|
||||
{ name = "json-repair", specifier = ">=0.25.2" },
|
||||
@@ -735,7 +736,6 @@ requires-dist = [
|
||||
{ name = "tomli-w", specifier = ">=1.1.0" },
|
||||
{ name = "uv", specifier = ">=0.4.25" },
|
||||
]
|
||||
provides-extras = ["tools", "embeddings", "agentops", "fastembed", "pdfplumber", "pandas", "openpyxl", "mem0", "docling", "aisuite"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
@@ -1097,6 +1097,33 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/84/4a2cab0e6adde6a85e7ba543862e5fc0250c51f3ac721a078a55cdcff250/easyocr-1.7.2-py3-none-any.whl", hash = "sha256:5be12f9b0e595d443c9c3d10b0542074b50f0ec2d98b141a109cd961fd1c177c", size = 2870178 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "elastic-transport"
|
||||
version = "8.17.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6a/54/d498a766ac8fa475f931da85a154666cc81a70f8eb4a780bc8e4e934e9ac/elastic_transport-8.17.1.tar.gz", hash = "sha256:5edef32ac864dca8e2f0a613ef63491ee8d6b8cfb52881fa7313ba9290cac6d2", size = 73425 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/cd/b71d5bc74cde7fc6fd9b2ff9389890f45d9762cbbbf81dc5e51fd7588c4a/elastic_transport-8.17.1-py3-none-any.whl", hash = "sha256:192718f498f1d10c5e9aa8b9cf32aed405e469a7f0e9d6a8923431dbb2c59fb8", size = 64969 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "elasticsearch"
|
||||
version = "9.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "elastic-transport" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/26/ad/d76e88811e68d7bdd976c0ff6027e7c3b544a949c8d3de052adc5765e1a6/elasticsearch-9.0.0.tar.gz", hash = "sha256:c075ccdc7d5697e2a842a88418efdb6cf6732d7a62c77a25d60184db23fd1464", size = 823636 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/b7/e85bdb8bed719dbf92780264d6c381186ced8eb7acc88bbe37a996f87b03/elasticsearch-9.0.0-py3-none-any.whl", hash = "sha256:295425172043e5db723d55cb3a5e28622696ca7739b466b812ab12ac938b6249", size = 895793 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "embedchain"
|
||||
version = "0.1.125"
|
||||
@@ -2988,6 +3015,7 @@ name = "nvidia-nccl-cu12"
|
||||
version = "2.20.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", size = 176238458 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/2a/0a131f572aa09f741c30ccd45a8e56316e8be8dfc7bc19bf0ab7cfef7b19/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56", size = 176249402 },
|
||||
]
|
||||
|
||||
@@ -2997,6 +3025,7 @@ version = "12.6.85"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971 },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user