mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
fix: resolve mypy type annotation issues in storage and telemetry modules
- Add proper type parameters for EmbeddingFunction generics - Fix ChromaDB query response handling with proper type checking - Add missing return type annotations to telemetry methods - Fix trace listener type annotations and imports - Handle potential None values in nested list indexing - Improve type safety in RAG and knowledge storage modules
This commit is contained in:
@@ -99,7 +99,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
return
|
||||
|
||||
super().__init__()
|
||||
self.batch_manager = batch_manager or TraceBatchManager()
|
||||
self.batch_manager = batch_manager or TraceBatchManager() # type: ignore[call-arg]
|
||||
self._initialized = True
|
||||
|
||||
def _check_authenticated(self) -> bool:
|
||||
@@ -324,7 +324,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
def _initialize_batch(
|
||||
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize trace batch if ephemeral"""
|
||||
if not self._check_authenticated():
|
||||
self.batch_manager.initialize_batch(
|
||||
@@ -426,7 +426,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
# TODO: move to utils
|
||||
def _safe_serialize_to_dict(
|
||||
self, obj, exclude: set[str] | None = None
|
||||
self, obj: Any, exclude: set[str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data."""
|
||||
try:
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import chromadb
|
||||
@@ -30,7 +31,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
collection: Optional[chromadb.Collection] = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
app: Optional[ClientAPI] = None
|
||||
embedder: Optional[EmbeddingFunction] = None
|
||||
embedder: Optional[EmbeddingFunction[Any]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -57,17 +58,61 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
where=filter,
|
||||
)
|
||||
results = []
|
||||
for i in range(len(fetched["ids"][0])):
|
||||
result = {
|
||||
"id": fetched["ids"][0][i],
|
||||
"metadata": fetched["metadatas"][0][i],
|
||||
"context": fetched["documents"][0][i],
|
||||
"score": fetched["distances"][0][i],
|
||||
}
|
||||
if (
|
||||
result["score"] <= score_threshold
|
||||
): # Note: distances are smaller when more similar
|
||||
results.append(result)
|
||||
if (
|
||||
fetched
|
||||
and "ids" in fetched
|
||||
and fetched["ids"]
|
||||
and len(fetched["ids"]) > 0
|
||||
):
|
||||
ids_list = (
|
||||
fetched["ids"][0]
|
||||
if isinstance(fetched["ids"][0], list)
|
||||
else fetched["ids"]
|
||||
)
|
||||
for i in range(len(ids_list)):
|
||||
# Handle metadatas
|
||||
metadata = {}
|
||||
if fetched.get("metadatas") and len(fetched["metadatas"]) > 0:
|
||||
metadata_list = (
|
||||
fetched["metadatas"][0]
|
||||
if isinstance(fetched["metadatas"][0], list)
|
||||
else fetched["metadatas"]
|
||||
)
|
||||
if i < len(metadata_list):
|
||||
metadata = metadata_list[i]
|
||||
|
||||
# Handle documents
|
||||
context = ""
|
||||
if fetched.get("documents") and len(fetched["documents"]) > 0:
|
||||
docs_list = (
|
||||
fetched["documents"][0]
|
||||
if isinstance(fetched["documents"][0], list)
|
||||
else fetched["documents"]
|
||||
)
|
||||
if i < len(docs_list):
|
||||
context = docs_list[i]
|
||||
|
||||
# Handle distances
|
||||
score = 1.0
|
||||
if fetched.get("distances") and len(fetched["distances"]) > 0:
|
||||
dist_list = (
|
||||
fetched["distances"][0]
|
||||
if isinstance(fetched["distances"][0], list)
|
||||
else fetched["distances"]
|
||||
)
|
||||
if i < len(dist_list):
|
||||
score = dist_list[i]
|
||||
|
||||
result = {
|
||||
"id": ids_list[i],
|
||||
"metadata": metadata,
|
||||
"context": context,
|
||||
"score": score,
|
||||
}
|
||||
|
||||
# Check score threshold - distances are smaller when more similar
|
||||
if isinstance(score, (int, float)) and score <= score_threshold:
|
||||
results.append(result)
|
||||
return results
|
||||
else:
|
||||
raise Exception("Collection not initialized")
|
||||
@@ -150,7 +195,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
# If we have no metadata at all, set it to None
|
||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata # type: ignore[assignment]
|
||||
)
|
||||
|
||||
self.collection.upsert(
|
||||
|
||||
@@ -23,7 +23,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
|
||||
app: ClientAPI | None = None
|
||||
embedder_config: EmbeddingFunction | None = None # type: ignore[assignment]
|
||||
embedder_config: EmbeddingFunction[Any] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -41,14 +41,22 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
|
||||
self._original_embedder_config = (
|
||||
embedder_config # Store for later use in _set_embedder_config
|
||||
)
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self) -> None:
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
# Pass the original embedder_config from __init__, not self.embedder_config
|
||||
if hasattr(self, "_original_embedder_config"):
|
||||
self.embedder_config = configurator.configure_embedder(
|
||||
self._original_embedder_config
|
||||
)
|
||||
else:
|
||||
self.embedder_config = configurator.configure_embedder(None)
|
||||
|
||||
def _initialize_app(self) -> None:
|
||||
from chromadb.config import Settings
|
||||
@@ -118,23 +126,60 @@ class RAGStorage(BaseRAGStorage):
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
|
||||
results = []
|
||||
if response and "ids" in response and response["ids"]:
|
||||
for i in range(len(response["ids"][0])):
|
||||
if (
|
||||
response
|
||||
and "ids" in response
|
||||
and response["ids"]
|
||||
and len(response["ids"]) > 0
|
||||
):
|
||||
ids_list = (
|
||||
response["ids"][0]
|
||||
if isinstance(response["ids"][0], list)
|
||||
else response["ids"]
|
||||
)
|
||||
for i in range(len(ids_list)):
|
||||
# Handle metadatas
|
||||
metadata = {}
|
||||
if response.get("metadatas") and len(response["metadatas"]) > 0:
|
||||
metadata_list = (
|
||||
response["metadatas"][0]
|
||||
if isinstance(response["metadatas"][0], list)
|
||||
else response["metadatas"]
|
||||
)
|
||||
if i < len(metadata_list):
|
||||
metadata = metadata_list[i]
|
||||
|
||||
# Handle documents
|
||||
context = ""
|
||||
if response.get("documents") and len(response["documents"]) > 0:
|
||||
docs_list = (
|
||||
response["documents"][0]
|
||||
if isinstance(response["documents"][0], list)
|
||||
else response["documents"]
|
||||
)
|
||||
if i < len(docs_list):
|
||||
context = docs_list[i]
|
||||
|
||||
# Handle distances
|
||||
score = 1.0
|
||||
if response.get("distances") and len(response["distances"]) > 0:
|
||||
dist_list = (
|
||||
response["distances"][0]
|
||||
if isinstance(response["distances"][0], list)
|
||||
else response["distances"]
|
||||
)
|
||||
if i < len(dist_list):
|
||||
score = dist_list[i]
|
||||
|
||||
result = {
|
||||
"id": response["ids"][0][i],
|
||||
"metadata": response["metadatas"][0][i]
|
||||
if response.get("metadatas")
|
||||
else {},
|
||||
"context": response["documents"][0][i]
|
||||
if response.get("documents")
|
||||
else "",
|
||||
"score": response["distances"][0][i]
|
||||
if response.get("distances")
|
||||
else 1.0,
|
||||
"id": ids_list[i],
|
||||
"metadata": metadata,
|
||||
"context": context,
|
||||
"score": score,
|
||||
}
|
||||
if (
|
||||
result["score"] <= score_threshold
|
||||
): # Note: distances are smaller when more similar
|
||||
|
||||
# Check score threshold - distances are smaller when more similar
|
||||
if isinstance(score, (int, float)) and score <= score_threshold:
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
@@ -168,7 +213,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
def _create_default_embedding_function(self) -> EmbeddingFunction[Any]:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
@@ -534,7 +534,7 @@ class Telemetry:
|
||||
|
||||
def tool_usage_error(
|
||||
self, llm: Any, agent: Any = None, tool_name: Optional[str] = None
|
||||
):
|
||||
) -> None:
|
||||
"""Records when a tool usage results in an error.
|
||||
|
||||
Args:
|
||||
@@ -572,7 +572,7 @@ class Telemetry:
|
||||
|
||||
def individual_test_result_span(
|
||||
self, crew: Crew, quality: float, exec_time: int, model_name: str
|
||||
):
|
||||
) -> None:
|
||||
"""Records individual test results for a crew execution.
|
||||
|
||||
Args:
|
||||
@@ -607,7 +607,7 @@ class Telemetry:
|
||||
iterations: int,
|
||||
inputs: dict[str, Any] | None,
|
||||
model_name: str,
|
||||
):
|
||||
) -> None:
|
||||
"""Records the execution of a test suite for a crew.
|
||||
|
||||
Args:
|
||||
@@ -792,7 +792,8 @@ class Telemetry:
|
||||
|
||||
if crew.share_crew:
|
||||
self._safe_telemetry_operation(operation)
|
||||
return operation()
|
||||
result = operation()
|
||||
return result # type: ignore[no-any-return]
|
||||
return None
|
||||
|
||||
def end_crew(self, crew: Any, final_string_output: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user