mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
249 lines
8.3 KiB
Python
249 lines
8.3 KiB
Python
import contextlib
|
|
import io
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
|
from crewai.utilities import EmbeddingConfigurator
|
|
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
|
from crewai.utilities.logger import Logger
|
|
from crewai.utilities.paths import db_storage_path
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def suppress_logging(logger_name="elasticsearch", level=logging.ERROR):
|
|
logger = logging.getLogger(logger_name)
|
|
original_level = logger.getEffectiveLevel()
|
|
logger.setLevel(level)
|
|
with (
|
|
contextlib.redirect_stdout(io.StringIO()),
|
|
contextlib.redirect_stderr(io.StringIO()),
|
|
contextlib.suppress(UserWarning),
|
|
):
|
|
yield
|
|
logger.setLevel(original_level)
|
|
|
|
|
|
class ElasticsearchStorage(BaseRAGStorage):
|
|
"""
|
|
Extends BaseRAGStorage to use Elasticsearch for storing embeddings
|
|
and improving search efficiency.
|
|
"""
|
|
|
|
app: Any | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
type,
|
|
allow_reset=True,
|
|
embedder_config=None,
|
|
crew=None,
|
|
path=None,
|
|
host="localhost",
|
|
port=9200,
|
|
username=None,
|
|
password=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(type, allow_reset, embedder_config, crew)
|
|
agents = crew.agents if crew else []
|
|
agents = [self._sanitize_role(agent.role) for agent in agents]
|
|
agents = "_".join(agents)
|
|
self.agents = agents
|
|
self.storage_file_name = self._build_storage_file_name(type, agents)
|
|
|
|
self.type = type
|
|
self.allow_reset = allow_reset
|
|
self.path = path
|
|
|
|
self.host = host
|
|
self.port = port
|
|
self.username = username
|
|
self.password = password
|
|
self.index_name = f"crewai_{type}".lower()
|
|
self.additional_config = kwargs
|
|
|
|
self._initialize_app()
|
|
|
|
def _sanitize_role(self, role: str) -> str:
|
|
"""
|
|
Sanitizes agent roles to ensure valid directory and index names.
|
|
"""
|
|
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
|
|
|
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
|
"""
|
|
Ensures file name does not exceed max allowed by OS
|
|
"""
|
|
base_path = f"{db_storage_path()}/{type}"
|
|
|
|
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
|
logging.warning(
|
|
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
|
)
|
|
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
|
|
|
return f"{base_path}/{file_name}"
|
|
|
|
def _set_embedder_config(self):
|
|
configurator = EmbeddingConfigurator()
|
|
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
|
|
|
def _initialize_app(self):
|
|
try:
|
|
from elasticsearch import Elasticsearch
|
|
|
|
self._set_embedder_config()
|
|
|
|
es_auth = {}
|
|
if self.username and self.password:
|
|
es_auth = {"basic_auth": (self.username, self.password)}
|
|
|
|
self.app = Elasticsearch(
|
|
[f"http://{self.host}:{self.port}"],
|
|
**es_auth,
|
|
**self.additional_config
|
|
)
|
|
|
|
if not self.app.indices.exists(index=self.index_name):
|
|
self.app.indices.create(
|
|
index=self.index_name,
|
|
body={
|
|
"mappings": {
|
|
"properties": {
|
|
"text": {"type": "text"},
|
|
"embedding": {
|
|
"type": "dense_vector",
|
|
"dims": 1536, # Default for OpenAI embeddings
|
|
"index": True,
|
|
"similarity": "cosine"
|
|
},
|
|
"metadata": {"type": "object"}
|
|
}
|
|
}
|
|
}
|
|
)
|
|
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
|
)
|
|
except Exception as e:
|
|
Logger(verbose=True).log(
|
|
"error",
|
|
f"Error initializing Elasticsearch: {str(e)}",
|
|
"red"
|
|
)
|
|
raise Exception(f"Error initializing Elasticsearch: {str(e)}")
|
|
|
|
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
|
if not hasattr(self, "app"):
|
|
self._initialize_app()
|
|
|
|
try:
|
|
self._generate_embedding(value, metadata)
|
|
except Exception as e:
|
|
logging.error(f"Error during {self.type} save: {str(e)}")
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 3,
|
|
filter: Optional[dict] = None,
|
|
score_threshold: float = 0.35,
|
|
) -> List[Any]:
|
|
if not hasattr(self, "app") or self.app is None:
|
|
self._initialize_app()
|
|
|
|
try:
|
|
embedding = self._get_embedding_for_text(query)
|
|
|
|
search_query = {
|
|
"size": limit,
|
|
"query": {
|
|
"script_score": {
|
|
"query": {"match_all": {}},
|
|
"script": {
|
|
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
|
"params": {"query_vector": embedding}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if filter:
|
|
for key, value in filter.items():
|
|
search_query["query"]["script_score"]["query"] = {
|
|
"bool": {
|
|
"must": [
|
|
search_query["query"]["script_score"]["query"],
|
|
{"match": {f"metadata.{key}": value}}
|
|
]
|
|
}
|
|
}
|
|
|
|
with suppress_logging():
|
|
response = self.app.search(
|
|
index=self.index_name,
|
|
body=search_query
|
|
)
|
|
|
|
results = []
|
|
for hit in response["hits"]["hits"]:
|
|
adjusted_score = (hit["_score"] - 1.0)
|
|
|
|
if adjusted_score >= score_threshold:
|
|
results.append({
|
|
"id": hit["_id"],
|
|
"metadata": hit["_source"]["metadata"],
|
|
"context": hit["_source"]["text"],
|
|
"score": adjusted_score,
|
|
})
|
|
|
|
return results
|
|
except Exception as e:
|
|
logging.error(f"Error during {self.type} search: {str(e)}")
|
|
return []
|
|
|
|
def _get_embedding_for_text(self, text: str) -> List[float]:
|
|
"""Get embedding for text using the configured embedder."""
|
|
if hasattr(self.embedder_config, "embed_documents"):
|
|
return self.embedder_config.embed_documents([text])[0]
|
|
elif hasattr(self.embedder_config, "embed"):
|
|
return self.embedder_config.embed(text)
|
|
else:
|
|
raise ValueError("Invalid embedding function configuration")
|
|
|
|
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None:
|
|
if not hasattr(self, "app") or self.app is None:
|
|
self._initialize_app()
|
|
|
|
embedding = self._get_embedding_for_text(text)
|
|
|
|
doc = {
|
|
"text": text,
|
|
"embedding": embedding,
|
|
"metadata": metadata or {},
|
|
}
|
|
|
|
self.app.index(
|
|
index=self.index_name,
|
|
id=str(uuid.uuid4()),
|
|
document=doc,
|
|
refresh=True # Make the document immediately available for search
|
|
)
|
|
|
|
def reset(self) -> None:
|
|
try:
|
|
if self.app:
|
|
if self.app.indices.exists(index=self.index_name):
|
|
self.app.indices.delete(index=self.index_name)
|
|
|
|
self._initialize_app()
|
|
except Exception as e:
|
|
raise Exception(
|
|
f"An error occurred while resetting the {self.type} memory: {e}"
|
|
)
|