mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-06 22:58:30 +00:00
Compare commits
4 Commits
lg-update-
...
devin/1745
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4e61ae714 | ||
|
|
958751fe36 | ||
|
|
3c838f16ff | ||
|
|
6c08e6062a |
117
docs/how-to/elasticsearch-integration.md
Normal file
117
docs/how-to/elasticsearch-integration.md
Normal file
@@ -0,0 +1,117 @@
|
||||
# Elasticsearch Integration
|
||||
|
||||
CrewAI supports using Elasticsearch as an alternative to ChromaDB for RAG (Retrieval Augmented Generation) storage. This allows you to leverage Elasticsearch's powerful search capabilities and scalability for your AI agents.
|
||||
|
||||
## Installation
|
||||
|
||||
To use Elasticsearch with CrewAI, you need to install the Elasticsearch Python client:
|
||||
|
||||
```bash
|
||||
pip install elasticsearch
|
||||
```
|
||||
|
||||
## Using Elasticsearch for Memory
|
||||
|
||||
You can configure your crew to use Elasticsearch for memory storage:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
# Create agents and tasks
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research a topic",
|
||||
backstory="You are a researcher who loves to find information.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research about AI",
|
||||
expected_output="Information about AI",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a crew with Elasticsearch memory
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
memory_config={
|
||||
"provider": "elasticsearch",
|
||||
"host": "localhost", # Optional, defaults to localhost
|
||||
"port": 9200, # Optional, defaults to 9200
|
||||
"username": "user", # Optional
|
||||
"password": "pass", # Optional
|
||||
},
|
||||
)
|
||||
|
||||
# Execute the crew
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
## Using Elasticsearch for Knowledge
|
||||
|
||||
You can also use Elasticsearch for knowledge storage:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
|
||||
# Create knowledge with Elasticsearch storage
|
||||
content = "AI is a field of computer science that focuses on creating machines that can perform tasks that typically require human intelligence."
|
||||
string_source = StringKnowledgeSource(
|
||||
content=content, metadata={"topic": "AI"}
|
||||
)
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[string_source],
|
||||
storage_provider="elasticsearch", # Use Elasticsearch
|
||||
# Optional Elasticsearch configuration
|
||||
host="localhost",
|
||||
port=9200,
|
||||
username="user",
|
||||
password="pass",
|
||||
)
|
||||
|
||||
# Create an agent with the knowledge
|
||||
agent = Agent(
|
||||
role="AI Expert",
|
||||
goal="Explain AI",
|
||||
backstory="You are an AI expert who loves to explain AI concepts.",
|
||||
knowledge=[knowledge],
|
||||
)
|
||||
|
||||
# Create a task
|
||||
task = Task(
|
||||
description="Explain what AI is",
|
||||
expected_output="Explanation of AI",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a crew
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
# Execute the crew
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
The Elasticsearch integration supports the following configuration options:
|
||||
|
||||
- `host`: Elasticsearch host (default: "localhost")
|
||||
- `port`: Elasticsearch port (default: 9200)
|
||||
- `username`: Elasticsearch username (optional)
|
||||
- `password`: Elasticsearch password (optional)
|
||||
- Additional keyword arguments are passed directly to the Elasticsearch client
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the Elasticsearch tests, you need to set the `RUN_ELASTICSEARCH_TESTS` environment variable to `true`:
|
||||
|
||||
```bash
|
||||
RUN_ELASTICSEARCH_TESTS=true pytest tests/memory/elasticsearch_storage_test.py tests/knowledge/elasticsearch_knowledge_storage_test.py tests/integration/elasticsearch_integration_test.py
|
||||
```
|
||||
@@ -8,29 +8,11 @@ icon: code-simple
|
||||
|
||||
## Description
|
||||
|
||||
The `CodeInterpreterTool` enables CrewAI agents to execute Python 3 code that they generate autonomously. This functionality is particularly valuable as it allows agents to create code, execute it, obtain the results, and utilize that information to inform subsequent decisions and actions.
|
||||
The `CodeInterpreterTool` enables CrewAI agents to execute Python 3 code that they generate autonomously. The code is run in a secure, isolated Docker container, ensuring safety regardless of the content. This functionality is particularly valuable as it allows agents to create code, execute it, obtain the results, and utilize that information to inform subsequent decisions and actions.
|
||||
|
||||
There are several ways to use this tool:
|
||||
|
||||
### Docker Container (Recommended)
|
||||
|
||||
This is the primary option. The code runs in a secure, isolated Docker container, ensuring safety regardless of its content.
|
||||
Make sure Docker is installed and running on your system. If you don’t have it, you can install it from [here](https://docs.docker.com/get-docker/).
|
||||
|
||||
### Sandbox environment
|
||||
|
||||
If Docker is unavailable — either not installed or not accessible for any reason — the code will be executed in a restricted Python environment - called sandbox.
|
||||
This environment is very limited, with strict restrictions on many modules and built-in functions.
|
||||
|
||||
### Unsafe Execution
|
||||
|
||||
**NOT RECOMMENDED FOR PRODUCTION**
|
||||
This mode allows execution of any Python code, including dangerous calls to `sys, os..` and similar modules. [Check out](/tools/codeinterpretertool#enabling-unsafe-mode) how to enable this mode
|
||||
|
||||
## Logging
|
||||
|
||||
The `CodeInterpreterTool` logs the selected execution strategy to STDOUT
|
||||
## Requirements
|
||||
|
||||
- Docker must be installed and running on your system. If you don't have it, you can install it from [here](https://docs.docker.com/get-docker/).
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -92,32 +74,18 @@ programmer_agent = Agent(
|
||||
)
|
||||
```
|
||||
|
||||
### Enabling `unsafe_mode`
|
||||
|
||||
```python Code
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
code = """
|
||||
import os
|
||||
os.system("ls -la")
|
||||
"""
|
||||
|
||||
CodeInterpreterTool(unsafe_mode=True).run(code=code)
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
The `CodeInterpreterTool` accepts the following parameters during initialization:
|
||||
|
||||
- **user_dockerfile_path**: Optional. Path to a custom Dockerfile to use for the code interpreter container.
|
||||
- **user_docker_base_url**: Optional. URL to the Docker daemon to use for running the container.
|
||||
- **unsafe_mode**: Optional. Whether to run code directly on the host machine instead of in a Docker container or sandbox. Default is `False`. Use with caution!
|
||||
- **default_image_tag**: Optional. Default Docker image tag. Default is `code-interpreter:latest`
|
||||
- **unsafe_mode**: Optional. Whether to run code directly on the host machine instead of in a Docker container. Default is `False`. Use with caution!
|
||||
|
||||
When using the tool with an agent, the agent will need to provide:
|
||||
|
||||
- **code**: Required. The Python 3 code to execute.
|
||||
- **libraries_used**: Optional. A list of libraries used in the code that need to be installed. Default is `[]`
|
||||
- **libraries_used**: Required. A list of libraries used in the code that need to be installed.
|
||||
|
||||
## Agent Integration Example
|
||||
|
||||
@@ -184,7 +152,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
if self.unsafe_mode:
|
||||
return self.run_code_unsafe(code, libraries_used)
|
||||
else:
|
||||
return self.run_code_safety(code, libraries_used)
|
||||
return self.run_code_in_docker(code, libraries_used)
|
||||
```
|
||||
|
||||
The tool performs the following steps:
|
||||
@@ -200,9 +168,8 @@ The tool performs the following steps:
|
||||
By default, the `CodeInterpreterTool` runs code in an isolated Docker container, which provides a layer of security. However, there are still some security considerations to keep in mind:
|
||||
|
||||
1. The Docker container has access to the current working directory, so sensitive files could potentially be accessed.
|
||||
2. If the Docker container is unavailable and the code needs to run safely, it will be executed in a sandbox environment. For security reasons, installing arbitrary libraries is not allowed
|
||||
3. The `unsafe_mode` parameter allows code to be executed directly on the host machine, which should only be used in trusted environments.
|
||||
4. Be cautious when allowing agents to install arbitrary libraries, as they could potentially include malicious code.
|
||||
2. The `unsafe_mode` parameter allows code to be executed directly on the host machine, which should only be used in trusted environments.
|
||||
3. Be cautious when allowing agents to install arbitrary libraries, as they could potentially include malicious code.
|
||||
|
||||
## Conclusion
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ dependencies = [
|
||||
"tomli>=2.0.2",
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"elasticsearch>=9.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
|
||||
try:
|
||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import (
|
||||
ElasticsearchKnowledgeStorage,
|
||||
)
|
||||
except ImportError:
|
||||
ElasticsearchKnowledgeStorage = None
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
|
||||
@@ -30,17 +37,29 @@ class Knowledge(BaseModel):
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
storage_provider: str = "chromadb",
|
||||
**data,
|
||||
):
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
if storage_provider == "elasticsearch":
|
||||
try:
|
||||
self.storage = cast(KnowledgeStorage, self._create_elasticsearch_storage(
|
||||
embedder, collection_name
|
||||
))
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
if self.storage is not None:
|
||||
self.storage.initialize_knowledge_storage()
|
||||
self._add_sources()
|
||||
|
||||
def query(
|
||||
@@ -71,6 +90,16 @@ class Knowledge(BaseModel):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _create_elasticsearch_storage(self, embedder, collection_name):
|
||||
"""Create an Elasticsearch storage instance."""
|
||||
if ElasticsearchKnowledgeStorage is None:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
return ElasticsearchKnowledgeStorage(
|
||||
embedder_config=embedder, collection_name=collection_name
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.reset()
|
||||
|
||||
268
src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py
Normal file
268
src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import contextlib
|
||||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(logger_name="elasticsearch", level=logging.ERROR):
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
"""
|
||||
Extends BaseKnowledgeStorage to use Elasticsearch for storing embeddings
|
||||
and improving search efficiency.
|
||||
"""
|
||||
|
||||
app: Any = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
host: str = "localhost",
|
||||
port: int = 9200,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._set_embedder_config(embedder_config)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.index_name = f"crewai_knowledge_{collection_name if collection_name else 'default'}".lower()
|
||||
self.additional_config = kwargs
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not self.app:
|
||||
self.initialize_knowledge_storage()
|
||||
|
||||
try:
|
||||
embedding = self._get_embedding_for_text(query[0])
|
||||
|
||||
search_query: Dict[str, Any] = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {"query_vector": embedding}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if filter:
|
||||
query_obj = search_query.get("query", {})
|
||||
if isinstance(query_obj, dict):
|
||||
script_score_obj = query_obj.get("script_score", {})
|
||||
if isinstance(script_score_obj, dict):
|
||||
query_part = script_score_obj.get("query", {})
|
||||
if isinstance(query_part, dict):
|
||||
for key, value in filter.items():
|
||||
new_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query_part,
|
||||
{"match": {f"metadata.{key}": value}}
|
||||
]
|
||||
}
|
||||
}
|
||||
if isinstance(script_score_obj, dict):
|
||||
script_score_obj["query"] = new_query
|
||||
|
||||
with suppress_logging():
|
||||
if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
|
||||
response = self.app.search(
|
||||
index=self.index_name,
|
||||
body=search_query
|
||||
)
|
||||
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
adjusted_score = (hit["_score"] - 1.0)
|
||||
|
||||
if adjusted_score >= score_threshold:
|
||||
results.append({
|
||||
"id": hit["_id"],
|
||||
"metadata": hit["_source"]["metadata"],
|
||||
"context": hit["_source"]["text"],
|
||||
"score": adjusted_score,
|
||||
})
|
||||
|
||||
return results
|
||||
else:
|
||||
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
|
||||
return []
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log("error", f"Search error: {e}", "red")
|
||||
raise Exception(f"Error during knowledge search: {str(e)}")
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
es_auth = {}
|
||||
if self.username and self.password:
|
||||
es_auth = {"basic_auth": (self.username, self.password)}
|
||||
|
||||
self.app = Elasticsearch(
|
||||
[f"http://{self.host}:{self.port}"],
|
||||
**es_auth,
|
||||
**self.additional_config
|
||||
)
|
||||
|
||||
if not self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.create(
|
||||
index=self.index_name,
|
||||
body={
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"embedding": {
|
||||
"type": "dense_vector",
|
||||
"dims": 1536, # Default for OpenAI embeddings
|
||||
"index": True,
|
||||
"similarity": "cosine"
|
||||
},
|
||||
"metadata": {"type": "object"}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
f"Error initializing Elasticsearch: {str(e)}",
|
||||
"red"
|
||||
)
|
||||
raise Exception(f"Error initializing Elasticsearch: {str(e)}")
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
if self.app is not None:
|
||||
if self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.delete(index=self.index_name)
|
||||
|
||||
self.initialize_knowledge_storage()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the knowledge storage: {e}"
|
||||
)
|
||||
|
||||
def save(
|
||||
self,
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
) -> None:
|
||||
if not self.app:
|
||||
self.initialize_knowledge_storage()
|
||||
|
||||
try:
|
||||
unique_docs = {}
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest()
|
||||
doc_metadata = None
|
||||
if metadata is not None:
|
||||
if isinstance(metadata, list):
|
||||
doc_metadata = metadata[idx]
|
||||
else:
|
||||
doc_metadata = metadata
|
||||
unique_docs[doc_id] = (doc, doc_metadata)
|
||||
|
||||
for doc_id, (doc, meta) in unique_docs.items():
|
||||
embedding = self._get_embedding_for_text(doc)
|
||||
|
||||
doc_body = {
|
||||
"text": doc,
|
||||
"embedding": embedding,
|
||||
"metadata": meta or {},
|
||||
}
|
||||
|
||||
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||
index_func = getattr(self.app, "index")
|
||||
index_func(
|
||||
index=self.index_name,
|
||||
id=doc_id,
|
||||
document=doc_body,
|
||||
refresh=True # Make the document immediately available for search
|
||||
)
|
||||
else:
|
||||
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
|
||||
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log("error", f"Save error: {e}", "red")
|
||||
raise Exception(f"Error during knowledge save: {str(e)}")
|
||||
|
||||
def _get_embedding_for_text(self, text: str) -> List[float]:
|
||||
"""Get embedding for text using the configured embedder."""
|
||||
if self.embedder_config is None:
|
||||
raise ValueError("Embedder configuration is not set")
|
||||
|
||||
embedder = self.embedder_config
|
||||
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||
embed_func = getattr(embedder, "embed_documents")
|
||||
return embed_func([text])[0]
|
||||
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||
embed_func = getattr(embedder, "embed")
|
||||
return embed_func(text)
|
||||
else:
|
||||
raise ValueError("Invalid embedding function configuration")
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(
|
||||
self, embedder: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
"""
|
||||
self.embedder_config = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
if embedder
|
||||
else self._create_default_embedding_function()
|
||||
)
|
||||
@@ -22,7 +22,9 @@ class EntityMemory(Memory):
|
||||
else:
|
||||
memory_provider = None
|
||||
|
||||
if memory_provider == "mem0":
|
||||
if storage:
|
||||
pass
|
||||
elif memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
@@ -30,17 +32,26 @@ class EntityMemory(Memory):
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
storage = Mem0Storage(type="entities", crew=crew)
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
elif memory_provider == "elasticsearch":
|
||||
try:
|
||||
storage = self._create_elasticsearch_storage(
|
||||
type="entities",
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
else:
|
||||
storage = RAGStorage(
|
||||
type="entities",
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
|
||||
super().__init__(storage=storage)
|
||||
@@ -59,6 +70,11 @@ class EntityMemory(Memory):
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
super().save(data, item.metadata)
|
||||
|
||||
def _create_elasticsearch_storage(self, **kwargs):
|
||||
"""Create an Elasticsearch storage instance."""
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
return ElasticsearchStorage(**kwargs)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -24,7 +24,9 @@ class ShortTermMemory(Memory):
|
||||
else:
|
||||
memory_provider = None
|
||||
|
||||
if memory_provider == "mem0":
|
||||
if storage:
|
||||
pass
|
||||
elif memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
@@ -32,16 +34,24 @@ class ShortTermMemory(Memory):
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew)
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
elif memory_provider == "elasticsearch":
|
||||
try:
|
||||
storage = self._create_elasticsearch_storage(
|
||||
type="short_term",
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
else:
|
||||
storage = RAGStorage(
|
||||
type="short_term",
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
@@ -68,6 +78,11 @@ class ShortTermMemory(Memory):
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
|
||||
def _create_elasticsearch_storage(self, **kwargs):
|
||||
"""Create an Elasticsearch storage instance."""
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
return ElasticsearchStorage(**kwargs)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
275
src/crewai/memory/storage/elasticsearch_storage.py
Normal file
275
src/crewai/memory/storage/elasticsearch_storage.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(logger_name="elasticsearch", level=logging.ERROR):
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class ElasticsearchStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends BaseRAGStorage to use Elasticsearch for storing embeddings
|
||||
and improving search efficiency.
|
||||
"""
|
||||
|
||||
app: Any = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Any = None,
|
||||
crew: Any = None,
|
||||
path: Optional[str] = None,
|
||||
host: str = "localhost",
|
||||
port: int = 9200,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.index_name = f"crewai_{type}".lower()
|
||||
self.additional_config = kwargs
|
||||
|
||||
self._initialize_app()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory and index names.
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
||||
)
|
||||
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def _set_embedder_config(self):
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
self._set_embedder_config()
|
||||
|
||||
es_auth = {}
|
||||
if self.username and self.password:
|
||||
es_auth = {"basic_auth": (self.username, self.password)}
|
||||
|
||||
self.app = Elasticsearch(
|
||||
[f"http://{self.host}:{self.port}"],
|
||||
**es_auth,
|
||||
**self.additional_config
|
||||
)
|
||||
|
||||
if not self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.create(
|
||||
index=self.index_name,
|
||||
body={
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"embedding": {
|
||||
"type": "dense_vector",
|
||||
"dims": 1536, # Default for OpenAI embeddings
|
||||
"index": True,
|
||||
"similarity": "cosine"
|
||||
},
|
||||
"metadata": {"type": "object"}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
f"Error initializing Elasticsearch: {str(e)}",
|
||||
"red"
|
||||
)
|
||||
raise Exception(f"Error initializing Elasticsearch: {str(e)}")
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
try:
|
||||
self._generate_embedding(value, metadata)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
if not hasattr(self, "app") or self.app is None:
|
||||
self._initialize_app()
|
||||
|
||||
try:
|
||||
embedding = self._get_embedding_for_text(query)
|
||||
|
||||
search_query: Dict[str, Any] = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {"query_vector": embedding}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if filter:
|
||||
query_obj = search_query.get("query", {})
|
||||
if isinstance(query_obj, dict):
|
||||
script_score_obj = query_obj.get("script_score", {})
|
||||
if isinstance(script_score_obj, dict):
|
||||
query_part = script_score_obj.get("query", {})
|
||||
if isinstance(query_part, dict):
|
||||
for key, value in filter.items():
|
||||
new_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query_part,
|
||||
{"match": {f"metadata.{key}": value}}
|
||||
]
|
||||
}
|
||||
}
|
||||
if isinstance(script_score_obj, dict):
|
||||
script_score_obj["query"] = new_query
|
||||
|
||||
with suppress_logging():
|
||||
if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
|
||||
search_func = getattr(self.app, "search")
|
||||
response = search_func(
|
||||
index=self.index_name,
|
||||
body=search_query
|
||||
)
|
||||
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
adjusted_score = (hit["_score"] - 1.0)
|
||||
|
||||
if adjusted_score >= score_threshold:
|
||||
results.append({
|
||||
"id": hit["_id"],
|
||||
"metadata": hit["_source"]["metadata"],
|
||||
"context": hit["_source"]["text"],
|
||||
"score": adjusted_score,
|
||||
})
|
||||
|
||||
return results
|
||||
else:
|
||||
logging.error("Elasticsearch client is not initialized")
|
||||
return []
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _get_embedding_for_text(self, text: str) -> List[float]:
|
||||
"""Get embedding for text using the configured embedder."""
|
||||
if self.embedder_config is None:
|
||||
raise ValueError("Embedder configuration is not set")
|
||||
|
||||
embedder = self.embedder_config
|
||||
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||
embed_func = getattr(embedder, "embed_documents")
|
||||
return embed_func([text])[0]
|
||||
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||
embed_func = getattr(embedder, "embed")
|
||||
return embed_func(text)
|
||||
else:
|
||||
raise ValueError("Invalid embedding function configuration")
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Generate embedding for text and save to Elasticsearch.
|
||||
|
||||
This method overrides the BaseRAGStorage method to use Elasticsearch.
|
||||
"""
|
||||
if not hasattr(self, "app") or self.app is None:
|
||||
self._initialize_app()
|
||||
|
||||
embedding = self._get_embedding_for_text(text)
|
||||
|
||||
doc = {
|
||||
"text": text,
|
||||
"embedding": embedding,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||
index_func = getattr(self.app, "index")
|
||||
result = index_func(
|
||||
index=self.index_name,
|
||||
id=str(uuid.uuid4()),
|
||||
document=doc,
|
||||
refresh=True # Make the document immediately available for search
|
||||
)
|
||||
return result
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
if self.app is not None:
|
||||
if self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.delete(index=self.index_name)
|
||||
|
||||
self._initialize_app()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
81
src/crewai/memory/storage/storage_factory.py
Normal file
81
src/crewai/memory/storage/storage_factory.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, Optional, Type, cast
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
try:
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
except ImportError:
|
||||
ElasticsearchStorage = None
|
||||
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
Mem0Storage = None
|
||||
|
||||
|
||||
class StorageFactory:
|
||||
"""Factory for creating storage instances based on provider type."""
|
||||
|
||||
@classmethod
|
||||
def create_storage(
|
||||
cls,
|
||||
provider: str,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Any] = None,
|
||||
crew: Any = None,
|
||||
path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> BaseRAGStorage:
|
||||
"""Create a storage instance based on the provider type.
|
||||
|
||||
Args:
|
||||
provider: Type of storage provider ("chromadb", "elasticsearch", "mem0").
|
||||
type: Type of memory storage (e.g., "short_term", "entity").
|
||||
allow_reset: Whether to allow resetting the storage.
|
||||
embedder_config: Configuration for the embedder.
|
||||
crew: Crew instance.
|
||||
path: Path to the storage.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
Storage instance.
|
||||
"""
|
||||
if provider == "elasticsearch":
|
||||
if ElasticsearchStorage is None:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`.",
|
||||
"red",
|
||||
)
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
return ElasticsearchStorage(
|
||||
type=type,
|
||||
allow_reset=allow_reset,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
elif provider == "mem0":
|
||||
if Mem0Storage is None:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`.",
|
||||
"red",
|
||||
)
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew))
|
||||
return RAGStorage(
|
||||
type=type,
|
||||
allow_reset=allow_reset,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
90
tests/integration/elasticsearch_integration_test.py
Normal file
90
tests/integration/elasticsearch_integration_test.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Integration test for Elasticsearch with CrewAI."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RUN_ELASTICSEARCH_TESTS") != "true",
|
||||
reason="Elasticsearch tests require RUN_ELASTICSEARCH_TESTS=true"
|
||||
)
|
||||
class TestElasticsearchIntegration(unittest.TestCase):
|
||||
"""Integration test for Elasticsearch with CrewAI."""
|
||||
|
||||
def test_crew_with_elasticsearch_memory(self):
|
||||
"""Test a crew with Elasticsearch memory."""
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="Research a topic",
|
||||
backstory="You are a researcher who loves to find information.",
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
role="Writer",
|
||||
goal="Write about a topic",
|
||||
backstory="You are a writer who loves to write about topics.",
|
||||
)
|
||||
|
||||
research_task = Task(
|
||||
description="Research about AI",
|
||||
expected_output="Information about AI",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
write_task = Task(
|
||||
description="Write about AI",
|
||||
expected_output="Article about AI",
|
||||
agent=writer,
|
||||
context=[research_task],
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[research_task, write_task],
|
||||
memory_config={"provider": "elasticsearch"},
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_crew_with_elasticsearch_knowledge(self):
|
||||
"""Test a crew with Elasticsearch knowledge."""
|
||||
content = "AI is a field of computer science that focuses on creating machines that can perform tasks that typically require human intelligence."
|
||||
string_source = StringKnowledgeSource(
|
||||
content=content, metadata={"topic": "AI"}
|
||||
)
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[string_source],
|
||||
storage_provider="elasticsearch",
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="AI Expert",
|
||||
goal="Explain AI",
|
||||
backstory="You are an AI expert who loves to explain AI concepts.",
|
||||
knowledge=[knowledge],
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Explain what AI is",
|
||||
expected_output="Explanation of AI",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
94
tests/knowledge/elasticsearch_knowledge_storage_test.py
Normal file
94
tests/knowledge/elasticsearch_knowledge_storage_test.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Test Elasticsearch knowledge storage functionality."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import (
|
||||
ElasticsearchKnowledgeStorage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RUN_ELASTICSEARCH_TESTS") != "true",
|
||||
reason="Elasticsearch tests require RUN_ELASTICSEARCH_TESTS=true"
|
||||
)
|
||||
class TestElasticsearchKnowledgeStorage(unittest.TestCase):
|
||||
"""Test Elasticsearch knowledge storage functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.es_mock = MagicMock()
|
||||
self.es_mock.indices.exists.return_value = False
|
||||
|
||||
self.embedder_mock = MagicMock()
|
||||
self.embedder_mock.embed_documents.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
self.es_patcher = patch(
|
||||
"crewai.knowledge.storage.elasticsearch_knowledge_storage.Elasticsearch",
|
||||
return_value=self.es_mock
|
||||
)
|
||||
self.es_class_mock = self.es_patcher.start()
|
||||
|
||||
self.storage = ElasticsearchKnowledgeStorage(
|
||||
embedder_config=self.embedder_mock,
|
||||
collection_name="test"
|
||||
)
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
def tearDown(self):
|
||||
"""Tear down test fixtures."""
|
||||
self.es_patcher.stop()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test initialization of Elasticsearch knowledge storage."""
|
||||
self.es_class_mock.assert_called_once()
|
||||
|
||||
self.es_mock.indices.create.assert_called_once()
|
||||
|
||||
def test_save(self):
|
||||
"""Test saving to Elasticsearch knowledge storage."""
|
||||
self.storage.save(["Test document 1", "Test document 2"], {"source": "test"})
|
||||
|
||||
self.assertEqual(self.es_mock.index.call_count, 2)
|
||||
|
||||
self.assertEqual(self.embedder_mock.embed_documents.call_count, 2)
|
||||
|
||||
def test_search(self):
|
||||
"""Test searching in Elasticsearch knowledge storage."""
|
||||
self.es_mock.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "test_id",
|
||||
"_score": 1.5, # Score between 1-2 (Elasticsearch range)
|
||||
"_source": {
|
||||
"text": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
results = self.storage.search(["test query"])
|
||||
|
||||
self.es_mock.search.assert_called_once()
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]["id"], "test_id")
|
||||
self.assertEqual(results[0]["context"], "Test document")
|
||||
self.assertEqual(results[0]["metadata"], {"source": "test"})
|
||||
self.assertEqual(results[0]["score"], 0.5) # Adjusted to 0-1 range
|
||||
|
||||
def test_reset(self):
|
||||
"""Test resetting Elasticsearch knowledge storage."""
|
||||
self.es_mock.indices.exists.return_value = True
|
||||
|
||||
self.storage.reset()
|
||||
|
||||
self.es_mock.indices.delete.assert_called_once()
|
||||
|
||||
self.assertEqual(self.es_mock.indices.create.call_count, 2)
|
||||
91
tests/memory/elasticsearch_storage_test.py
Normal file
91
tests/memory/elasticsearch_storage_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Test Elasticsearch storage functionality."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RUN_ELASTICSEARCH_TESTS") != "true",
|
||||
reason="Elasticsearch tests require RUN_ELASTICSEARCH_TESTS=true"
|
||||
)
|
||||
class TestElasticsearchStorage(unittest.TestCase):
|
||||
"""Test Elasticsearch storage functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.es_mock = MagicMock()
|
||||
self.es_mock.indices.exists.return_value = False
|
||||
|
||||
self.embedder_mock = MagicMock()
|
||||
self.embedder_mock.embed_documents.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
self.es_patcher = patch(
|
||||
"crewai.memory.storage.elasticsearch_storage.Elasticsearch",
|
||||
return_value=self.es_mock
|
||||
)
|
||||
self.es_class_mock = self.es_patcher.start()
|
||||
|
||||
self.storage = ElasticsearchStorage(
|
||||
type="test",
|
||||
embedder_config=self.embedder_mock
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Tear down test fixtures."""
|
||||
self.es_patcher.stop()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test initialization of Elasticsearch storage."""
|
||||
self.es_class_mock.assert_called_once()
|
||||
|
||||
self.es_mock.indices.create.assert_called_once()
|
||||
|
||||
def test_save(self):
|
||||
"""Test saving to Elasticsearch storage."""
|
||||
self.storage.save("Test document", {"source": "test"})
|
||||
|
||||
self.es_mock.index.assert_called_once()
|
||||
|
||||
self.embedder_mock.embed_documents.assert_called_once_with(["Test document"])
|
||||
|
||||
def test_search(self):
|
||||
"""Test searching in Elasticsearch storage."""
|
||||
self.es_mock.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "test_id",
|
||||
"_score": 1.5, # Score between 1-2 (Elasticsearch range)
|
||||
"_source": {
|
||||
"text": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
results = self.storage.search("test query")
|
||||
|
||||
self.es_mock.search.assert_called_once()
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]["id"], "test_id")
|
||||
self.assertEqual(results[0]["context"], "Test document")
|
||||
self.assertEqual(results[0]["metadata"], {"source": "test"})
|
||||
self.assertEqual(results[0]["score"], 0.5) # Adjusted to 0-1 range
|
||||
|
||||
def test_reset(self):
|
||||
"""Test resetting Elasticsearch storage."""
|
||||
self.es_mock.indices.exists.return_value = True
|
||||
|
||||
self.storage.reset()
|
||||
|
||||
self.es_mock.indices.delete.assert_called_once()
|
||||
|
||||
self.assertEqual(self.es_mock.indices.create.call_count, 2)
|
||||
33
uv.lock
generated
33
uv.lock
generated
@@ -1,5 +1,4 @@
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.10, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11' and platform_python_implementation == 'PyPy' and sys_platform == 'darwin'",
|
||||
@@ -628,6 +627,7 @@ dependencies = [
|
||||
{ name = "blinker" },
|
||||
{ name = "chromadb" },
|
||||
{ name = "click" },
|
||||
{ name = "elasticsearch" },
|
||||
{ name = "instructor" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "json5" },
|
||||
@@ -710,6 +710,7 @@ requires-dist = [
|
||||
{ name = "click", specifier = ">=8.1.7" },
|
||||
{ name = "crewai-tools", marker = "extra == 'tools'", specifier = "~=0.40.1" },
|
||||
{ name = "docling", marker = "extra == 'docling'", specifier = ">=2.12.0" },
|
||||
{ name = "elasticsearch", specifier = ">=9.0.0" },
|
||||
{ name = "fastembed", marker = "extra == 'fastembed'", specifier = ">=0.4.1" },
|
||||
{ name = "instructor", specifier = ">=1.3.3" },
|
||||
{ name = "json-repair", specifier = ">=0.25.2" },
|
||||
@@ -735,7 +736,6 @@ requires-dist = [
|
||||
{ name = "tomli-w", specifier = ">=1.1.0" },
|
||||
{ name = "uv", specifier = ">=0.4.25" },
|
||||
]
|
||||
provides-extras = ["tools", "embeddings", "agentops", "fastembed", "pdfplumber", "pandas", "openpyxl", "mem0", "docling", "aisuite"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
@@ -1097,6 +1097,33 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/84/4a2cab0e6adde6a85e7ba543862e5fc0250c51f3ac721a078a55cdcff250/easyocr-1.7.2-py3-none-any.whl", hash = "sha256:5be12f9b0e595d443c9c3d10b0542074b50f0ec2d98b141a109cd961fd1c177c", size = 2870178 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "elastic-transport"
|
||||
version = "8.17.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6a/54/d498a766ac8fa475f931da85a154666cc81a70f8eb4a780bc8e4e934e9ac/elastic_transport-8.17.1.tar.gz", hash = "sha256:5edef32ac864dca8e2f0a613ef63491ee8d6b8cfb52881fa7313ba9290cac6d2", size = 73425 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/cd/b71d5bc74cde7fc6fd9b2ff9389890f45d9762cbbbf81dc5e51fd7588c4a/elastic_transport-8.17.1-py3-none-any.whl", hash = "sha256:192718f498f1d10c5e9aa8b9cf32aed405e469a7f0e9d6a8923431dbb2c59fb8", size = 64969 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "elasticsearch"
|
||||
version = "9.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "elastic-transport" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/26/ad/d76e88811e68d7bdd976c0ff6027e7c3b544a949c8d3de052adc5765e1a6/elasticsearch-9.0.0.tar.gz", hash = "sha256:c075ccdc7d5697e2a842a88418efdb6cf6732d7a62c77a25d60184db23fd1464", size = 823636 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/b7/e85bdb8bed719dbf92780264d6c381186ced8eb7acc88bbe37a996f87b03/elasticsearch-9.0.0-py3-none-any.whl", hash = "sha256:295425172043e5db723d55cb3a5e28622696ca7739b466b812ab12ac938b6249", size = 895793 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "embedchain"
|
||||
version = "0.1.125"
|
||||
@@ -2988,6 +3015,7 @@ name = "nvidia-nccl-cu12"
|
||||
version = "2.20.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", size = 176238458 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/2a/0a131f572aa09f741c30ccd45a8e56316e8be8dfc7bc19bf0ab7cfef7b19/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56", size = 176249402 },
|
||||
]
|
||||
|
||||
@@ -2997,6 +3025,7 @@ version = "12.6.85"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971 },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user