Fix type-checking errors in Elasticsearch integration

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-23 05:44:47 +00:00
parent 3c838f16ff
commit 958751fe36
4 changed files with 145 additions and 106 deletions

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
from pydantic import BaseModel, ConfigDict, Field
@@ -40,9 +40,9 @@ class Knowledge(BaseModel):
if storage_provider == "elasticsearch":
try:
from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage
self.storage = ElasticsearchKnowledgeStorage(
embedder=embedder, collection_name=collection_name
)
self.storage = cast(KnowledgeStorage, ElasticsearchKnowledgeStorage(
embedder_config=embedder, collection_name=collection_name
))
except ImportError:
raise ImportError(
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
@@ -52,7 +52,8 @@ class Knowledge(BaseModel):
embedder=embedder, collection_name=collection_name
)
self.sources = sources
self.storage.initialize_knowledge_storage()
if self.storage is not None:
self.storage.initialize_knowledge_storage()
self._add_sources()
def query(

View File

@@ -31,21 +31,21 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
and improving search efficiency.
"""
app = None
app: Any = None
collection_name: Optional[str] = "knowledge"
def __init__(
self,
embedder: Optional[Dict[str, Any]] = None,
embedder_config: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None,
host="localhost",
port=9200,
username=None,
password=None,
**kwargs
host: str = "localhost",
port: int = 9200,
username: Optional[str] = None,
password: Optional[str] = None,
**kwargs: Any
):
self.collection_name = collection_name
self._set_embedder_config(embedder)
self._set_embedder_config(embedder_config)
self.host = host
self.port = port
@@ -67,7 +67,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
try:
embedding = self._get_embedding_for_text(query[0])
search_query = {
search_query: Dict[str, Any] = {
"size": limit,
"query": {
"script_score": {
@@ -81,35 +81,45 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
}
if filter:
for key, value in filter.items():
search_query["query"]["script_score"]["query"] = {
"bool": {
"must": [
search_query["query"]["script_score"]["query"],
{"match": {f"metadata.{key}": value}}
]
}
}
query_obj = search_query.get("query", {})
if isinstance(query_obj, dict):
script_score_obj = query_obj.get("script_score", {})
if isinstance(script_score_obj, dict):
query_part = script_score_obj.get("query", {})
if isinstance(query_part, dict):
for key, value in filter.items():
script_score_obj["query"] = {
"bool": {
"must": [
query_part,
{"match": {f"metadata.{key}": value}}
]
}
}
with suppress_logging():
response = self.app.search(
index=self.index_name,
body=search_query
)
results = []
for hit in response["hits"]["hits"]:
adjusted_score = (hit["_score"] - 1.0)
if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
response = self.app.search(
index=self.index_name,
body=search_query
)
if adjusted_score >= score_threshold:
results.append({
"id": hit["_id"],
"metadata": hit["_source"]["metadata"],
"context": hit["_source"]["text"],
"score": adjusted_score,
})
return results
results = []
for hit in response["hits"]["hits"]:
adjusted_score = (hit["_score"] - 1.0)
if adjusted_score >= score_threshold:
results.append({
"id": hit["_id"],
"metadata": hit["_source"]["metadata"],
"context": hit["_source"]["text"],
"score": adjusted_score,
})
return results
else:
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
return []
except Exception as e:
Logger(verbose=True).log("error", f"Search error: {e}", "red")
raise Exception(f"Error during knowledge search: {str(e)}")
@@ -159,9 +169,9 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
)
raise Exception(f"Error initializing Elasticsearch: {str(e)}")
def reset(self):
def reset(self) -> None:
try:
if self.app:
if self.app is not None:
if self.app.indices.exists(index=self.index_name):
self.app.indices.delete(index=self.index_name)
@@ -175,7 +185,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
self,
documents: List[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
):
) -> None:
if not self.app:
self.initialize_knowledge_storage()
@@ -201,12 +211,15 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
"metadata": meta or {},
}
self.app.index(
index=self.index_name,
id=doc_id,
document=doc_body,
refresh=True # Make the document immediately available for search
)
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
self.app.index(
index=self.index_name,
id=doc_id,
document=doc_body,
refresh=True # Make the document immediately available for search
)
else:
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
except Exception as e:
Logger(verbose=True).log("error", f"Save error: {e}", "red")
@@ -214,10 +227,14 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
def _get_embedding_for_text(self, text: str) -> List[float]:
"""Get embedding for text using the configured embedder."""
if hasattr(self.embedder_config, "embed_documents"):
return self.embedder_config.embed_documents([text])[0]
elif hasattr(self.embedder_config, "embed"):
return self.embedder_config.embed(text)
if self.embedder_config is None:
raise ValueError("Embedder configuration is not set")
embedder = self.embedder_config
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
return embedder.embed_documents([text])[0]
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
return embedder.embed(text)
else:
raise ValueError("Invalid embedding function configuration")

View File

@@ -3,7 +3,7 @@ import io
import logging
import os
import uuid
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator
@@ -32,20 +32,20 @@ class ElasticsearchStorage(BaseRAGStorage):
and improving search efficiency.
"""
app: Any | None = None
app: Any = None
def __init__(
self,
type,
allow_reset=True,
embedder_config=None,
crew=None,
path=None,
host="localhost",
port=9200,
username=None,
password=None,
**kwargs
type: str,
allow_reset: bool = True,
embedder_config: Any = None,
crew: Any = None,
path: Optional[str] = None,
host: str = "localhost",
port: int = 9200,
username: Optional[str] = None,
password: Optional[str] = None,
**kwargs: Any
):
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
@@ -160,7 +160,7 @@ class ElasticsearchStorage(BaseRAGStorage):
try:
embedding = self._get_embedding_for_text(query)
search_query = {
search_query: Dict[str, Any] = {
"size": limit,
"query": {
"script_score": {
@@ -174,49 +174,67 @@ class ElasticsearchStorage(BaseRAGStorage):
}
if filter:
for key, value in filter.items():
search_query["query"]["script_score"]["query"] = {
"bool": {
"must": [
search_query["query"]["script_score"]["query"],
{"match": {f"metadata.{key}": value}}
]
}
}
query_obj = search_query.get("query", {})
if isinstance(query_obj, dict):
script_score_obj = query_obj.get("script_score", {})
if isinstance(script_score_obj, dict):
query_part = script_score_obj.get("query", {})
if isinstance(query_part, dict):
for key, value in filter.items():
script_score_obj["query"] = {
"bool": {
"must": [
query_part,
{"match": {f"metadata.{key}": value}}
]
}
}
with suppress_logging():
response = self.app.search(
index=self.index_name,
body=search_query
)
results = []
for hit in response["hits"]["hits"]:
adjusted_score = (hit["_score"] - 1.0)
if adjusted_score >= score_threshold:
results.append({
"id": hit["_id"],
"metadata": hit["_source"]["metadata"],
"context": hit["_source"]["text"],
"score": adjusted_score,
})
if self.app is not None:
response = self.app.search(
index=self.index_name,
body=search_query
)
return results
results = []
for hit in response["hits"]["hits"]:
adjusted_score = (hit["_score"] - 1.0)
if adjusted_score >= score_threshold:
results.append({
"id": hit["_id"],
"metadata": hit["_source"]["metadata"],
"context": hit["_source"]["text"],
"score": adjusted_score,
})
return results
else:
logging.error("Elasticsearch client is not initialized")
return []
except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}")
return []
def _get_embedding_for_text(self, text: str) -> List[float]:
"""Get embedding for text using the configured embedder."""
if hasattr(self.embedder_config, "embed_documents"):
return self.embedder_config.embed_documents([text])[0]
elif hasattr(self.embedder_config, "embed"):
return self.embedder_config.embed(text)
if self.embedder_config is None:
raise ValueError("Embedder configuration is not set")
embedder = self.embedder_config
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
return embedder.embed_documents([text])[0]
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
return embedder.embed(text)
else:
raise ValueError("Invalid embedding function configuration")
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None:
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
"""Generate embedding for text and save to Elasticsearch.
This method overrides the BaseRAGStorage method to use Elasticsearch.
"""
if not hasattr(self, "app") or self.app is None:
self._initialize_app()
@@ -228,16 +246,19 @@ class ElasticsearchStorage(BaseRAGStorage):
"metadata": metadata or {},
}
self.app.index(
index=self.index_name,
id=str(uuid.uuid4()),
document=doc,
refresh=True # Make the document immediately available for search
)
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
result = self.app.index(
index=self.index_name,
id=str(uuid.uuid4()),
document=doc,
refresh=True # Make the document immediately available for search
)
return result
return None
def reset(self) -> None:
try:
if self.app:
if self.app is not None:
if self.app.indices.exists(index=self.index_name):
self.app.indices.delete(index=self.index_name)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Type, cast
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.memory.storage.rag_storage import RAGStorage
@@ -56,7 +56,7 @@ class StorageFactory:
elif provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type=type, crew=crew)
return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew))
except ImportError:
Logger(verbose=True).log(
"error",