fix: resolve mypy type annotation issues in storage and telemetry modules

- Add proper type parameters for EmbeddingFunction generics
- Fix ChromaDB query response handling with proper type checking
- Add missing return type annotations to telemetry methods
- Fix trace listener type annotations and imports
- Handle potential None values in nested list indexing
- Improve type safety in RAG and knowledge storage modules
This commit is contained in:
Greyson LaLonde
2025-09-04 14:58:28 -04:00
parent 23c60befd8
commit 4812986f58
4 changed files with 130 additions and 39 deletions

View File

@@ -99,7 +99,7 @@ class TraceCollectionListener(BaseEventListener):
return return
super().__init__() super().__init__()
self.batch_manager = batch_manager or TraceBatchManager() self.batch_manager = batch_manager or TraceBatchManager() # type: ignore[call-arg]
self._initialized = True self._initialized = True
def _check_authenticated(self) -> bool: def _check_authenticated(self) -> bool:
@@ -324,7 +324,7 @@ class TraceCollectionListener(BaseEventListener):
def _initialize_batch( def _initialize_batch(
self, user_context: dict[str, str], execution_metadata: dict[str, Any] self, user_context: dict[str, str], execution_metadata: dict[str, Any]
): ) -> None:
"""Initialize trace batch if ephemeral""" """Initialize trace batch if ephemeral"""
if not self._check_authenticated(): if not self._check_authenticated():
self.batch_manager.initialize_batch( self.batch_manager.initialize_batch(
@@ -426,7 +426,7 @@ class TraceCollectionListener(BaseEventListener):
# TODO: move to utils # TODO: move to utils
def _safe_serialize_to_dict( def _safe_serialize_to_dict(
self, obj, exclude: set[str] | None = None self, obj: Any, exclude: set[str] | None = None
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Safely serialize an object to a dictionary for event data.""" """Safely serialize an object to a dictionary for event data."""
try: try:

View File

@@ -3,6 +3,7 @@ import logging
import os import os
import shutil import shutil
import warnings import warnings
from collections.abc import Mapping
from typing import Any, Optional, Union from typing import Any, Optional, Union
import chromadb import chromadb
@@ -30,7 +31,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
collection: Optional[chromadb.Collection] = None collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge" collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None app: Optional[ClientAPI] = None
embedder: Optional[EmbeddingFunction] = None embedder: Optional[EmbeddingFunction[Any]] = None
def __init__( def __init__(
self, self,
@@ -57,17 +58,61 @@ class KnowledgeStorage(BaseKnowledgeStorage):
where=filter, where=filter,
) )
results = [] results = []
for i in range(len(fetched["ids"][0])): if (
result = { fetched
"id": fetched["ids"][0][i], and "ids" in fetched
"metadata": fetched["metadatas"][0][i], and fetched["ids"]
"context": fetched["documents"][0][i], and len(fetched["ids"]) > 0
"score": fetched["distances"][0][i], ):
} ids_list = (
if ( fetched["ids"][0]
result["score"] <= score_threshold if isinstance(fetched["ids"][0], list)
): # Note: distances are smaller when more similar else fetched["ids"]
results.append(result) )
for i in range(len(ids_list)):
# Handle metadatas
metadata = {}
if fetched.get("metadatas") and len(fetched["metadatas"]) > 0:
metadata_list = (
fetched["metadatas"][0]
if isinstance(fetched["metadatas"][0], list)
else fetched["metadatas"]
)
if i < len(metadata_list):
metadata = metadata_list[i]
# Handle documents
context = ""
if fetched.get("documents") and len(fetched["documents"]) > 0:
docs_list = (
fetched["documents"][0]
if isinstance(fetched["documents"][0], list)
else fetched["documents"]
)
if i < len(docs_list):
context = docs_list[i]
# Handle distances
score = 1.0
if fetched.get("distances") and len(fetched["distances"]) > 0:
dist_list = (
fetched["distances"][0]
if isinstance(fetched["distances"][0], list)
else fetched["distances"]
)
if i < len(dist_list):
score = dist_list[i]
result = {
"id": ids_list[i],
"metadata": metadata,
"context": context,
"score": score,
}
# Check score threshold - distances are smaller when more similar
if isinstance(score, (int, float)) and score <= score_threshold:
results.append(result)
return results return results
else: else:
raise Exception("Collection not initialized") raise Exception("Collection not initialized")
@@ -150,7 +195,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
# If we have no metadata at all, set it to None # If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = ( final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
None if all(m is None for m in filtered_metadata) else filtered_metadata None if all(m is None for m in filtered_metadata) else filtered_metadata # type: ignore[assignment]
) )
self.collection.upsert( self.collection.upsert(

View File

@@ -23,7 +23,7 @@ class RAGStorage(BaseRAGStorage):
""" """
app: ClientAPI | None = None app: ClientAPI | None = None
embedder_config: EmbeddingFunction | None = None # type: ignore[assignment] embedder_config: EmbeddingFunction[Any] | None = None
def __init__( def __init__(
self, self,
@@ -41,14 +41,22 @@ class RAGStorage(BaseRAGStorage):
self.storage_file_name = self._build_storage_file_name(type, agents) self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type self.type = type
self._original_embedder_config = (
embedder_config # Store for later use in _set_embedder_config
)
self.allow_reset = allow_reset self.allow_reset = allow_reset
self.path = path self.path = path
self._initialize_app() self._initialize_app()
def _set_embedder_config(self) -> None: def _set_embedder_config(self) -> None:
configurator = EmbeddingConfigurator() configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config) # Pass the original embedder_config from __init__, not self.embedder_config
if hasattr(self, "_original_embedder_config"):
self.embedder_config = configurator.configure_embedder(
self._original_embedder_config
)
else:
self.embedder_config = configurator.configure_embedder(None)
def _initialize_app(self) -> None: def _initialize_app(self) -> None:
from chromadb.config import Settings from chromadb.config import Settings
@@ -118,23 +126,60 @@ class RAGStorage(BaseRAGStorage):
response = self.collection.query(query_texts=query, n_results=limit) response = self.collection.query(query_texts=query, n_results=limit)
results = [] results = []
if response and "ids" in response and response["ids"]: if (
for i in range(len(response["ids"][0])): response
and "ids" in response
and response["ids"]
and len(response["ids"]) > 0
):
ids_list = (
response["ids"][0]
if isinstance(response["ids"][0], list)
else response["ids"]
)
for i in range(len(ids_list)):
# Handle metadatas
metadata = {}
if response.get("metadatas") and len(response["metadatas"]) > 0:
metadata_list = (
response["metadatas"][0]
if isinstance(response["metadatas"][0], list)
else response["metadatas"]
)
if i < len(metadata_list):
metadata = metadata_list[i]
# Handle documents
context = ""
if response.get("documents") and len(response["documents"]) > 0:
docs_list = (
response["documents"][0]
if isinstance(response["documents"][0], list)
else response["documents"]
)
if i < len(docs_list):
context = docs_list[i]
# Handle distances
score = 1.0
if response.get("distances") and len(response["distances"]) > 0:
dist_list = (
response["distances"][0]
if isinstance(response["distances"][0], list)
else response["distances"]
)
if i < len(dist_list):
score = dist_list[i]
result = { result = {
"id": response["ids"][0][i], "id": ids_list[i],
"metadata": response["metadatas"][0][i] "metadata": metadata,
if response.get("metadatas") "context": context,
else {}, "score": score,
"context": response["documents"][0][i]
if response.get("documents")
else "",
"score": response["distances"][0][i]
if response.get("distances")
else 1.0,
} }
if (
result["score"] <= score_threshold # Check score threshold - distances are smaller when more similar
): # Note: distances are smaller when more similar if isinstance(score, (int, float)) and score <= score_threshold:
results.append(result) results.append(result)
return results return results
@@ -168,7 +213,7 @@ class RAGStorage(BaseRAGStorage):
f"An error occurred while resetting the {self.type} memory: {e}" f"An error occurred while resetting the {self.type} memory: {e}"
) )
def _create_default_embedding_function(self): def _create_default_embedding_function(self) -> EmbeddingFunction[Any]:
from chromadb.utils.embedding_functions.openai_embedding_function import ( from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction, OpenAIEmbeddingFunction,
) )

View File

@@ -534,7 +534,7 @@ class Telemetry:
def tool_usage_error( def tool_usage_error(
self, llm: Any, agent: Any = None, tool_name: Optional[str] = None self, llm: Any, agent: Any = None, tool_name: Optional[str] = None
): ) -> None:
"""Records when a tool usage results in an error. """Records when a tool usage results in an error.
Args: Args:
@@ -572,7 +572,7 @@ class Telemetry:
def individual_test_result_span( def individual_test_result_span(
self, crew: Crew, quality: float, exec_time: int, model_name: str self, crew: Crew, quality: float, exec_time: int, model_name: str
): ) -> None:
"""Records individual test results for a crew execution. """Records individual test results for a crew execution.
Args: Args:
@@ -607,7 +607,7 @@ class Telemetry:
iterations: int, iterations: int,
inputs: dict[str, Any] | None, inputs: dict[str, Any] | None,
model_name: str, model_name: str,
): ) -> None:
"""Records the execution of a test suite for a crew. """Records the execution of a test suite for a crew.
Args: Args:
@@ -792,7 +792,8 @@ class Telemetry:
if crew.share_crew: if crew.share_crew:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
return operation() result = operation()
return result # type: ignore[no-any-return]
return None return None
def end_crew(self, crew: Any, final_string_output: str) -> None: def end_crew(self, crew: Any, final_string_output: str) -> None: