mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-12 01:28:30 +00:00
Fix type-checking errors in Elasticsearch integration
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -40,9 +40,9 @@ class Knowledge(BaseModel):
|
||||
if storage_provider == "elasticsearch":
|
||||
try:
|
||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage
|
||||
self.storage = ElasticsearchKnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
self.storage = cast(KnowledgeStorage, ElasticsearchKnowledgeStorage(
|
||||
embedder_config=embedder, collection_name=collection_name
|
||||
))
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||
@@ -52,7 +52,8 @@ class Knowledge(BaseModel):
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
if self.storage is not None:
|
||||
self.storage.initialize_knowledge_storage()
|
||||
self._add_sources()
|
||||
|
||||
def query(
|
||||
|
||||
@@ -31,21 +31,21 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
and improving search efficiency.
|
||||
"""
|
||||
|
||||
app = None
|
||||
app: Any = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
host="localhost",
|
||||
port=9200,
|
||||
username=None,
|
||||
password=None,
|
||||
**kwargs
|
||||
host: str = "localhost",
|
||||
port: int = 9200,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._set_embedder_config(embedder)
|
||||
self._set_embedder_config(embedder_config)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
@@ -67,7 +67,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
try:
|
||||
embedding = self._get_embedding_for_text(query[0])
|
||||
|
||||
search_query = {
|
||||
search_query: Dict[str, Any] = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"script_score": {
|
||||
@@ -81,35 +81,45 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
}
|
||||
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
search_query["query"]["script_score"]["query"] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
search_query["query"]["script_score"]["query"],
|
||||
{"match": {f"metadata.{key}": value}}
|
||||
]
|
||||
}
|
||||
}
|
||||
query_obj = search_query.get("query", {})
|
||||
if isinstance(query_obj, dict):
|
||||
script_score_obj = query_obj.get("script_score", {})
|
||||
if isinstance(script_score_obj, dict):
|
||||
query_part = script_score_obj.get("query", {})
|
||||
if isinstance(query_part, dict):
|
||||
for key, value in filter.items():
|
||||
script_score_obj["query"] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query_part,
|
||||
{"match": {f"metadata.{key}": value}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
with suppress_logging():
|
||||
response = self.app.search(
|
||||
index=self.index_name,
|
||||
body=search_query
|
||||
)
|
||||
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
adjusted_score = (hit["_score"] - 1.0)
|
||||
if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
|
||||
response = self.app.search(
|
||||
index=self.index_name,
|
||||
body=search_query
|
||||
)
|
||||
|
||||
if adjusted_score >= score_threshold:
|
||||
results.append({
|
||||
"id": hit["_id"],
|
||||
"metadata": hit["_source"]["metadata"],
|
||||
"context": hit["_source"]["text"],
|
||||
"score": adjusted_score,
|
||||
})
|
||||
|
||||
return results
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
adjusted_score = (hit["_score"] - 1.0)
|
||||
|
||||
if adjusted_score >= score_threshold:
|
||||
results.append({
|
||||
"id": hit["_id"],
|
||||
"metadata": hit["_source"]["metadata"],
|
||||
"context": hit["_source"]["text"],
|
||||
"score": adjusted_score,
|
||||
})
|
||||
|
||||
return results
|
||||
else:
|
||||
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
|
||||
return []
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log("error", f"Search error: {e}", "red")
|
||||
raise Exception(f"Error during knowledge search: {str(e)}")
|
||||
@@ -159,9 +169,9 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
raise Exception(f"Error initializing Elasticsearch: {str(e)}")
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
if self.app:
|
||||
if self.app is not None:
|
||||
if self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.delete(index=self.index_name)
|
||||
|
||||
@@ -175,7 +185,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
self,
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
):
|
||||
) -> None:
|
||||
if not self.app:
|
||||
self.initialize_knowledge_storage()
|
||||
|
||||
@@ -201,12 +211,15 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
"metadata": meta or {},
|
||||
}
|
||||
|
||||
self.app.index(
|
||||
index=self.index_name,
|
||||
id=doc_id,
|
||||
document=doc_body,
|
||||
refresh=True # Make the document immediately available for search
|
||||
)
|
||||
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||
self.app.index(
|
||||
index=self.index_name,
|
||||
id=doc_id,
|
||||
document=doc_body,
|
||||
refresh=True # Make the document immediately available for search
|
||||
)
|
||||
else:
|
||||
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
|
||||
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log("error", f"Save error: {e}", "red")
|
||||
@@ -214,10 +227,14 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def _get_embedding_for_text(self, text: str) -> List[float]:
|
||||
"""Get embedding for text using the configured embedder."""
|
||||
if hasattr(self.embedder_config, "embed_documents"):
|
||||
return self.embedder_config.embed_documents([text])[0]
|
||||
elif hasattr(self.embedder_config, "embed"):
|
||||
return self.embedder_config.embed(text)
|
||||
if self.embedder_config is None:
|
||||
raise ValueError("Embedder configuration is not set")
|
||||
|
||||
embedder = self.embedder_config
|
||||
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||
return embedder.embed_documents([text])[0]
|
||||
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||
return embedder.embed(text)
|
||||
else:
|
||||
raise ValueError("Invalid embedding function configuration")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
@@ -32,20 +32,20 @@ class ElasticsearchStorage(BaseRAGStorage):
|
||||
and improving search efficiency.
|
||||
"""
|
||||
|
||||
app: Any | None = None
|
||||
app: Any = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type,
|
||||
allow_reset=True,
|
||||
embedder_config=None,
|
||||
crew=None,
|
||||
path=None,
|
||||
host="localhost",
|
||||
port=9200,
|
||||
username=None,
|
||||
password=None,
|
||||
**kwargs
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Any = None,
|
||||
crew: Any = None,
|
||||
path: Optional[str] = None,
|
||||
host: str = "localhost",
|
||||
port: int = 9200,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
@@ -160,7 +160,7 @@ class ElasticsearchStorage(BaseRAGStorage):
|
||||
try:
|
||||
embedding = self._get_embedding_for_text(query)
|
||||
|
||||
search_query = {
|
||||
search_query: Dict[str, Any] = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"script_score": {
|
||||
@@ -174,49 +174,67 @@ class ElasticsearchStorage(BaseRAGStorage):
|
||||
}
|
||||
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
search_query["query"]["script_score"]["query"] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
search_query["query"]["script_score"]["query"],
|
||||
{"match": {f"metadata.{key}": value}}
|
||||
]
|
||||
}
|
||||
}
|
||||
query_obj = search_query.get("query", {})
|
||||
if isinstance(query_obj, dict):
|
||||
script_score_obj = query_obj.get("script_score", {})
|
||||
if isinstance(script_score_obj, dict):
|
||||
query_part = script_score_obj.get("query", {})
|
||||
if isinstance(query_part, dict):
|
||||
for key, value in filter.items():
|
||||
script_score_obj["query"] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
query_part,
|
||||
{"match": {f"metadata.{key}": value}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
with suppress_logging():
|
||||
response = self.app.search(
|
||||
index=self.index_name,
|
||||
body=search_query
|
||||
)
|
||||
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
adjusted_score = (hit["_score"] - 1.0)
|
||||
|
||||
if adjusted_score >= score_threshold:
|
||||
results.append({
|
||||
"id": hit["_id"],
|
||||
"metadata": hit["_source"]["metadata"],
|
||||
"context": hit["_source"]["text"],
|
||||
"score": adjusted_score,
|
||||
})
|
||||
if self.app is not None:
|
||||
response = self.app.search(
|
||||
index=self.index_name,
|
||||
body=search_query
|
||||
)
|
||||
|
||||
return results
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
adjusted_score = (hit["_score"] - 1.0)
|
||||
|
||||
if adjusted_score >= score_threshold:
|
||||
results.append({
|
||||
"id": hit["_id"],
|
||||
"metadata": hit["_source"]["metadata"],
|
||||
"context": hit["_source"]["text"],
|
||||
"score": adjusted_score,
|
||||
})
|
||||
|
||||
return results
|
||||
else:
|
||||
logging.error("Elasticsearch client is not initialized")
|
||||
return []
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _get_embedding_for_text(self, text: str) -> List[float]:
|
||||
"""Get embedding for text using the configured embedder."""
|
||||
if hasattr(self.embedder_config, "embed_documents"):
|
||||
return self.embedder_config.embed_documents([text])[0]
|
||||
elif hasattr(self.embedder_config, "embed"):
|
||||
return self.embedder_config.embed(text)
|
||||
if self.embedder_config is None:
|
||||
raise ValueError("Embedder configuration is not set")
|
||||
|
||||
embedder = self.embedder_config
|
||||
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||
return embedder.embed_documents([text])[0]
|
||||
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||
return embedder.embed(text)
|
||||
else:
|
||||
raise ValueError("Invalid embedding function configuration")
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None:
|
||||
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Generate embedding for text and save to Elasticsearch.
|
||||
|
||||
This method overrides the BaseRAGStorage method to use Elasticsearch.
|
||||
"""
|
||||
if not hasattr(self, "app") or self.app is None:
|
||||
self._initialize_app()
|
||||
|
||||
@@ -228,16 +246,19 @@ class ElasticsearchStorage(BaseRAGStorage):
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
self.app.index(
|
||||
index=self.index_name,
|
||||
id=str(uuid.uuid4()),
|
||||
document=doc,
|
||||
refresh=True # Make the document immediately available for search
|
||||
)
|
||||
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||
result = self.app.index(
|
||||
index=self.index_name,
|
||||
id=str(uuid.uuid4()),
|
||||
document=doc,
|
||||
refresh=True # Make the document immediately available for search
|
||||
)
|
||||
return result
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
if self.app:
|
||||
if self.app is not None:
|
||||
if self.app.indices.exists(index=self.index_name):
|
||||
self.app.indices.delete(index=self.index_name)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Type, cast
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
@@ -56,7 +56,7 @@ class StorageFactory:
|
||||
elif provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
return Mem0Storage(type=type, crew=crew)
|
||||
return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew))
|
||||
except ImportError:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
|
||||
Reference in New Issue
Block a user