diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index d0f8514d4..b27d3d212 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast from pydantic import BaseModel, ConfigDict, Field @@ -40,9 +40,9 @@ class Knowledge(BaseModel): if storage_provider == "elasticsearch": try: from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage - self.storage = ElasticsearchKnowledgeStorage( - embedder=embedder, collection_name=collection_name - ) + self.storage = cast(KnowledgeStorage, ElasticsearchKnowledgeStorage( + embedder_config=embedder, collection_name=collection_name + )) except ImportError: raise ImportError( "Elasticsearch is not installed. Please install it with `pip install elasticsearch`." @@ -52,7 +52,8 @@ class Knowledge(BaseModel): embedder=embedder, collection_name=collection_name ) self.sources = sources - self.storage.initialize_knowledge_storage() + if self.storage is not None: + self.storage.initialize_knowledge_storage() self._add_sources() def query( diff --git a/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py b/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py index 2544eeeea..c4b33636c 100644 --- a/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py +++ b/src/crewai/knowledge/storage/elasticsearch_knowledge_storage.py @@ -31,21 +31,21 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): and improving search efficiency. """ - app = None + app: Any = None collection_name: Optional[str] = "knowledge" def __init__( self, - embedder: Optional[Dict[str, Any]] = None, + embedder_config: Optional[Dict[str, Any]] = None, collection_name: Optional[str] = None, - host="localhost", - port=9200, - username=None, - password=None, - **kwargs + host: str = "localhost", + port: int = 9200, + username: Optional[str] = None, + password: Optional[str] = None, + **kwargs: Any ): self.collection_name = collection_name - self._set_embedder_config(embedder) + self._set_embedder_config(embedder_config) self.host = host self.port = port @@ -67,7 +67,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): try: embedding = self._get_embedding_for_text(query[0]) - search_query = { + search_query: Dict[str, Any] = { "size": limit, "query": { "script_score": { @@ -81,35 +81,45 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): } if filter: - for key, value in filter.items(): - search_query["query"]["script_score"]["query"] = { - "bool": { - "must": [ - search_query["query"]["script_score"]["query"], - {"match": {f"metadata.{key}": value}} - ] - } - } + query_obj = search_query.get("query", {}) + if isinstance(query_obj, dict): + script_score_obj = query_obj.get("script_score", {}) + if isinstance(script_score_obj, dict): + query_part = script_score_obj.get("query", {}) + if isinstance(query_part, dict): + for key, value in filter.items(): + script_score_obj["query"] = { + "bool": { + "must": [ + query_part, + {"match": {f"metadata.{key}": value}} + ] + } + } with suppress_logging(): - response = self.app.search( - index=self.index_name, - body=search_query - ) - - results = [] - for hit in response["hits"]["hits"]: - adjusted_score = (hit["_score"] - 1.0) + if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")): + response = self.app.search( + index=self.index_name, + body=search_query + ) - if adjusted_score >= score_threshold: - results.append({ - "id": hit["_id"], - "metadata": hit["_source"]["metadata"], - "context": hit["_source"]["text"], - "score": adjusted_score, - }) - - return results + results = [] + for hit in response["hits"]["hits"]: + adjusted_score = (hit["_score"] - 1.0) + + if adjusted_score >= score_threshold: + results.append({ + "id": hit["_id"], + "metadata": hit["_source"]["metadata"], + "context": hit["_source"]["text"], + "score": adjusted_score, + }) + + return results + else: + Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red") + return [] except Exception as e: Logger(verbose=True).log("error", f"Search error: {e}", "red") raise Exception(f"Error during knowledge search: {str(e)}") @@ -159,9 +169,9 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): ) raise Exception(f"Error initializing Elasticsearch: {str(e)}") - def reset(self): + def reset(self) -> None: try: - if self.app: + if self.app is not None: if self.app.indices.exists(index=self.index_name): self.app.indices.delete(index=self.index_name) @@ -175,7 +185,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): self, documents: List[str], metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, - ): + ) -> None: if not self.app: self.initialize_knowledge_storage() @@ -201,12 +211,15 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): "metadata": meta or {}, } - self.app.index( - index=self.index_name, - id=doc_id, - document=doc_body, - refresh=True # Make the document immediately available for search - ) + if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")): + self.app.index( + index=self.index_name, + id=doc_id, + document=doc_body, + refresh=True # Make the document immediately available for search + ) + else: + Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red") except Exception as e: Logger(verbose=True).log("error", f"Save error: {e}", "red") @@ -214,10 +227,14 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage): def _get_embedding_for_text(self, text: str) -> List[float]: """Get embedding for text using the configured embedder.""" - if hasattr(self.embedder_config, "embed_documents"): - return self.embedder_config.embed_documents([text])[0] - elif hasattr(self.embedder_config, "embed"): - return self.embedder_config.embed(text) + if self.embedder_config is None: + raise ValueError("Embedder configuration is not set") + + embedder = self.embedder_config + if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")): + return embedder.embed_documents([text])[0] + elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")): + return embedder.embed(text) else: raise ValueError("Invalid embedding function configuration") diff --git a/src/crewai/memory/storage/elasticsearch_storage.py b/src/crewai/memory/storage/elasticsearch_storage.py index 467de52a7..25b493622 100644 --- a/src/crewai/memory/storage/elasticsearch_storage.py +++ b/src/crewai/memory/storage/elasticsearch_storage.py @@ -3,7 +3,7 @@ import io import logging import os import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities import EmbeddingConfigurator @@ -32,20 +32,20 @@ class ElasticsearchStorage(BaseRAGStorage): and improving search efficiency. """ - app: Any | None = None + app: Any = None def __init__( self, - type, - allow_reset=True, - embedder_config=None, - crew=None, - path=None, - host="localhost", - port=9200, - username=None, - password=None, - **kwargs + type: str, + allow_reset: bool = True, + embedder_config: Any = None, + crew: Any = None, + path: Optional[str] = None, + host: str = "localhost", + port: int = 9200, + username: Optional[str] = None, + password: Optional[str] = None, + **kwargs: Any ): super().__init__(type, allow_reset, embedder_config, crew) agents = crew.agents if crew else [] @@ -160,7 +160,7 @@ class ElasticsearchStorage(BaseRAGStorage): try: embedding = self._get_embedding_for_text(query) - search_query = { + search_query: Dict[str, Any] = { "size": limit, "query": { "script_score": { @@ -174,49 +174,67 @@ class ElasticsearchStorage(BaseRAGStorage): } if filter: - for key, value in filter.items(): - search_query["query"]["script_score"]["query"] = { - "bool": { - "must": [ - search_query["query"]["script_score"]["query"], - {"match": {f"metadata.{key}": value}} - ] - } - } + query_obj = search_query.get("query", {}) + if isinstance(query_obj, dict): + script_score_obj = query_obj.get("script_score", {}) + if isinstance(script_score_obj, dict): + query_part = script_score_obj.get("query", {}) + if isinstance(query_part, dict): + for key, value in filter.items(): + script_score_obj["query"] = { + "bool": { + "must": [ + query_part, + {"match": {f"metadata.{key}": value}} + ] + } + } with suppress_logging(): - response = self.app.search( - index=self.index_name, - body=search_query - ) - - results = [] - for hit in response["hits"]["hits"]: - adjusted_score = (hit["_score"] - 1.0) - - if adjusted_score >= score_threshold: - results.append({ - "id": hit["_id"], - "metadata": hit["_source"]["metadata"], - "context": hit["_source"]["text"], - "score": adjusted_score, - }) + if self.app is not None: + response = self.app.search( + index=self.index_name, + body=search_query + ) - return results + results = [] + for hit in response["hits"]["hits"]: + adjusted_score = (hit["_score"] - 1.0) + + if adjusted_score >= score_threshold: + results.append({ + "id": hit["_id"], + "metadata": hit["_source"]["metadata"], + "context": hit["_source"]["text"], + "score": adjusted_score, + }) + + return results + else: + logging.error("Elasticsearch client is not initialized") + return [] except Exception as e: logging.error(f"Error during {self.type} search: {str(e)}") return [] def _get_embedding_for_text(self, text: str) -> List[float]: """Get embedding for text using the configured embedder.""" - if hasattr(self.embedder_config, "embed_documents"): - return self.embedder_config.embed_documents([text])[0] - elif hasattr(self.embedder_config, "embed"): - return self.embedder_config.embed(text) + if self.embedder_config is None: + raise ValueError("Embedder configuration is not set") + + embedder = self.embedder_config + if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")): + return embedder.embed_documents([text])[0] + elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")): + return embedder.embed(text) else: raise ValueError("Invalid embedding function configuration") - def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: + def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any: + """Generate embedding for text and save to Elasticsearch. + + This method overrides the BaseRAGStorage method to use Elasticsearch. + """ if not hasattr(self, "app") or self.app is None: self._initialize_app() @@ -228,16 +246,19 @@ class ElasticsearchStorage(BaseRAGStorage): "metadata": metadata or {}, } - self.app.index( - index=self.index_name, - id=str(uuid.uuid4()), - document=doc, - refresh=True # Make the document immediately available for search - ) + if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")): + result = self.app.index( + index=self.index_name, + id=str(uuid.uuid4()), + document=doc, + refresh=True # Make the document immediately available for search + ) + return result + return None def reset(self) -> None: try: - if self.app: + if self.app is not None: if self.app.indices.exists(index=self.index_name): self.app.indices.delete(index=self.index_name) diff --git a/src/crewai/memory/storage/storage_factory.py b/src/crewai/memory/storage/storage_factory.py index 073ee4f20..0ee73f0ec 100644 --- a/src/crewai/memory/storage/storage_factory.py +++ b/src/crewai/memory/storage/storage_factory.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Type, cast from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.rag_storage import RAGStorage @@ -56,7 +56,7 @@ class StorageFactory: elif provider == "mem0": try: from crewai.memory.storage.mem0_storage import Mem0Storage - return Mem0Storage(type=type, crew=crew) + return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew)) except ImportError: Logger(verbose=True).log( "error",