From 6c08e6062a7bc116de95a0704c0601ab201d8a6f Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 05:27:53 +0000 Subject: [PATCH] Add Elasticsearch integration for RAG storage Co-Authored-By: Joe Moura --- docs/how-to/elasticsearch-integration.md | 117 +++++++++ src/crewai/knowledge/knowledge.py | 18 +- .../elasticsearch_knowledge_storage.py | 246 +++++++++++++++++ src/crewai/memory/entity/entity_memory.py | 34 ++- .../memory/short_term/short_term_memory.py | 31 ++- .../memory/storage/elasticsearch_storage.py | 248 ++++++++++++++++++ src/crewai/memory/storage/storage_factory.py | 75 ++++++ .../elasticsearch_integration_test.py | 91 +++++++ .../elasticsearch_knowledge_storage_test.py | 92 +++++++ tests/memory/elasticsearch_storage_test.py | 91 +++++++ 10 files changed, 1019 insertions(+), 24 deletions(-) create mode 100644 docs/how-to/elasticsearch-integration.md create mode 100644 src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py create mode 100644 src/crewai/memory/storage/elasticsearch_storage.py create mode 100644 src/crewai/memory/storage/storage_factory.py create mode 100644 tests/integration/elasticsearch_integration_test.py create mode 100644 tests/knowledge/elasticsearch_knowledge_storage_test.py create mode 100644 tests/memory/elasticsearch_storage_test.py diff --git a/docs/how-to/elasticsearch-integration.md b/docs/how-to/elasticsearch-integration.md new file mode 100644 index 000000000..c71143fc9 --- /dev/null +++ b/docs/how-to/elasticsearch-integration.md @@ -0,0 +1,117 @@ +# Elasticsearch Integration + +CrewAI supports using Elasticsearch as an alternative to ChromaDB for RAG (Retrieval Augmented Generation) storage. This allows you to leverage Elasticsearch's powerful search capabilities and scalability for your AI agents. + +## Installation + +To use Elasticsearch with CrewAI, you need to install the Elasticsearch Python client: + +```bash +pip install elasticsearch +``` + +## Using Elasticsearch for Memory + +You can configure your crew to use Elasticsearch for memory storage: + +```python +from crewai import Agent, Crew, Task + +# Create agents and tasks +agent = Agent( + role="Researcher", + goal="Research a topic", + backstory="You are a researcher who loves to find information.", +) + +task = Task( + description="Research about AI", + expected_output="Information about AI", + agent=agent, +) + +# Create a crew with Elasticsearch memory +crew = Crew( + agents=[agent], + tasks=[task], + memory_config={ + "provider": "elasticsearch", + "host": "localhost", # Optional, defaults to localhost + "port": 9200, # Optional, defaults to 9200 + "username": "user", # Optional + "password": "pass", # Optional + }, +) + +# Execute the crew +result = crew.kickoff() +``` + +## Using Elasticsearch for Knowledge + +You can also use Elasticsearch for knowledge storage: + +```python +from crewai import Agent, Crew, Task +from crewai.knowledge import Knowledge +from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource + +# Create knowledge with Elasticsearch storage +content = "AI is a field of computer science that focuses on creating machines that can perform tasks that typically require human intelligence." +string_source = StringKnowledgeSource( + content=content, metadata={"topic": "AI"} +) + +knowledge = Knowledge( + collection_name="test", + sources=[string_source], + storage_provider="elasticsearch", # Use Elasticsearch + # Optional Elasticsearch configuration + host="localhost", + port=9200, + username="user", + password="pass", +) + +# Create an agent with the knowledge +agent = Agent( + role="AI Expert", + goal="Explain AI", + backstory="You are an AI expert who loves to explain AI concepts.", + knowledge=[knowledge], +) + +# Create a task +task = Task( + description="Explain what AI is", + expected_output="Explanation of AI", + agent=agent, +) + +# Create a crew +crew = Crew( + agents=[agent], + tasks=[task], +) + +# Execute the crew +result = crew.kickoff() +``` + +## Configuration Options + +The Elasticsearch integration supports the following configuration options: + +- `host`: Elasticsearch host (default: "localhost") +- `port`: Elasticsearch port (default: 9200) +- `username`: Elasticsearch username (optional) +- `password`: Elasticsearch password (optional) +- Additional keyword arguments are passed directly to the Elasticsearch client + +## Running Tests + +To run the Elasticsearch tests, you need to set the `RUN_ELASTICSEARCH_TESTS` environment variable to `true`: + +```bash +RUN_ELASTICSEARCH_TESTS=true pytest tests/memory/elasticsearch_storage_test.py tests/knowledge/elasticsearch_knowledge_storage_test.py tests/integration/elasticsearch_integration_test.py +``` diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index 824325d12..d0f8514d4 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -30,15 +30,27 @@ class Knowledge(BaseModel): sources: List[BaseKnowledgeSource], embedder: Optional[Dict[str, Any]] = None, storage: Optional[KnowledgeStorage] = None, + storage_provider: str = "chromadb", **data, ): super().__init__(**data) if storage: self.storage = storage else: - self.storage = KnowledgeStorage( - embedder=embedder, collection_name=collection_name - ) + if storage_provider == "elasticsearch": + try: + from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage + self.storage = ElasticsearchKnowledgeStorage( + embedder=embedder, collection_name=collection_name + ) + except ImportError: + raise ImportError( + "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." + ) + else: + self.storage = KnowledgeStorage( + embedder=embedder, collection_name=collection_name + ) self.sources = sources self.storage.initialize_knowledge_storage() self._add_sources() diff --git a/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py b/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py new file mode 100644 index 000000000..2544eeeea --- /dev/null +++ b/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py @@ -0,0 +1,246 @@ +import contextlib +import hashlib +import io +import logging +import os +from typing import Any, Dict, List, Optional, Union, cast + +from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage +from crewai.utilities import EmbeddingConfigurator +from crewai.utilities.logger import Logger +from crewai.utilities.paths import db_storage_path + + +@contextlib.contextmanager +def suppress_logging(logger_name="elasticsearch", level=logging.ERROR): + logger = logging.getLogger(logger_name) + original_level = logger.getEffectiveLevel() + logger.setLevel(level) + with ( + contextlib.redirect_stdout(io.StringIO()), + contextlib.redirect_stderr(io.StringIO()), + contextlib.suppress(UserWarning), + ): + yield + logger.setLevel(original_level) + + +class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): + """ + Extends BaseKnowledgeStorage to use Elasticsearch for storing embeddings + and improving search efficiency. + """ + + app = None + collection_name: Optional[str] = "knowledge" + + def __init__( + self, + embedder: Optional[Dict[str, Any]] = None, + collection_name: Optional[str] = None, + host="localhost", + port=9200, + username=None, + password=None, + **kwargs + ): + self.collection_name = collection_name + self._set_embedder_config(embedder) + + self.host = host + self.port = port + self.username = username + self.password = password + self.index_name = f"crewai_knowledge_{collection_name if collection_name else 'default'}".lower() + self.additional_config = kwargs + + def search( + self, + query: List[str], + limit: int = 3, + filter: Optional[dict] = None, + score_threshold: float = 0.35, + ) -> List[Dict[str, Any]]: + if not self.app: + self.initialize_knowledge_storage() + + try: + embedding = self._get_embedding_for_text(query[0]) + + search_query = { + "size": limit, + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", + "params": {"query_vector": embedding} + } + } + } + } + + if filter: + for key, value in filter.items(): + search_query["query"]["script_score"]["query"] = { + "bool": { + "must": [ + search_query["query"]["script_score"]["query"], + {"match": {f"metadata.{key}": value}} + ] + } + } + + with suppress_logging(): + response = self.app.search( + index=self.index_name, + body=search_query + ) + + results = [] + for hit in response["hits"]["hits"]: + adjusted_score = (hit["_score"] - 1.0) + + if adjusted_score >= score_threshold: + results.append({ + "id": hit["_id"], + "metadata": hit["_source"]["metadata"], + "context": hit["_source"]["text"], + "score": adjusted_score, + }) + + return results + except Exception as e: + Logger(verbose=True).log("error", f"Search error: {e}", "red") + raise Exception(f"Error during knowledge search: {str(e)}") + + def initialize_knowledge_storage(self): + try: + from elasticsearch import Elasticsearch + + es_auth = {} + if self.username and self.password: + es_auth = {"basic_auth": (self.username, self.password)} + + self.app = Elasticsearch( + [f"http://{self.host}:{self.port}"], + **es_auth, + **self.additional_config + ) + + if not self.app.indices.exists(index=self.index_name): + self.app.indices.create( + index=self.index_name, + body={ + "mappings": { + "properties": { + "text": {"type": "text"}, + "embedding": { + "type": "dense_vector", + "dims": 1536, # Default for OpenAI embeddings + "index": True, + "similarity": "cosine" + }, + "metadata": {"type": "object"} + } + } + } + ) + + except ImportError: + raise ImportError( + "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." + ) + except Exception as e: + Logger(verbose=True).log( + "error", + f"Error initializing Elasticsearch: {str(e)}", + "red" + ) + raise Exception(f"Error initializing Elasticsearch: {str(e)}") + + def reset(self): + try: + if self.app: + if self.app.indices.exists(index=self.index_name): + self.app.indices.delete(index=self.index_name) + + self.initialize_knowledge_storage() + except Exception as e: + raise Exception( + f"An error occurred while resetting the knowledge storage: {e}" + ) + + def save( + self, + documents: List[str], + metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + ): + if not self.app: + self.initialize_knowledge_storage() + + try: + unique_docs = {} + + for idx, doc in enumerate(documents): + doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest() + doc_metadata = None + if metadata is not None: + if isinstance(metadata, list): + doc_metadata = metadata[idx] + else: + doc_metadata = metadata + unique_docs[doc_id] = (doc, doc_metadata) + + for doc_id, (doc, meta) in unique_docs.items(): + embedding = self._get_embedding_for_text(doc) + + doc_body = { + "text": doc, + "embedding": embedding, + "metadata": meta or {}, + } + + self.app.index( + index=self.index_name, + id=doc_id, + document=doc_body, + refresh=True # Make the document immediately available for search + ) + + except Exception as e: + Logger(verbose=True).log("error", f"Save error: {e}", "red") + raise Exception(f"Error during knowledge save: {str(e)}") + + def _get_embedding_for_text(self, text: str) -> List[float]: + """Get embedding for text using the configured embedder.""" + if hasattr(self.embedder_config, "embed_documents"): + return self.embedder_config.embed_documents([text])[0] + elif hasattr(self.embedder_config, "embed"): + return self.embedder_config.embed(text) + else: + raise ValueError("Invalid embedding function configuration") + + def _create_default_embedding_function(self): + from chromadb.utils.embedding_functions.openai_embedding_function import ( + OpenAIEmbeddingFunction, + ) + + return OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" + ) + + def _set_embedder_config( + self, embedder: Optional[Dict[str, Any]] = None + ) -> None: + """Set the embedding configuration for the knowledge storage. + + Args: + embedder (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. + If None or empty, defaults to the default embedding function. + """ + self.embedder_config = ( + EmbeddingConfigurator().configure_embedder(embedder) + if embedder + else self._create_default_embedding_function() + ) diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 264b64103..07678499f 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -22,7 +22,9 @@ class EntityMemory(Memory): else: memory_provider = None - if memory_provider == "mem0": + if storage: + pass + elif memory_provider == "mem0": try: from crewai.memory.storage.mem0_storage import Mem0Storage except ImportError: @@ -30,17 +32,27 @@ class EntityMemory(Memory): "Mem0 is not installed. Please install it with `pip install mem0ai`." ) storage = Mem0Storage(type="entities", crew=crew) - else: - storage = ( - storage - if storage - else RAGStorage( - type="entities", - allow_reset=True, - embedder_config=embedder_config, - crew=crew, - path=path, + elif memory_provider == "elasticsearch": + try: + from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage + except ImportError: + raise ImportError( + "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." ) + storage = ElasticsearchStorage( + type="entities", + allow_reset=True, + embedder_config=embedder_config, + crew=crew, + path=path, + ) + else: + storage = RAGStorage( + type="entities", + allow_reset=True, + embedder_config=embedder_config, + crew=crew, + path=path, ) super().__init__(storage=storage) diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index b7581f400..b4ec90c0d 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -24,7 +24,9 @@ class ShortTermMemory(Memory): else: memory_provider = None - if memory_provider == "mem0": + if storage: + pass + elif memory_provider == "mem0": try: from crewai.memory.storage.mem0_storage import Mem0Storage except ImportError: @@ -32,16 +34,25 @@ class ShortTermMemory(Memory): "Mem0 is not installed. Please install it with `pip install mem0ai`." ) storage = Mem0Storage(type="short_term", crew=crew) - else: - storage = ( - storage - if storage - else RAGStorage( - type="short_term", - embedder_config=embedder_config, - crew=crew, - path=path, + elif memory_provider == "elasticsearch": + try: + from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage + except ImportError: + raise ImportError( + "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." ) + storage = ElasticsearchStorage( + type="short_term", + embedder_config=embedder_config, + crew=crew, + path=path, + ) + else: + storage = RAGStorage( + type="short_term", + embedder_config=embedder_config, + crew=crew, + path=path, ) super().__init__(storage=storage) self._memory_provider = memory_provider diff --git a/src/crewai/memory/storage/elasticsearch_storage.py b/src/crewai/memory/storage/elasticsearch_storage.py new file mode 100644 index 000000000..467de52a7 --- /dev/null +++ b/src/crewai/memory/storage/elasticsearch_storage.py @@ -0,0 +1,248 @@ +import contextlib +import io +import logging +import os +import uuid +from typing import Any, Dict, List, Optional + +from crewai.memory.storage.base_rag_storage import BaseRAGStorage +from crewai.utilities import EmbeddingConfigurator +from crewai.utilities.constants import MAX_FILE_NAME_LENGTH +from crewai.utilities.logger import Logger +from crewai.utilities.paths import db_storage_path + + +@contextlib.contextmanager +def suppress_logging(logger_name="elasticsearch", level=logging.ERROR): + logger = logging.getLogger(logger_name) + original_level = logger.getEffectiveLevel() + logger.setLevel(level) + with ( + contextlib.redirect_stdout(io.StringIO()), + contextlib.redirect_stderr(io.StringIO()), + contextlib.suppress(UserWarning), + ): + yield + logger.setLevel(original_level) + + +class ElasticsearchStorage(BaseRAGStorage): + """ + Extends BaseRAGStorage to use Elasticsearch for storing embeddings + and improving search efficiency. + """ + + app: Any | None = None + + def __init__( + self, + type, + allow_reset=True, + embedder_config=None, + crew=None, + path=None, + host="localhost", + port=9200, + username=None, + password=None, + **kwargs + ): + super().__init__(type, allow_reset, embedder_config, crew) + agents = crew.agents if crew else [] + agents = [self._sanitize_role(agent.role) for agent in agents] + agents = "_".join(agents) + self.agents = agents + self.storage_file_name = self._build_storage_file_name(type, agents) + + self.type = type + self.allow_reset = allow_reset + self.path = path + + self.host = host + self.port = port + self.username = username + self.password = password + self.index_name = f"crewai_{type}".lower() + self.additional_config = kwargs + + self._initialize_app() + + def _sanitize_role(self, role: str) -> str: + """ + Sanitizes agent roles to ensure valid directory and index names. + """ + return role.replace("\n", "").replace(" ", "_").replace("/", "_") + + def _build_storage_file_name(self, type: str, file_name: str) -> str: + """ + Ensures file name does not exceed max allowed by OS + """ + base_path = f"{db_storage_path()}/{type}" + + if len(file_name) > MAX_FILE_NAME_LENGTH: + logging.warning( + f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters." + ) + file_name = file_name[:MAX_FILE_NAME_LENGTH] + + return f"{base_path}/{file_name}" + + def _set_embedder_config(self): + configurator = EmbeddingConfigurator() + self.embedder_config = configurator.configure_embedder(self.embedder_config) + + def _initialize_app(self): + try: + from elasticsearch import Elasticsearch + + self._set_embedder_config() + + es_auth = {} + if self.username and self.password: + es_auth = {"basic_auth": (self.username, self.password)} + + self.app = Elasticsearch( + [f"http://{self.host}:{self.port}"], + **es_auth, + **self.additional_config + ) + + if not self.app.indices.exists(index=self.index_name): + self.app.indices.create( + index=self.index_name, + body={ + "mappings": { + "properties": { + "text": {"type": "text"}, + "embedding": { + "type": "dense_vector", + "dims": 1536, # Default for OpenAI embeddings + "index": True, + "similarity": "cosine" + }, + "metadata": {"type": "object"} + } + } + } + ) + + except ImportError: + raise ImportError( + "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." + ) + except Exception as e: + Logger(verbose=True).log( + "error", + f"Error initializing Elasticsearch: {str(e)}", + "red" + ) + raise Exception(f"Error initializing Elasticsearch: {str(e)}") + + def save(self, value: Any, metadata: Dict[str, Any]) -> None: + if not hasattr(self, "app"): + self._initialize_app() + + try: + self._generate_embedding(value, metadata) + except Exception as e: + logging.error(f"Error during {self.type} save: {str(e)}") + + def search( + self, + query: str, + limit: int = 3, + filter: Optional[dict] = None, + score_threshold: float = 0.35, + ) -> List[Any]: + if not hasattr(self, "app") or self.app is None: + self._initialize_app() + + try: + embedding = self._get_embedding_for_text(query) + + search_query = { + "size": limit, + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", + "params": {"query_vector": embedding} + } + } + } + } + + if filter: + for key, value in filter.items(): + search_query["query"]["script_score"]["query"] = { + "bool": { + "must": [ + search_query["query"]["script_score"]["query"], + {"match": {f"metadata.{key}": value}} + ] + } + } + + with suppress_logging(): + response = self.app.search( + index=self.index_name, + body=search_query + ) + + results = [] + for hit in response["hits"]["hits"]: + adjusted_score = (hit["_score"] - 1.0) + + if adjusted_score >= score_threshold: + results.append({ + "id": hit["_id"], + "metadata": hit["_source"]["metadata"], + "context": hit["_source"]["text"], + "score": adjusted_score, + }) + + return results + except Exception as e: + logging.error(f"Error during {self.type} search: {str(e)}") + return [] + + def _get_embedding_for_text(self, text: str) -> List[float]: + """Get embedding for text using the configured embedder.""" + if hasattr(self.embedder_config, "embed_documents"): + return self.embedder_config.embed_documents([text])[0] + elif hasattr(self.embedder_config, "embed"): + return self.embedder_config.embed(text) + else: + raise ValueError("Invalid embedding function configuration") + + def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: + if not hasattr(self, "app") or self.app is None: + self._initialize_app() + + embedding = self._get_embedding_for_text(text) + + doc = { + "text": text, + "embedding": embedding, + "metadata": metadata or {}, + } + + self.app.index( + index=self.index_name, + id=str(uuid.uuid4()), + document=doc, + refresh=True # Make the document immediately available for search + ) + + def reset(self) -> None: + try: + if self.app: + if self.app.indices.exists(index=self.index_name): + self.app.indices.delete(index=self.index_name) + + self._initialize_app() + except Exception as e: + raise Exception( + f"An error occurred while resetting the {self.type} memory: {e}" + ) diff --git a/src/crewai/memory/storage/storage_factory.py b/src/crewai/memory/storage/storage_factory.py new file mode 100644 index 000000000..073ee4f20 --- /dev/null +++ b/src/crewai/memory/storage/storage_factory.py @@ -0,0 +1,75 @@ +from typing import Any, Dict, Optional, Type + +from crewai.memory.storage.base_rag_storage import BaseRAGStorage +from crewai.memory.storage.rag_storage import RAGStorage +from crewai.utilities.logger import Logger + + +class StorageFactory: + """Factory for creating storage instances based on provider type.""" + + @classmethod + def create_storage( + cls, + provider: str, + type: str, + allow_reset: bool = True, + embedder_config: Optional[Any] = None, + crew: Any = None, + path: Optional[str] = None, + **kwargs, + ) -> BaseRAGStorage: + """Create a storage instance based on the provider type. + + Args: + provider: Type of storage provider ("chromadb", "elasticsearch", "mem0"). + type: Type of memory storage (e.g., "short_term", "entity"). + allow_reset: Whether to allow resetting the storage. + embedder_config: Configuration for the embedder. + crew: Crew instance. + path: Path to the storage. + **kwargs: Additional arguments. + + Returns: + Storage instance. + """ + if provider == "elasticsearch": + try: + from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage + return ElasticsearchStorage( + type=type, + allow_reset=allow_reset, + embedder_config=embedder_config, + crew=crew, + path=path, + **kwargs, + ) + except ImportError: + Logger(verbose=True).log( + "error", + "Elasticsearch is not installed. Please install it with `pip install elasticsearch`.", + "red", + ) + raise ImportError( + "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." + ) + elif provider == "mem0": + try: + from crewai.memory.storage.mem0_storage import Mem0Storage + return Mem0Storage(type=type, crew=crew) + except ImportError: + Logger(verbose=True).log( + "error", + "Mem0 is not installed. Please install it with `pip install mem0ai`.", + "red", + ) + raise ImportError( + "Mem0 is not installed. Please install it with `pip install mem0ai`." + ) + return RAGStorage( + type=type, + allow_reset=allow_reset, + embedder_config=embedder_config, + crew=crew, + path=path, + ) diff --git a/tests/integration/elasticsearch_integration_test.py b/tests/integration/elasticsearch_integration_test.py new file mode 100644 index 000000000..1b480d231 --- /dev/null +++ b/tests/integration/elasticsearch_integration_test.py @@ -0,0 +1,91 @@ +"""Integration test for Elasticsearch with CrewAI.""" + +import os +import unittest + +import pytest + +from crewai import Agent, Crew, Task + + +@pytest.mark.skipif( + os.environ.get("RUN_ELASTICSEARCH_TESTS") != "true", + reason="Elasticsearch tests require RUN_ELASTICSEARCH_TESTS=true" +) +class TestElasticsearchIntegration(unittest.TestCase): + """Integration test for Elasticsearch with CrewAI.""" + + def test_crew_with_elasticsearch_memory(self): + """Test a crew with Elasticsearch memory.""" + researcher = Agent( + role="Researcher", + goal="Research a topic", + backstory="You are a researcher who loves to find information.", + ) + + writer = Agent( + role="Writer", + goal="Write about a topic", + backstory="You are a writer who loves to write about topics.", + ) + + research_task = Task( + description="Research about AI", + expected_output="Information about AI", + agent=researcher, + ) + + write_task = Task( + description="Write about AI", + expected_output="Article about AI", + agent=writer, + context=[research_task], + ) + + crew = Crew( + agents=[researcher, writer], + tasks=[research_task, write_task], + memory_config={"provider": "elasticsearch"}, + ) + + result = crew.kickoff() + + self.assertIsNotNone(result) + + def test_crew_with_elasticsearch_knowledge(self): + """Test a crew with Elasticsearch knowledge.""" + from crewai.knowledge import Knowledge + from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource + + content = "AI is a field of computer science that focuses on creating machines that can perform tasks that typically require human intelligence." + string_source = StringKnowledgeSource( + content=content, metadata={"topic": "AI"} + ) + + knowledge = Knowledge( + collection_name="test", + sources=[string_source], + storage_provider="elasticsearch", + ) + + agent = Agent( + role="AI Expert", + goal="Explain AI", + backstory="You are an AI expert who loves to explain AI concepts.", + knowledge=[knowledge], + ) + + task = Task( + description="Explain what AI is", + expected_output="Explanation of AI", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task], + ) + + result = crew.kickoff() + + self.assertIsNotNone(result) diff --git a/tests/knowledge/elasticsearch_knowledge_storage_test.py b/tests/knowledge/elasticsearch_knowledge_storage_test.py new file mode 100644 index 000000000..f890390da --- /dev/null +++ b/tests/knowledge/elasticsearch_knowledge_storage_test.py @@ -0,0 +1,92 @@ +"""Test Elasticsearch knowledge storage functionality.""" + +import os +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage + + +@pytest.mark.skipif( + os.environ.get("RUN_ELASTICSEARCH_TESTS") != "true", + reason="Elasticsearch tests require RUN_ELASTICSEARCH_TESTS=true" +) +class TestElasticsearchKnowledgeStorage(unittest.TestCase): + """Test Elasticsearch knowledge storage functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.es_mock = MagicMock() + self.es_mock.indices.exists.return_value = False + + self.embedder_mock = MagicMock() + self.embedder_mock.embed_documents.return_value = [[0.1, 0.2, 0.3]] + + self.es_patcher = patch( + "crewai.knowledge.storage.elasticsearch_knowledge_storage.Elasticsearch", + return_value=self.es_mock + ) + self.es_class_mock = self.es_patcher.start() + + self.storage = ElasticsearchKnowledgeStorage( + embedder_config=self.embedder_mock, + collection_name="test" + ) + self.storage.initialize_knowledge_storage() + + def tearDown(self): + """Tear down test fixtures.""" + self.es_patcher.stop() + + def test_initialization(self): + """Test initialization of Elasticsearch knowledge storage.""" + self.es_class_mock.assert_called_once() + + self.es_mock.indices.create.assert_called_once() + + def test_save(self): + """Test saving to Elasticsearch knowledge storage.""" + self.storage.save(["Test document 1", "Test document 2"], {"source": "test"}) + + self.assertEqual(self.es_mock.index.call_count, 2) + + self.assertEqual(self.embedder_mock.embed_documents.call_count, 2) + + def test_search(self): + """Test searching in Elasticsearch knowledge storage.""" + self.es_mock.search.return_value = { + "hits": { + "hits": [ + { + "_id": "test_id", + "_score": 1.5, # Score between 1-2 (Elasticsearch range) + "_source": { + "text": "Test document", + "metadata": {"source": "test"}, + } + } + ] + } + } + + results = self.storage.search(["test query"]) + + self.es_mock.search.assert_called_once() + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["id"], "test_id") + self.assertEqual(results[0]["context"], "Test document") + self.assertEqual(results[0]["metadata"], {"source": "test"}) + self.assertEqual(results[0]["score"], 0.5) # Adjusted to 0-1 range + + def test_reset(self): + """Test resetting Elasticsearch knowledge storage.""" + self.es_mock.indices.exists.return_value = True + + self.storage.reset() + + self.es_mock.indices.delete.assert_called_once() + + self.assertEqual(self.es_mock.indices.create.call_count, 2) diff --git a/tests/memory/elasticsearch_storage_test.py b/tests/memory/elasticsearch_storage_test.py new file mode 100644 index 000000000..2f4fae391 --- /dev/null +++ b/tests/memory/elasticsearch_storage_test.py @@ -0,0 +1,91 @@ +"""Test Elasticsearch storage functionality.""" + +import os +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.memory.storage.elasticsearch_storage import ElasticsearchStorage + + +@pytest.mark.skipif( + os.environ.get("RUN_ELASTICSEARCH_TESTS") != "true", + reason="Elasticsearch tests require RUN_ELASTICSEARCH_TESTS=true" +) +class TestElasticsearchStorage(unittest.TestCase): + """Test Elasticsearch storage functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.es_mock = MagicMock() + self.es_mock.indices.exists.return_value = False + + self.embedder_mock = MagicMock() + self.embedder_mock.embed_documents.return_value = [[0.1, 0.2, 0.3]] + + self.es_patcher = patch( + "crewai.memory.storage.elasticsearch_storage.Elasticsearch", + return_value=self.es_mock + ) + self.es_class_mock = self.es_patcher.start() + + self.storage = ElasticsearchStorage( + type="test", + embedder_config=self.embedder_mock + ) + + def tearDown(self): + """Tear down test fixtures.""" + self.es_patcher.stop() + + def test_initialization(self): + """Test initialization of Elasticsearch storage.""" + self.es_class_mock.assert_called_once() + + self.es_mock.indices.create.assert_called_once() + + def test_save(self): + """Test saving to Elasticsearch storage.""" + self.storage.save("Test document", {"source": "test"}) + + self.es_mock.index.assert_called_once() + + self.embedder_mock.embed_documents.assert_called_once_with(["Test document"]) + + def test_search(self): + """Test searching in Elasticsearch storage.""" + self.es_mock.search.return_value = { + "hits": { + "hits": [ + { + "_id": "test_id", + "_score": 1.5, # Score between 1-2 (Elasticsearch range) + "_source": { + "text": "Test document", + "metadata": {"source": "test"}, + } + } + ] + } + } + + results = self.storage.search("test query") + + self.es_mock.search.assert_called_once() + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["id"], "test_id") + self.assertEqual(results[0]["context"], "Test document") + self.assertEqual(results[0]["metadata"], {"source": "test"}) + self.assertEqual(results[0]["score"], 0.5) # Adjusted to 0-1 range + + def test_reset(self): + """Test resetting Elasticsearch storage.""" + self.es_mock.indices.exists.return_value = True + + self.storage.reset() + + self.es_mock.indices.delete.assert_called_once() + + self.assertEqual(self.es_mock.indices.create.call_count, 2)