fix: resolve mypy type annotation issues in storage and telemetry modules

- Add proper type parameters for EmbeddingFunction generics
- Fix ChromaDB query response handling with proper type checking
- Add missing return type annotations to telemetry methods
- Fix trace listener type annotations and imports
- Handle potential None values in nested list indexing
- Improve type safety in RAG and knowledge storage modules
This commit is contained in:
Greyson LaLonde
2025-09-04 14:58:28 -04:00
parent 23c60befd8
commit 4812986f58
4 changed files with 130 additions and 39 deletions

View File

@@ -99,7 +99,7 @@ class TraceCollectionListener(BaseEventListener):
return
super().__init__()
self.batch_manager = batch_manager or TraceBatchManager()
self.batch_manager = batch_manager or TraceBatchManager() # type: ignore[call-arg]
self._initialized = True
def _check_authenticated(self) -> bool:
@@ -324,7 +324,7 @@ class TraceCollectionListener(BaseEventListener):
def _initialize_batch(
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
):
) -> None:
"""Initialize trace batch if ephemeral"""
if not self._check_authenticated():
self.batch_manager.initialize_batch(
@@ -426,7 +426,7 @@ class TraceCollectionListener(BaseEventListener):
# TODO: move to utils
def _safe_serialize_to_dict(
self, obj, exclude: set[str] | None = None
self, obj: Any, exclude: set[str] | None = None
) -> dict[str, Any]:
"""Safely serialize an object to a dictionary for event data."""
try:

View File

@@ -3,6 +3,7 @@ import logging
import os
import shutil
import warnings
from collections.abc import Mapping
from typing import Any, Optional, Union
import chromadb
@@ -30,7 +31,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None
embedder: Optional[EmbeddingFunction] = None
embedder: Optional[EmbeddingFunction[Any]] = None
def __init__(
self,
@@ -57,17 +58,61 @@ class KnowledgeStorage(BaseKnowledgeStorage):
where=filter,
)
results = []
for i in range(len(fetched["ids"][0])):
result = {
"id": fetched["ids"][0][i],
"metadata": fetched["metadatas"][0][i],
"context": fetched["documents"][0][i],
"score": fetched["distances"][0][i],
}
if (
result["score"] <= score_threshold
): # Note: distances are smaller when more similar
results.append(result)
if (
fetched
and "ids" in fetched
and fetched["ids"]
and len(fetched["ids"]) > 0
):
ids_list = (
fetched["ids"][0]
if isinstance(fetched["ids"][0], list)
else fetched["ids"]
)
for i in range(len(ids_list)):
# Handle metadatas
metadata = {}
if fetched.get("metadatas") and len(fetched["metadatas"]) > 0:
metadata_list = (
fetched["metadatas"][0]
if isinstance(fetched["metadatas"][0], list)
else fetched["metadatas"]
)
if i < len(metadata_list):
metadata = metadata_list[i]
# Handle documents
context = ""
if fetched.get("documents") and len(fetched["documents"]) > 0:
docs_list = (
fetched["documents"][0]
if isinstance(fetched["documents"][0], list)
else fetched["documents"]
)
if i < len(docs_list):
context = docs_list[i]
# Handle distances
score = 1.0
if fetched.get("distances") and len(fetched["distances"]) > 0:
dist_list = (
fetched["distances"][0]
if isinstance(fetched["distances"][0], list)
else fetched["distances"]
)
if i < len(dist_list):
score = dist_list[i]
result = {
"id": ids_list[i],
"metadata": metadata,
"context": context,
"score": score,
}
# Check score threshold - distances are smaller when more similar
if isinstance(score, (int, float)) and score <= score_threshold:
results.append(result)
return results
else:
raise Exception("Collection not initialized")
@@ -150,7 +195,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
# If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
None if all(m is None for m in filtered_metadata) else filtered_metadata
None if all(m is None for m in filtered_metadata) else filtered_metadata # type: ignore[assignment]
)
self.collection.upsert(

View File

@@ -23,7 +23,7 @@ class RAGStorage(BaseRAGStorage):
"""
app: ClientAPI | None = None
embedder_config: EmbeddingFunction | None = None # type: ignore[assignment]
embedder_config: EmbeddingFunction[Any] | None = None
def __init__(
self,
@@ -41,14 +41,22 @@ class RAGStorage(BaseRAGStorage):
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type
self._original_embedder_config = (
embedder_config # Store for later use in _set_embedder_config
)
self.allow_reset = allow_reset
self.path = path
self._initialize_app()
def _set_embedder_config(self) -> None:
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
# Pass the original embedder_config from __init__, not self.embedder_config
if hasattr(self, "_original_embedder_config"):
self.embedder_config = configurator.configure_embedder(
self._original_embedder_config
)
else:
self.embedder_config = configurator.configure_embedder(None)
def _initialize_app(self) -> None:
from chromadb.config import Settings
@@ -118,23 +126,60 @@ class RAGStorage(BaseRAGStorage):
response = self.collection.query(query_texts=query, n_results=limit)
results = []
if response and "ids" in response and response["ids"]:
for i in range(len(response["ids"][0])):
if (
response
and "ids" in response
and response["ids"]
and len(response["ids"]) > 0
):
ids_list = (
response["ids"][0]
if isinstance(response["ids"][0], list)
else response["ids"]
)
for i in range(len(ids_list)):
# Handle metadatas
metadata = {}
if response.get("metadatas") and len(response["metadatas"]) > 0:
metadata_list = (
response["metadatas"][0]
if isinstance(response["metadatas"][0], list)
else response["metadatas"]
)
if i < len(metadata_list):
metadata = metadata_list[i]
# Handle documents
context = ""
if response.get("documents") and len(response["documents"]) > 0:
docs_list = (
response["documents"][0]
if isinstance(response["documents"][0], list)
else response["documents"]
)
if i < len(docs_list):
context = docs_list[i]
# Handle distances
score = 1.0
if response.get("distances") and len(response["distances"]) > 0:
dist_list = (
response["distances"][0]
if isinstance(response["distances"][0], list)
else response["distances"]
)
if i < len(dist_list):
score = dist_list[i]
result = {
"id": response["ids"][0][i],
"metadata": response["metadatas"][0][i]
if response.get("metadatas")
else {},
"context": response["documents"][0][i]
if response.get("documents")
else "",
"score": response["distances"][0][i]
if response.get("distances")
else 1.0,
"id": ids_list[i],
"metadata": metadata,
"context": context,
"score": score,
}
if (
result["score"] <= score_threshold
): # Note: distances are smaller when more similar
# Check score threshold - distances are smaller when more similar
if isinstance(score, (int, float)) and score <= score_threshold:
results.append(result)
return results
@@ -168,7 +213,7 @@ class RAGStorage(BaseRAGStorage):
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_default_embedding_function(self):
def _create_default_embedding_function(self) -> EmbeddingFunction[Any]:
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)

View File

@@ -534,7 +534,7 @@ class Telemetry:
def tool_usage_error(
self, llm: Any, agent: Any = None, tool_name: Optional[str] = None
):
) -> None:
"""Records when a tool usage results in an error.
Args:
@@ -572,7 +572,7 @@ class Telemetry:
def individual_test_result_span(
self, crew: Crew, quality: float, exec_time: int, model_name: str
):
) -> None:
"""Records individual test results for a crew execution.
Args:
@@ -607,7 +607,7 @@ class Telemetry:
iterations: int,
inputs: dict[str, Any] | None,
model_name: str,
):
) -> None:
"""Records the execution of a test suite for a crew.
Args:
@@ -792,7 +792,8 @@ class Telemetry:
if crew.share_crew:
self._safe_telemetry_operation(operation)
return operation()
result = operation()
return result # type: ignore[no-any-return]
return None
def end_crew(self, crew: Any, final_string_output: str) -> None: