mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 07:08:14 +00:00
Fix type-checking errors in Elasticsearch integration
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
@@ -40,9 +40,9 @@ class Knowledge(BaseModel):
|
|||||||
if storage_provider == "elasticsearch":
|
if storage_provider == "elasticsearch":
|
||||||
try:
|
try:
|
||||||
from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage
|
from crewai.knowledge.storage.elasticsearch_knowledge_storage import ElasticsearchKnowledgeStorage
|
||||||
self.storage = ElasticsearchKnowledgeStorage(
|
self.storage = cast(KnowledgeStorage, ElasticsearchKnowledgeStorage(
|
||||||
embedder=embedder, collection_name=collection_name
|
embedder_config=embedder, collection_name=collection_name
|
||||||
)
|
))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
"Elasticsearch is not installed. Please install it with `pip install elasticsearch`."
|
||||||
@@ -52,7 +52,8 @@ class Knowledge(BaseModel):
|
|||||||
embedder=embedder, collection_name=collection_name
|
embedder=embedder, collection_name=collection_name
|
||||||
)
|
)
|
||||||
self.sources = sources
|
self.sources = sources
|
||||||
self.storage.initialize_knowledge_storage()
|
if self.storage is not None:
|
||||||
|
self.storage.initialize_knowledge_storage()
|
||||||
self._add_sources()
|
self._add_sources()
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
|
|||||||
@@ -31,21 +31,21 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
and improving search efficiency.
|
and improving search efficiency.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
app = None
|
app: Any = None
|
||||||
collection_name: Optional[str] = "knowledge"
|
collection_name: Optional[str] = "knowledge"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedder: Optional[Dict[str, Any]] = None,
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
host="localhost",
|
host: str = "localhost",
|
||||||
port=9200,
|
port: int = 9200,
|
||||||
username=None,
|
username: Optional[str] = None,
|
||||||
password=None,
|
password: Optional[str] = None,
|
||||||
**kwargs
|
**kwargs: Any
|
||||||
):
|
):
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self._set_embedder_config(embedder)
|
self._set_embedder_config(embedder_config)
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
@@ -67,7 +67,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
try:
|
try:
|
||||||
embedding = self._get_embedding_for_text(query[0])
|
embedding = self._get_embedding_for_text(query[0])
|
||||||
|
|
||||||
search_query = {
|
search_query: Dict[str, Any] = {
|
||||||
"size": limit,
|
"size": limit,
|
||||||
"query": {
|
"query": {
|
||||||
"script_score": {
|
"script_score": {
|
||||||
@@ -81,35 +81,45 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
for key, value in filter.items():
|
query_obj = search_query.get("query", {})
|
||||||
search_query["query"]["script_score"]["query"] = {
|
if isinstance(query_obj, dict):
|
||||||
"bool": {
|
script_score_obj = query_obj.get("script_score", {})
|
||||||
"must": [
|
if isinstance(script_score_obj, dict):
|
||||||
search_query["query"]["script_score"]["query"],
|
query_part = script_score_obj.get("query", {})
|
||||||
{"match": {f"metadata.{key}": value}}
|
if isinstance(query_part, dict):
|
||||||
]
|
for key, value in filter.items():
|
||||||
}
|
script_score_obj["query"] = {
|
||||||
}
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
query_part,
|
||||||
|
{"match": {f"metadata.{key}": value}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
with suppress_logging():
|
with suppress_logging():
|
||||||
response = self.app.search(
|
if self.app is not None and hasattr(self.app, "search") and callable(getattr(self.app, "search")):
|
||||||
index=self.index_name,
|
response = self.app.search(
|
||||||
body=search_query
|
index=self.index_name,
|
||||||
)
|
body=search_query
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for hit in response["hits"]["hits"]:
|
for hit in response["hits"]["hits"]:
|
||||||
adjusted_score = (hit["_score"] - 1.0)
|
adjusted_score = (hit["_score"] - 1.0)
|
||||||
|
|
||||||
if adjusted_score >= score_threshold:
|
if adjusted_score >= score_threshold:
|
||||||
results.append({
|
results.append({
|
||||||
"id": hit["_id"],
|
"id": hit["_id"],
|
||||||
"metadata": hit["_source"]["metadata"],
|
"metadata": hit["_source"]["metadata"],
|
||||||
"context": hit["_source"]["text"],
|
"context": hit["_source"]["text"],
|
||||||
"score": adjusted_score,
|
"score": adjusted_score,
|
||||||
})
|
})
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
else:
|
||||||
|
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(verbose=True).log("error", f"Search error: {e}", "red")
|
Logger(verbose=True).log("error", f"Search error: {e}", "red")
|
||||||
raise Exception(f"Error during knowledge search: {str(e)}")
|
raise Exception(f"Error during knowledge search: {str(e)}")
|
||||||
@@ -159,9 +169,9 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
)
|
)
|
||||||
raise Exception(f"Error initializing Elasticsearch: {str(e)}")
|
raise Exception(f"Error initializing Elasticsearch: {str(e)}")
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
try:
|
try:
|
||||||
if self.app:
|
if self.app is not None:
|
||||||
if self.app.indices.exists(index=self.index_name):
|
if self.app.indices.exists(index=self.index_name):
|
||||||
self.app.indices.delete(index=self.index_name)
|
self.app.indices.delete(index=self.index_name)
|
||||||
|
|
||||||
@@ -175,7 +185,7 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
self,
|
self,
|
||||||
documents: List[str],
|
documents: List[str],
|
||||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||||
):
|
) -> None:
|
||||||
if not self.app:
|
if not self.app:
|
||||||
self.initialize_knowledge_storage()
|
self.initialize_knowledge_storage()
|
||||||
|
|
||||||
@@ -201,12 +211,15 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
"metadata": meta or {},
|
"metadata": meta or {},
|
||||||
}
|
}
|
||||||
|
|
||||||
self.app.index(
|
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||||
index=self.index_name,
|
self.app.index(
|
||||||
id=doc_id,
|
index=self.index_name,
|
||||||
document=doc_body,
|
id=doc_id,
|
||||||
refresh=True # Make the document immediately available for search
|
document=doc_body,
|
||||||
)
|
refresh=True # Make the document immediately available for search
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
Logger(verbose=True).log("error", "Elasticsearch client is not initialized", "red")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(verbose=True).log("error", f"Save error: {e}", "red")
|
Logger(verbose=True).log("error", f"Save error: {e}", "red")
|
||||||
@@ -214,10 +227,14 @@ class ElasticsearchKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
|
|
||||||
def _get_embedding_for_text(self, text: str) -> List[float]:
|
def _get_embedding_for_text(self, text: str) -> List[float]:
|
||||||
"""Get embedding for text using the configured embedder."""
|
"""Get embedding for text using the configured embedder."""
|
||||||
if hasattr(self.embedder_config, "embed_documents"):
|
if self.embedder_config is None:
|
||||||
return self.embedder_config.embed_documents([text])[0]
|
raise ValueError("Embedder configuration is not set")
|
||||||
elif hasattr(self.embedder_config, "embed"):
|
|
||||||
return self.embedder_config.embed(text)
|
embedder = self.embedder_config
|
||||||
|
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||||
|
return embedder.embed_documents([text])[0]
|
||||||
|
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||||
|
return embedder.embed(text)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid embedding function configuration")
|
raise ValueError("Invalid embedding function configuration")
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import io
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
@@ -32,20 +32,20 @@ class ElasticsearchStorage(BaseRAGStorage):
|
|||||||
and improving search efficiency.
|
and improving search efficiency.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
app: Any | None = None
|
app: Any = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
type,
|
type: str,
|
||||||
allow_reset=True,
|
allow_reset: bool = True,
|
||||||
embedder_config=None,
|
embedder_config: Any = None,
|
||||||
crew=None,
|
crew: Any = None,
|
||||||
path=None,
|
path: Optional[str] = None,
|
||||||
host="localhost",
|
host: str = "localhost",
|
||||||
port=9200,
|
port: int = 9200,
|
||||||
username=None,
|
username: Optional[str] = None,
|
||||||
password=None,
|
password: Optional[str] = None,
|
||||||
**kwargs
|
**kwargs: Any
|
||||||
):
|
):
|
||||||
super().__init__(type, allow_reset, embedder_config, crew)
|
super().__init__(type, allow_reset, embedder_config, crew)
|
||||||
agents = crew.agents if crew else []
|
agents = crew.agents if crew else []
|
||||||
@@ -160,7 +160,7 @@ class ElasticsearchStorage(BaseRAGStorage):
|
|||||||
try:
|
try:
|
||||||
embedding = self._get_embedding_for_text(query)
|
embedding = self._get_embedding_for_text(query)
|
||||||
|
|
||||||
search_query = {
|
search_query: Dict[str, Any] = {
|
||||||
"size": limit,
|
"size": limit,
|
||||||
"query": {
|
"query": {
|
||||||
"script_score": {
|
"script_score": {
|
||||||
@@ -174,49 +174,67 @@ class ElasticsearchStorage(BaseRAGStorage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
for key, value in filter.items():
|
query_obj = search_query.get("query", {})
|
||||||
search_query["query"]["script_score"]["query"] = {
|
if isinstance(query_obj, dict):
|
||||||
"bool": {
|
script_score_obj = query_obj.get("script_score", {})
|
||||||
"must": [
|
if isinstance(script_score_obj, dict):
|
||||||
search_query["query"]["script_score"]["query"],
|
query_part = script_score_obj.get("query", {})
|
||||||
{"match": {f"metadata.{key}": value}}
|
if isinstance(query_part, dict):
|
||||||
]
|
for key, value in filter.items():
|
||||||
}
|
script_score_obj["query"] = {
|
||||||
}
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
query_part,
|
||||||
|
{"match": {f"metadata.{key}": value}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
with suppress_logging():
|
with suppress_logging():
|
||||||
response = self.app.search(
|
if self.app is not None:
|
||||||
index=self.index_name,
|
response = self.app.search(
|
||||||
body=search_query
|
index=self.index_name,
|
||||||
)
|
body=search_query
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for hit in response["hits"]["hits"]:
|
for hit in response["hits"]["hits"]:
|
||||||
adjusted_score = (hit["_score"] - 1.0)
|
adjusted_score = (hit["_score"] - 1.0)
|
||||||
|
|
||||||
if adjusted_score >= score_threshold:
|
if adjusted_score >= score_threshold:
|
||||||
results.append({
|
results.append({
|
||||||
"id": hit["_id"],
|
"id": hit["_id"],
|
||||||
"metadata": hit["_source"]["metadata"],
|
"metadata": hit["_source"]["metadata"],
|
||||||
"context": hit["_source"]["text"],
|
"context": hit["_source"]["text"],
|
||||||
"score": adjusted_score,
|
"score": adjusted_score,
|
||||||
})
|
})
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
else:
|
||||||
|
logging.error("Elasticsearch client is not initialized")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_embedding_for_text(self, text: str) -> List[float]:
|
def _get_embedding_for_text(self, text: str) -> List[float]:
|
||||||
"""Get embedding for text using the configured embedder."""
|
"""Get embedding for text using the configured embedder."""
|
||||||
if hasattr(self.embedder_config, "embed_documents"):
|
if self.embedder_config is None:
|
||||||
return self.embedder_config.embed_documents([text])[0]
|
raise ValueError("Embedder configuration is not set")
|
||||||
elif hasattr(self.embedder_config, "embed"):
|
|
||||||
return self.embedder_config.embed(text)
|
embedder = self.embedder_config
|
||||||
|
if hasattr(embedder, "embed_documents") and callable(getattr(embedder, "embed_documents")):
|
||||||
|
return embedder.embed_documents([text])[0]
|
||||||
|
elif hasattr(embedder, "embed") and callable(getattr(embedder, "embed")):
|
||||||
|
return embedder.embed(text)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid embedding function configuration")
|
raise ValueError("Invalid embedding function configuration")
|
||||||
|
|
||||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None:
|
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
|
||||||
|
"""Generate embedding for text and save to Elasticsearch.
|
||||||
|
|
||||||
|
This method overrides the BaseRAGStorage method to use Elasticsearch.
|
||||||
|
"""
|
||||||
if not hasattr(self, "app") or self.app is None:
|
if not hasattr(self, "app") or self.app is None:
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
@@ -228,16 +246,19 @@ class ElasticsearchStorage(BaseRAGStorage):
|
|||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
}
|
}
|
||||||
|
|
||||||
self.app.index(
|
if self.app is not None and hasattr(self.app, "index") and callable(getattr(self.app, "index")):
|
||||||
index=self.index_name,
|
result = self.app.index(
|
||||||
id=str(uuid.uuid4()),
|
index=self.index_name,
|
||||||
document=doc,
|
id=str(uuid.uuid4()),
|
||||||
refresh=True # Make the document immediately available for search
|
document=doc,
|
||||||
)
|
refresh=True # Make the document immediately available for search
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
try:
|
try:
|
||||||
if self.app:
|
if self.app is not None:
|
||||||
if self.app.indices.exists(index=self.index_name):
|
if self.app.indices.exists(index=self.index_name):
|
||||||
self.app.indices.delete(index=self.index_name)
|
self.app.indices.delete(index=self.index_name)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Dict, Optional, Type, cast
|
||||||
|
|
||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.memory.storage.rag_storage import RAGStorage
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
@@ -56,7 +56,7 @@ class StorageFactory:
|
|||||||
elif provider == "mem0":
|
elif provider == "mem0":
|
||||||
try:
|
try:
|
||||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||||
return Mem0Storage(type=type, crew=crew)
|
return cast(BaseRAGStorage, Mem0Storage(type=type, crew=crew))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Logger(verbose=True).log(
|
Logger(verbose=True).log(
|
||||||
"error",
|
"error",
|
||||||
|
|||||||
Reference in New Issue
Block a user