mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Fix import sorting issues in Elasticsearch integration
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -6,6 +6,13 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
|
||||
try:
|
||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import (
|
||||
ElasticsearchKnowledgeStorage,
|
||||
)
|
||||
except ImportError:
|
||||
ElasticsearchKnowledgeStorage = None
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
|
||||
@@ -39,9 +46,8 @@ class Knowledge(BaseModel):
|
||||
else:
|
||||
if storage_provider == "elasticsearch":
|
||||
try:
|
||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage
|
||||
self.storage = cast(KnowledgeStorage, ElasticsearchKnowledgeStorage(
|
||||
embedder_config=embedder, collection_name=collection_name
|
||||
self.storage = cast(KnowledgeStorage, self._create_elasticsearch_storage(
|
||||
embedder, collection_name
|
||||
))
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -84,6 +90,16 @@ class Knowledge(BaseModel):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _create_elasticsearch_storage(self, embedder, collection_name):
|
||||
"""Create an Elasticsearch storage instance."""
|
||||
if ElasticsearchKnowledgeStorage is None:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
return ElasticsearchKnowledgeStorage(
|
||||
embedder_config=embedder, collection_name=collection_name
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -88,7 +88,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
query_part = script_score_obj.get("query", {})
|
||||
if isinstance(query_part, dict):
|
||||
for key, value in filter.items():
|
||||
script_score_obj["query"] = {
|
||||
new_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query_part,
|
||||
@@ -96,6 +96,8 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
]
|
||||
}
|
||||
}
|
||||
if isinstance(script_score_obj, dict):
|
||||
script_score_obj["query"] = new_query
|
||||
|
||||
with suppress_logging():
|
||||
if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
|
||||
@@ -212,7 +214,8 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
}
|
||||
|
||||
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||
self.app.index(
|
||||
index_func = getattr(self.app, "index")
|
||||
index_func(
|
||||
index=self.index_name,
|
||||
id=doc_id,
|
||||
document=doc_body,
|
||||
@@ -232,9 +235,11 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
embedder = self.embedder_config
|
||||
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||
return embedder.embed_documents([text])[0]
|
||||
embed_func = getattr(embedder, "embed_documents")
|
||||
return embed_func([text])[0]
|
||||
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||
return embedder.embed(text)
|
||||
embed_func = getattr(embedder, "embed")
|
||||
return embed_func(text)
|
||||
else:
|
||||
raise ValueError("Invalid embedding function configuration")
|
||||
|
||||
|
||||
@@ -34,18 +34,17 @@ class EntityMemory(Memory):
|
||||
storage = Mem0Storage(type="entities", crew=crew)
|
||||
elif memory_provider == "elasticsearch":
|
||||
try:
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
storage = self._create_elasticsearch_storage(
|
||||
type="entities",
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
storage = ElasticsearchStorage(
|
||||
type="entities",
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
else:
|
||||
storage = RAGStorage(
|
||||
type="entities",
|
||||
@@ -71,6 +70,11 @@ class EntityMemory(Memory):
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
super().save(data, item.metadata)
|
||||
|
||||
def _create_elasticsearch_storage(self, **kwargs):
|
||||
"""Create an Elasticsearch storage instance."""
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
return ElasticsearchStorage(**kwargs)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -36,17 +36,16 @@ class ShortTermMemory(Memory):
|
||||
storage = Mem0Storage(type="short_term", crew=crew)
|
||||
elif memory_provider == "elasticsearch":
|
||||
try:
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
storage = self._create_elasticsearch_storage(
|
||||
type="short_term",
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
storage = ElasticsearchStorage(
|
||||
type="short_term",
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
else:
|
||||
storage = RAGStorage(
|
||||
type="short_term",
|
||||
@@ -79,6 +78,11 @@ class ShortTermMemory(Memory):
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
|
||||
def _create_elasticsearch_storage(self, **kwargs):
|
||||
"""Create an Elasticsearch storage instance."""
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
return ElasticsearchStorage(**kwargs)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -181,7 +181,7 @@ class ElasticsearchStorage(BaseRAGStorage):
|
||||
query_part = script_score_obj.get("query", {})
|
||||
if isinstance(query_part, dict):
|
||||
for key, value in filter.items():
|
||||
script_score_obj["query"] = {
|
||||
new_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query_part,
|
||||
@@ -189,10 +189,13 @@ class ElasticsearchStorage(BaseRAGStorage):
|
||||
]
|
||||
}
|
||||
}
|
||||
if isinstance(script_score_obj, dict):
|
||||
script_score_obj["query"] = new_query
|
||||
|
||||
with suppress_logging():
|
||||
if self.app is not None:
|
||||
response = self.app.search(
|
||||
if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
|
||||
search_func = getattr(self.app, "search")
|
||||
response = search_func(
|
||||
index=self.index_name,
|
||||
body=search_query
|
||||
)
|
||||
@@ -224,9 +227,11 @@ class ElasticsearchStorage(BaseRAGStorage):
|
||||
|
||||
embedder = self.embedder_config
|
||||
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||
return embedder.embed_documents([text])[0]
|
||||
embed_func = getattr(embedder, "embed_documents")
|
||||
return embed_func([text])[0]
|
||||
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||
return embedder.embed(text)
|
||||
embed_func = getattr(embedder, "embed")
|
||||
return embed_func(text)
|
||||
else:
|
||||
raise ValueError("Invalid embedding function configuration")
|
||||
|
||||
@@ -247,7 +252,8 @@ class ElasticsearchStorage(BaseRAGStorage):
|
||||
}
|
||||
|
||||
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||
result = self.app.index(
|
||||
index_func = getattr(self.app, "index")
|
||||
result = index_func(
|
||||
index=self.index_name,
|
||||
id=str(uuid.uuid4()),
|
||||
document=doc,
|
||||
|
||||
@@ -4,6 +4,16 @@ from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
try:
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
except ImportError:
|
||||
ElasticsearchStorage = None
|
||||
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
Mem0Storage = None
|
||||
|
||||
|
||||
class StorageFactory:
|
||||
"""Factory for creating storage instances based on provider type."""
|
||||
@@ -34,17 +44,7 @@ class StorageFactory:
|
||||
Storage instance.
|
||||
"""
|
||||
if provider == "elasticsearch":
|
||||
try:
|
||||
from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage
|
||||
return ElasticsearchStorage(
|
||||
type=type,
|
||||
allow_reset=allow_reset,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
except ImportError:
|
||||
if ElasticsearchStorage is None:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`.",
|
||||
@@ -53,11 +53,16 @@ class StorageFactory:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
return ElasticsearchStorage(
|
||||
type=type,
|
||||
allow_reset=allow_reset,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
elif provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew))
|
||||
except ImportError:
|
||||
if Mem0Storage is None:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`.",
|
||||
@@ -66,6 +71,7 @@ class StorageFactory:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew))
|
||||
return RAGStorage(
|
||||
type=type,
|
||||
allow_reset=allow_reset,
|
||||
|
||||
@@ -6,6 +6,8 @@ import unittest
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -54,9 +56,6 @@ class TestElasticsearchIntegration(unittest.TestCase):
|
||||
|
||||
def test_crew_with_elasticsearch_knowledge(self):
|
||||
"""Test a crew with Elasticsearch knowledge."""
|
||||
from crewai.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
|
||||
content = "AI is a field of computer science that focuses on creating machines that can perform tasks that typically require human intelligence."
|
||||
string_source = StringKnowledgeSource(
|
||||
content=content, metadata={"topic": "AI"}
|
||||
|
||||
@@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage
|
||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import (
|
||||
ElasticsearchKnowledgeStorage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
Reference in New Issue
Block a user