Fix type-checking errors in Elasticsearch integration

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-23 05:44:47 +00:00
parent 3c838f16ff
commit 958751fe36
4 changed files with 145 additions and 106 deletions

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, cast
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@@ -40,9 +40,9 @@ class Knowledge(BaseModel):
if storage_provider == "elasticsearch": if storage_provider == "elasticsearch":
try: try:
from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage
self.storage = ElasticsearchKnowledgeStorage( self.storage = cast(KnowledgeStorage, ElasticsearchKnowledgeStorage(
embedder=embedder, collection_name=collection_name embedder_config=embedder, collection_name=collection_name
) ))
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`." "Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
@@ -52,7 +52,8 @@ class Knowledge(BaseModel):
embedder=embedder, collection_name=collection_name embedder=embedder, collection_name=collection_name
) )
self.sources = sources self.sources = sources
self.storage.initialize_knowledge_storage() if self.storage is not None:
self.storage.initialize_knowledge_storage()
self._add_sources() self._add_sources()
def query( def query(

View File

@@ -31,21 +31,21 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
and improving search efficiency. and improving search efficiency.
""" """
app = None app: Any = None
collection_name: Optional[str] = "knowledge" collection_name: Optional[str] = "knowledge"
def __init__( def __init__(
self, self,
embedder: Optional[Dict[str, Any]] = None, embedder_config: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
host="localhost", host: str = "localhost",
port=9200, port: int = 9200,
username=None, username: Optional[str] = None,
password=None, password: Optional[str] = None,
**kwargs **kwargs: Any
): ):
self.collection_name = collection_name self.collection_name = collection_name
self._set_embedder_config(embedder) self._set_embedder_config(embedder_config)
self.host = host self.host = host
self.port = port self.port = port
@@ -67,7 +67,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
try: try:
embedding = self._get_embedding_for_text(query[0]) embedding = self._get_embedding_for_text(query[0])
search_query = { search_query: Dict[str, Any] = {
"size": limit, "size": limit,
"query": { "query": {
"script_score": { "script_score": {
@@ -81,35 +81,45 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
} }
if filter: if filter:
for key, value in filter.items(): query_obj = search_query.get("query", {})
search_query["query"]["script_score"]["query"] = { if isinstance(query_obj, dict):
"bool": { script_score_obj = query_obj.get("script_score", {})
"must": [ if isinstance(script_score_obj, dict):
search_query["query"]["script_score"]["query"], query_part = script_score_obj.get("query", {})
{"match": {f"metadata.{key}": value}} if isinstance(query_part, dict):
] for key, value in filter.items():
} script_score_obj["query"] = {
} "bool": {
"must": [
query_part,
{"match": {f"metadata.{key}": value}}
]
}
}
with suppress_logging(): with suppress_logging():
response = self.app.search( if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
index=self.index_name, response = self.app.search(
body=search_query index=self.index_name,
) body=search_query
)
results = []
for hit in response["hits"]["hits"]:
adjusted_score = (hit["_score"] - 1.0)
if adjusted_score >= score_threshold: results = []
results.append({ for hit in response["hits"]["hits"]:
"id": hit["_id"], adjusted_score = (hit["_score"] - 1.0)
"metadata": hit["_source"]["metadata"],
"context": hit["_source"]["text"], if adjusted_score >= score_threshold:
"score": adjusted_score, results.append({
}) "id": hit["_id"],
"metadata": hit["_source"]["metadata"],
return results "context": hit["_source"]["text"],
"score": adjusted_score,
})
return results
else:
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
return []
except Exception as e: except Exception as e:
Logger(verbose=True).log("error", f"Search error: {e}", "red") Logger(verbose=True).log("error", f"Search error: {e}", "red")
raise Exception(f"Error during knowledge search: {str(e)}") raise Exception(f"Error during knowledge search: {str(e)}")
@@ -159,9 +169,9 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
) )
raise Exception(f"Error initializing Elasticsearch: {str(e)}") raise Exception(f"Error initializing Elasticsearch: {str(e)}")
def reset(self): def reset(self) -> None:
try: try:
if self.app: if self.app is not None:
if self.app.indices.exists(index=self.index_name): if self.app.indices.exists(index=self.index_name):
self.app.indices.delete(index=self.index_name) self.app.indices.delete(index=self.index_name)
@@ -175,7 +185,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
self, self,
documents: List[str], documents: List[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
): ) -> None:
if not self.app: if not self.app:
self.initialize_knowledge_storage() self.initialize_knowledge_storage()
@@ -201,12 +211,15 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
"metadata": meta or {}, "metadata": meta or {},
} }
self.app.index( if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
index=self.index_name, self.app.index(
id=doc_id, index=self.index_name,
document=doc_body, id=doc_id,
refresh=True # Make the document immediately available for search document=doc_body,
) refresh=True # Make the document immediately available for search
)
else:
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
except Exception as e: except Exception as e:
Logger(verbose=True).log("error", f"Save error: {e}", "red") Logger(verbose=True).log("error", f"Save error: {e}", "red")
@@ -214,10 +227,14 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
def _get_embedding_for_text(self, text: str) -> List[float]: def _get_embedding_for_text(self, text: str) -> List[float]:
"""Get embedding for text using the configured embedder.""" """Get embedding for text using the configured embedder."""
if hasattr(self.embedder_config, "embed_documents"): if self.embedder_config is None:
return self.embedder_config.embed_documents([text])[0] raise ValueError("Embedder configuration is not set")
elif hasattr(self.embedder_config, "embed"):
return self.embedder_config.embed(text) embedder = self.embedder_config
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
return embedder.embed_documents([text])[0]
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
return embedder.embed(text)
else: else:
raise ValueError("Invalid embedding function configuration") raise ValueError("Invalid embedding function configuration")

View File

@@ -3,7 +3,7 @@ import io
import logging import logging
import os import os
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, cast
from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
@@ -32,20 +32,20 @@ class ElasticsearchStorage(BaseRAGStorage):
and improving search efficiency. and improving search efficiency.
""" """
app: Any | None = None app: Any = None
def __init__( def __init__(
self, self,
type, type: str,
allow_reset=True, allow_reset: bool = True,
embedder_config=None, embedder_config: Any = None,
crew=None, crew: Any = None,
path=None, path: Optional[str] = None,
host="localhost", host: str = "localhost",
port=9200, port: int = 9200,
username=None, username: Optional[str] = None,
password=None, password: Optional[str] = None,
**kwargs **kwargs: Any
): ):
super().__init__(type, allow_reset, embedder_config, crew) super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else [] agents = crew.agents if crew else []
@@ -160,7 +160,7 @@ class ElasticsearchStorage(BaseRAGStorage):
try: try:
embedding = self._get_embedding_for_text(query) embedding = self._get_embedding_for_text(query)
search_query = { search_query: Dict[str, Any] = {
"size": limit, "size": limit,
"query": { "query": {
"script_score": { "script_score": {
@@ -174,49 +174,67 @@ class ElasticsearchStorage(BaseRAGStorage):
} }
if filter: if filter:
for key, value in filter.items(): query_obj = search_query.get("query", {})
search_query["query"]["script_score"]["query"] = { if isinstance(query_obj, dict):
"bool": { script_score_obj = query_obj.get("script_score", {})
"must": [ if isinstance(script_score_obj, dict):
search_query["query"]["script_score"]["query"], query_part = script_score_obj.get("query", {})
{"match": {f"metadata.{key}": value}} if isinstance(query_part, dict):
] for key, value in filter.items():
} script_score_obj["query"] = {
} "bool": {
"must": [
query_part,
{"match": {f"metadata.{key}": value}}
]
}
}
with suppress_logging(): with suppress_logging():
response = self.app.search( if self.app is not None:
index=self.index_name, response = self.app.search(
body=search_query index=self.index_name,
) body=search_query
)
results = []
for hit in response["hits"]["hits"]:
adjusted_score = (hit["_score"] - 1.0)
if adjusted_score >= score_threshold:
results.append({
"id": hit["_id"],
"metadata": hit["_source"]["metadata"],
"context": hit["_source"]["text"],
"score": adjusted_score,
})
return results results = []
for hit in response["hits"]["hits"]:
adjusted_score = (hit["_score"] - 1.0)
if adjusted_score >= score_threshold:
results.append({
"id": hit["_id"],
"metadata": hit["_source"]["metadata"],
"context": hit["_source"]["text"],
"score": adjusted_score,
})
return results
else:
logging.error("Elasticsearch client is not initialized")
return []
except Exception as e: except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}") logging.error(f"Error during {self.type} search: {str(e)}")
return [] return []
def _get_embedding_for_text(self, text: str) -> List[float]: def _get_embedding_for_text(self, text: str) -> List[float]:
"""Get embedding for text using the configured embedder.""" """Get embedding for text using the configured embedder."""
if hasattr(self.embedder_config, "embed_documents"): if self.embedder_config is None:
return self.embedder_config.embed_documents([text])[0] raise ValueError("Embedder configuration is not set")
elif hasattr(self.embedder_config, "embed"):
return self.embedder_config.embed(text) embedder = self.embedder_config
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
return embedder.embed_documents([text])[0]
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
return embedder.embed(text)
else: else:
raise ValueError("Invalid embedding function configuration") raise ValueError("Invalid embedding function configuration")
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
"""Generate embedding for text and save to Elasticsearch.
This method overrides the BaseRAGStorage method to use Elasticsearch.
"""
if not hasattr(self, "app") or self.app is None: if not hasattr(self, "app") or self.app is None:
self._initialize_app() self._initialize_app()
@@ -228,16 +246,19 @@ class ElasticsearchStorage(BaseRAGStorage):
"metadata": metadata or {}, "metadata": metadata or {},
} }
self.app.index( if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
index=self.index_name, result = self.app.index(
id=str(uuid.uuid4()), index=self.index_name,
document=doc, id=str(uuid.uuid4()),
refresh=True # Make the document immediately available for search document=doc,
) refresh=True # Make the document immediately available for search
)
return result
return None
def reset(self) -> None: def reset(self) -> None:
try: try:
if self.app: if self.app is not None:
if self.app.indices.exists(index=self.index_name): if self.app.indices.exists(index=self.index_name):
self.app.indices.delete(index=self.index_name) self.app.indices.delete(index=self.index_name)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Type from typing import Any, Dict, Optional, Type, cast
from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.memory.storage.rag_storage import RAGStorage from crewai.memory.storage.rag_storage import RAGStorage
@@ -56,7 +56,7 @@ class StorageFactory:
elif provider == "mem0": elif provider == "mem0":
try: try:
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type=type, crew=crew) return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew))
except ImportError: except ImportError:
Logger(verbose=True).log( Logger(verbose=True).log(
"error", "error",