mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
fix: resolve mypy type annotation issues in storage and telemetry modules
- Add proper type parameters for EmbeddingFunction generics - Fix ChromaDB query response handling with proper type checking - Add missing return type annotations to telemetry methods - Fix trace listener type annotations and imports - Handle potential None values in nested list indexing - Improve type safety in RAG and knowledge storage modules
This commit is contained in:
@@ -99,7 +99,7 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
return
|
return
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_manager = batch_manager or TraceBatchManager()
|
self.batch_manager = batch_manager or TraceBatchManager() # type: ignore[call-arg]
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def _check_authenticated(self) -> bool:
|
def _check_authenticated(self) -> bool:
|
||||||
@@ -324,7 +324,7 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
|
|
||||||
def _initialize_batch(
|
def _initialize_batch(
|
||||||
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
|
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
|
||||||
):
|
) -> None:
|
||||||
"""Initialize trace batch if ephemeral"""
|
"""Initialize trace batch if ephemeral"""
|
||||||
if not self._check_authenticated():
|
if not self._check_authenticated():
|
||||||
self.batch_manager.initialize_batch(
|
self.batch_manager.initialize_batch(
|
||||||
@@ -426,7 +426,7 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
|
|
||||||
# TODO: move to utils
|
# TODO: move to utils
|
||||||
def _safe_serialize_to_dict(
|
def _safe_serialize_to_dict(
|
||||||
self, obj, exclude: set[str] | None = None
|
self, obj: Any, exclude: set[str] | None = None
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Safely serialize an object to a dictionary for event data."""
|
"""Safely serialize an object to a dictionary for event data."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
@@ -30,7 +31,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
collection: Optional[chromadb.Collection] = None
|
collection: Optional[chromadb.Collection] = None
|
||||||
collection_name: Optional[str] = "knowledge"
|
collection_name: Optional[str] = "knowledge"
|
||||||
app: Optional[ClientAPI] = None
|
app: Optional[ClientAPI] = None
|
||||||
embedder: Optional[EmbeddingFunction] = None
|
embedder: Optional[EmbeddingFunction[Any]] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -57,17 +58,61 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
where=filter,
|
where=filter,
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
for i in range(len(fetched["ids"][0])):
|
if (
|
||||||
result = {
|
fetched
|
||||||
"id": fetched["ids"][0][i],
|
and "ids" in fetched
|
||||||
"metadata": fetched["metadatas"][0][i],
|
and fetched["ids"]
|
||||||
"context": fetched["documents"][0][i],
|
and len(fetched["ids"]) > 0
|
||||||
"score": fetched["distances"][0][i],
|
):
|
||||||
}
|
ids_list = (
|
||||||
if (
|
fetched["ids"][0]
|
||||||
result["score"] <= score_threshold
|
if isinstance(fetched["ids"][0], list)
|
||||||
): # Note: distances are smaller when more similar
|
else fetched["ids"]
|
||||||
results.append(result)
|
)
|
||||||
|
for i in range(len(ids_list)):
|
||||||
|
# Handle metadatas
|
||||||
|
metadata = {}
|
||||||
|
if fetched.get("metadatas") and len(fetched["metadatas"]) > 0:
|
||||||
|
metadata_list = (
|
||||||
|
fetched["metadatas"][0]
|
||||||
|
if isinstance(fetched["metadatas"][0], list)
|
||||||
|
else fetched["metadatas"]
|
||||||
|
)
|
||||||
|
if i < len(metadata_list):
|
||||||
|
metadata = metadata_list[i]
|
||||||
|
|
||||||
|
# Handle documents
|
||||||
|
context = ""
|
||||||
|
if fetched.get("documents") and len(fetched["documents"]) > 0:
|
||||||
|
docs_list = (
|
||||||
|
fetched["documents"][0]
|
||||||
|
if isinstance(fetched["documents"][0], list)
|
||||||
|
else fetched["documents"]
|
||||||
|
)
|
||||||
|
if i < len(docs_list):
|
||||||
|
context = docs_list[i]
|
||||||
|
|
||||||
|
# Handle distances
|
||||||
|
score = 1.0
|
||||||
|
if fetched.get("distances") and len(fetched["distances"]) > 0:
|
||||||
|
dist_list = (
|
||||||
|
fetched["distances"][0]
|
||||||
|
if isinstance(fetched["distances"][0], list)
|
||||||
|
else fetched["distances"]
|
||||||
|
)
|
||||||
|
if i < len(dist_list):
|
||||||
|
score = dist_list[i]
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"id": ids_list[i],
|
||||||
|
"metadata": metadata,
|
||||||
|
"context": context,
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check score threshold - distances are smaller when more similar
|
||||||
|
if isinstance(score, (int, float)) and score <= score_threshold:
|
||||||
|
results.append(result)
|
||||||
return results
|
return results
|
||||||
else:
|
else:
|
||||||
raise Exception("Collection not initialized")
|
raise Exception("Collection not initialized")
|
||||||
@@ -150,7 +195,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
|
|
||||||
# If we have no metadata at all, set it to None
|
# If we have no metadata at all, set it to None
|
||||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
None if all(m is None for m in filtered_metadata) else filtered_metadata # type: ignore[assignment]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.collection.upsert(
|
self.collection.upsert(
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
app: ClientAPI | None = None
|
app: ClientAPI | None = None
|
||||||
embedder_config: EmbeddingFunction | None = None # type: ignore[assignment]
|
embedder_config: EmbeddingFunction[Any] | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -41,14 +41,22 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||||
|
|
||||||
self.type = type
|
self.type = type
|
||||||
|
self._original_embedder_config = (
|
||||||
|
embedder_config # Store for later use in _set_embedder_config
|
||||||
|
)
|
||||||
self.allow_reset = allow_reset
|
self.allow_reset = allow_reset
|
||||||
self.path = path
|
self.path = path
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
def _set_embedder_config(self) -> None:
|
def _set_embedder_config(self) -> None:
|
||||||
configurator = EmbeddingConfigurator()
|
configurator = EmbeddingConfigurator()
|
||||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
# Pass the original embedder_config from __init__, not self.embedder_config
|
||||||
|
if hasattr(self, "_original_embedder_config"):
|
||||||
|
self.embedder_config = configurator.configure_embedder(
|
||||||
|
self._original_embedder_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embedder_config = configurator.configure_embedder(None)
|
||||||
|
|
||||||
def _initialize_app(self) -> None:
|
def _initialize_app(self) -> None:
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
@@ -118,23 +126,60 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
response = self.collection.query(query_texts=query, n_results=limit)
|
response = self.collection.query(query_texts=query, n_results=limit)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
if response and "ids" in response and response["ids"]:
|
if (
|
||||||
for i in range(len(response["ids"][0])):
|
response
|
||||||
|
and "ids" in response
|
||||||
|
and response["ids"]
|
||||||
|
and len(response["ids"]) > 0
|
||||||
|
):
|
||||||
|
ids_list = (
|
||||||
|
response["ids"][0]
|
||||||
|
if isinstance(response["ids"][0], list)
|
||||||
|
else response["ids"]
|
||||||
|
)
|
||||||
|
for i in range(len(ids_list)):
|
||||||
|
# Handle metadatas
|
||||||
|
metadata = {}
|
||||||
|
if response.get("metadatas") and len(response["metadatas"]) > 0:
|
||||||
|
metadata_list = (
|
||||||
|
response["metadatas"][0]
|
||||||
|
if isinstance(response["metadatas"][0], list)
|
||||||
|
else response["metadatas"]
|
||||||
|
)
|
||||||
|
if i < len(metadata_list):
|
||||||
|
metadata = metadata_list[i]
|
||||||
|
|
||||||
|
# Handle documents
|
||||||
|
context = ""
|
||||||
|
if response.get("documents") and len(response["documents"]) > 0:
|
||||||
|
docs_list = (
|
||||||
|
response["documents"][0]
|
||||||
|
if isinstance(response["documents"][0], list)
|
||||||
|
else response["documents"]
|
||||||
|
)
|
||||||
|
if i < len(docs_list):
|
||||||
|
context = docs_list[i]
|
||||||
|
|
||||||
|
# Handle distances
|
||||||
|
score = 1.0
|
||||||
|
if response.get("distances") and len(response["distances"]) > 0:
|
||||||
|
dist_list = (
|
||||||
|
response["distances"][0]
|
||||||
|
if isinstance(response["distances"][0], list)
|
||||||
|
else response["distances"]
|
||||||
|
)
|
||||||
|
if i < len(dist_list):
|
||||||
|
score = dist_list[i]
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"id": response["ids"][0][i],
|
"id": ids_list[i],
|
||||||
"metadata": response["metadatas"][0][i]
|
"metadata": metadata,
|
||||||
if response.get("metadatas")
|
"context": context,
|
||||||
else {},
|
"score": score,
|
||||||
"context": response["documents"][0][i]
|
|
||||||
if response.get("documents")
|
|
||||||
else "",
|
|
||||||
"score": response["distances"][0][i]
|
|
||||||
if response.get("distances")
|
|
||||||
else 1.0,
|
|
||||||
}
|
}
|
||||||
if (
|
|
||||||
result["score"] <= score_threshold
|
# Check score threshold - distances are smaller when more similar
|
||||||
): # Note: distances are smaller when more similar
|
if isinstance(score, (int, float)) and score <= score_threshold:
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@@ -168,7 +213,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_default_embedding_function(self):
|
def _create_default_embedding_function(self) -> EmbeddingFunction[Any]:
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -534,7 +534,7 @@ class Telemetry:
|
|||||||
|
|
||||||
def tool_usage_error(
|
def tool_usage_error(
|
||||||
self, llm: Any, agent: Any = None, tool_name: Optional[str] = None
|
self, llm: Any, agent: Any = None, tool_name: Optional[str] = None
|
||||||
):
|
) -> None:
|
||||||
"""Records when a tool usage results in an error.
|
"""Records when a tool usage results in an error.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -572,7 +572,7 @@ class Telemetry:
|
|||||||
|
|
||||||
def individual_test_result_span(
|
def individual_test_result_span(
|
||||||
self, crew: Crew, quality: float, exec_time: int, model_name: str
|
self, crew: Crew, quality: float, exec_time: int, model_name: str
|
||||||
):
|
) -> None:
|
||||||
"""Records individual test results for a crew execution.
|
"""Records individual test results for a crew execution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -607,7 +607,7 @@ class Telemetry:
|
|||||||
iterations: int,
|
iterations: int,
|
||||||
inputs: dict[str, Any] | None,
|
inputs: dict[str, Any] | None,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
) -> None:
|
||||||
"""Records the execution of a test suite for a crew.
|
"""Records the execution of a test suite for a crew.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -792,7 +792,8 @@ class Telemetry:
|
|||||||
|
|
||||||
if crew.share_crew:
|
if crew.share_crew:
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
return operation()
|
result = operation()
|
||||||
|
return result # type: ignore[no-any-return]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def end_crew(self, crew: Any, final_string_output: str) -> None:
|
def end_crew(self, crew: Any, final_string_output: str) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user