diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index b27d3d212..180130b3b 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -6,6 +6,13 @@ from pydantic import BaseModel, ConfigDict, Field from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage +try: + from crewai.knowledge.storage.elasticsearch_knowledge_storage import ( + ElasticsearchKnowledgeStorage, + ) +except ImportError: + ElasticsearchKnowledgeStorage = None + os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed @@ -39,9 +46,8 @@ class Knowledge(BaseModel): else: if storage_provider == "elasticsearch": try: - from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage - self.storage = cast(KnowledgeStorage, ElasticsearchKnowledgeStorage( - embedder_config=embedder, collection_name=collection_name + self.storage = cast(KnowledgeStorage, self._create_elasticsearch_storage( + embedder, collection_name )) except ImportError: raise ImportError( @@ -84,6 +90,16 @@ class Knowledge(BaseModel): except Exception as e: raise e + def _create_elasticsearch_storage(self, embedder, collection_name): + """Create an Elasticsearch storage instance.""" + if ElasticsearchKnowledgeStorage is None: + raise ImportError( + "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." + ) + return ElasticsearchKnowledgeStorage( + embedder_config=embedder, collection_name=collection_name + ) + def reset(self) -> None: if self.storage: self.storage.reset() diff --git a/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py b/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py index c4b33636c..0d4a55abd 100644 --- a/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py +++ b/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py @@ -88,7 +88,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): query_part = script_score_obj.get("query", {}) if isinstance(query_part, dict): for key, value in filter.items(): - script_score_obj["query"] = { + new_query = { "bool": { "must": [ query_part, @@ -96,6 +96,8 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): ] } } + if isinstance(script_score_obj, dict): + script_score_obj["query"] = new_query with suppress_logging(): if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")): @@ -212,7 +214,8 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): } if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")): - self.app.index( + index_func = getattr(self.app, "index") + index_func( index=self.index_name, id=doc_id, document=doc_body, @@ -232,9 +235,11 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): embedder = self.embedder_config if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")): - return embedder.embed_documents([text])[0] + embed_func = getattr(embedder, "embed_documents") + return embed_func([text])[0] elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")): - return embedder.embed(text) + embed_func = getattr(embedder, "embed") + return embed_func(text) else: raise ValueError("Invalid embedding function configuration") diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 07678499f..a44cb3d97 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -34,18 +34,17 @@ class EntityMemory(Memory): storage = Mem0Storage(type="entities", crew=crew) elif memory_provider == "elasticsearch": try: - from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage + storage = self._create_elasticsearch_storage( + type="entities", + allow_reset=True, + embedder_config=embedder_config, + crew=crew, + path=path, + ) except ImportError: raise ImportError( "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." ) - storage = ElasticsearchStorage( - type="entities", - allow_reset=True, - embedder_config=embedder_config, - crew=crew, - path=path, - ) else: storage = RAGStorage( type="entities", @@ -71,6 +70,11 @@ class EntityMemory(Memory): data = f"{item.name}({item.type}): {item.description}" super().save(data, item.metadata) + def _create_elasticsearch_storage(self, **kwargs): + """Create an Elasticsearch storage instance.""" + from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage + return ElasticsearchStorage(**kwargs) + def reset(self) -> None: try: self.storage.reset() diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index b4ec90c0d..9df1fc9d8 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -36,17 +36,16 @@ class ShortTermMemory(Memory): storage = Mem0Storage(type="short_term", crew=crew) elif memory_provider == "elasticsearch": try: - from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage + storage = self._create_elasticsearch_storage( + type="short_term", + embedder_config=embedder_config, + crew=crew, + path=path, + ) except ImportError: raise ImportError( "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." ) - storage = ElasticsearchStorage( - type="short_term", - embedder_config=embedder_config, - crew=crew, - path=path, - ) else: storage = RAGStorage( type="short_term", @@ -79,6 +78,11 @@ class ShortTermMemory(Memory): query=query, limit=limit, score_threshold=score_threshold ) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters + def _create_elasticsearch_storage(self, **kwargs): + """Create an Elasticsearch storage instance.""" + from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage + return ElasticsearchStorage(**kwargs) + def reset(self) -> None: try: self.storage.reset() diff --git a/src/crewai/memory/storage/elasticsearch_storage.py b/src/crewai/memory/storage/elasticsearch_storage.py index 25b493622..384ffe600 100644 --- a/src/crewai/memory/storage/elasticsearch_storage.py +++ b/src/crewai/memory/storage/elasticsearch_storage.py @@ -181,7 +181,7 @@ class ElasticsearchStorage(BaseRAGStorage): query_part = script_score_obj.get("query", {}) if isinstance(query_part, dict): for key, value in filter.items(): - script_score_obj["query"] = { + new_query = { "bool": { "must": [ query_part, @@ -189,10 +189,13 @@ class ElasticsearchStorage(BaseRAGStorage): ] } } + if isinstance(script_score_obj, dict): + script_score_obj["query"] = new_query with suppress_logging(): - if self.app is not None: - response = self.app.search( + if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")): + search_func = getattr(self.app, "search") + response = search_func( index=self.index_name, body=search_query ) @@ -224,9 +227,11 @@ class ElasticsearchStorage(BaseRAGStorage): embedder = self.embedder_config if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")): - return embedder.embed_documents([text])[0] + embed_func = getattr(embedder, "embed_documents") + return embed_func([text])[0] elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")): - return embedder.embed(text) + embed_func = getattr(embedder, "embed") + return embed_func(text) else: raise ValueError("Invalid embedding function configuration") @@ -247,7 +252,8 @@ class ElasticsearchStorage(BaseRAGStorage): } if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")): - result = self.app.index( + index_func = getattr(self.app, "index") + result = index_func( index=self.index_name, id=str(uuid.uuid4()), document=doc, diff --git a/src/crewai/memory/storage/storage_factory.py b/src/crewai/memory/storage/storage_factory.py index 0ee73f0ec..ebda541fd 100644 --- a/src/crewai/memory/storage/storage_factory.py +++ b/src/crewai/memory/storage/storage_factory.py @@ -4,6 +4,16 @@ from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.rag_storage import RAGStorage from crewai.utilities.logger import Logger +try: + from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage +except ImportError: + ElasticsearchStorage = None + +try: + from crewai.memory.storage.mem0_storage import Mem0Storage +except ImportError: + Mem0Storage = None + class StorageFactory: """Factory for creating storage instances based on provider type.""" @@ -34,17 +44,7 @@ class StorageFactory: Storage instance. """ if provider == "elasticsearch": - try: - from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage - return ElasticsearchStorage( - type=type, - allow_reset=allow_reset, - embedder_config=embedder_config, - crew=crew, - path=path, - **kwargs, - ) - except ImportError: + if ElasticsearchStorage is None: Logger(verbose=True).log( "error", "Elasticsearch is not installed. Please install it with `pip install elasticsearch`.", @@ -53,11 +53,16 @@ class StorageFactory: raise ImportError( "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." ) + return ElasticsearchStorage( + type=type, + allow_reset=allow_reset, + embedder_config=embedder_config, + crew=crew, + path=path, + **kwargs, + ) elif provider == "mem0": - try: - from crewai.memory.storage.mem0_storage import Mem0Storage - return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew)) - except ImportError: + if Mem0Storage is None: Logger(verbose=True).log( "error", "Mem0 is not installed. Please install it with `pip install mem0ai`.", @@ -66,6 +71,7 @@ class StorageFactory: raise ImportError( "Mem0 is not installed. Please install it with `pip install mem0ai`." ) + return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew)) return RAGStorage( type=type, allow_reset=allow_reset, diff --git a/tests/integration/elasticsearch_integration_test.py b/tests/integration/elasticsearch_integration_test.py index 1b480d231..521486f85 100644 --- a/tests/integration/elasticsearch_integration_test.py +++ b/tests/integration/elasticsearch_integration_test.py @@ -6,6 +6,8 @@ import unittest import pytest from crewai import Agent, Crew, Task +from crewai.knowledge import Knowledge +from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource @pytest.mark.skipif( @@ -54,9 +56,6 @@ class TestElasticsearchIntegration(unittest.TestCase): def test_crew_with_elasticsearch_knowledge(self): """Test a crew with Elasticsearch knowledge.""" - from crewai.knowledge import Knowledge - from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource - content = "AI is a field of computer science that focuses on creating machines that can perform tasks that typically require human intelligence." string_source = StringKnowledgeSource( content=content, metadata={"topic": "AI"} diff --git a/tests/knowledge/elasticsearch_knowledge_storage_test.py b/tests/knowledge/elasticsearch_knowledge_storage_test.py index f890390da..28945fb9e 100644 --- a/tests/knowledge/elasticsearch_knowledge_storage_test.py +++ b/tests/knowledge/elasticsearch_knowledge_storage_test.py @@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch import pytest -from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage +from crewai.knowledge.storage.elasticsearch_knowledge_storage import ( + ElasticsearchKnowledgeStorage, +) @pytest.mark.skipif(