fix: resolve mypy errors in storage and tracing modules

This commit is contained in:
Greyson LaLonde
2025-09-04 15:39:01 -04:00
parent 221bfcccce
commit bcee792390
2 changed files with 80 additions and 44 deletions

View File

@@ -8,6 +8,7 @@ from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version from crewai.cli.version import get_crewai_version
from crewai.events.base_event_listener import BaseEventListener from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
from crewai.events.listeners.tracing.types import TraceEvent from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.types.agent_events import ( from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent, AgentExecutionCompletedEvent,
@@ -65,8 +66,6 @@ from crewai.events.types.tool_usage_events import (
) )
from crewai.utilities.serialization import to_serializable from crewai.utilities.serialization import to_serializable
from .trace_batch_manager import TraceBatchManager
class TraceCollectionListener(BaseEventListener): class TraceCollectionListener(BaseEventListener):
""" """
@@ -86,7 +85,7 @@ class TraceCollectionListener(BaseEventListener):
_initialized = False _initialized = False
_listeners_setup = False _listeners_setup = False
def __new__(cls, batch_manager: Optional[Any] = None) -> "TraceCollectionListener": def __new__(cls, batch_manager: Optional[Any] = None) -> Self:
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
@@ -99,10 +98,11 @@ class TraceCollectionListener(BaseEventListener):
return return
super().__init__() super().__init__()
self.batch_manager = batch_manager or TraceBatchManager() # type: ignore[call-arg] self.batch_manager = batch_manager or TraceBatchManager() # type: ignore
self._initialized = True self._initialized = True
def _check_authenticated(self) -> bool: @staticmethod
def _check_authenticated() -> bool:
"""Check if tracing should be enabled""" """Check if tracing should be enabled"""
try: try:
res = bool(get_auth_token()) res = bool(get_auth_token())
@@ -110,7 +110,8 @@ class TraceCollectionListener(BaseEventListener):
except AuthError: except AuthError:
return False return False
def _get_user_context(self) -> dict[str, str]: @staticmethod
def _get_user_context() -> dict[str, str]:
"""Extract user context for tracing""" """Extract user context for tracing"""
return { return {
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"), "user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
@@ -331,9 +332,7 @@ class TraceCollectionListener(BaseEventListener):
user_context, execution_metadata, use_ephemeral=True user_context, execution_metadata, use_ephemeral=True
) )
else: else:
self.batch_manager.initialize_batch( self.batch_manager.initialize_batch(user_context, execution_metadata)
user_context, execution_metadata, use_ephemeral=False
)
def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None: def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None:
"""Generic handler for context end events""" """Generic handler for context end events"""
@@ -424,11 +423,19 @@ class TraceCollectionListener(BaseEventListener):
"source": source, "source": source,
} }
# TODO: move to utils @staticmethod
def _safe_serialize_to_dict( def _safe_serialize_to_dict(
self, obj: Any, exclude: set[str] | None = None obj: Any, exclude: set[str] | None = None
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Safely serialize an object to a dictionary for event data.""" """Safely serialize an object to a dictionary for event data.
Args:
obj: The object to serialize.
exclude: Optional set of attribute names to exclude from serialization.
Notes:
- TODO: refactor to utilities function.
"""
try: try:
serialized = to_serializable(obj, exclude) serialized = to_serializable(obj, exclude)
if isinstance(serialized, dict): if isinstance(serialized, dict):
@@ -438,11 +445,20 @@ class TraceCollectionListener(BaseEventListener):
except Exception as e: except Exception as e:
return {"serialization_error": str(e), "object_type": type(obj).__name__} return {"serialization_error": str(e), "object_type": type(obj).__name__}
# TODO: move to utils @staticmethod
def _truncate_messages( def _truncate_messages(
self, messages: Any, max_content_length: int = 500, max_messages: int = 5 messages: Any, max_content_length: int = 500, max_messages: int = 5
) -> Any: ) -> Any:
"""Truncate message content and limit number of messages""" """Truncate message content and limit number of messages
Args:
messages: List of message dicts with 'content' keys.
max_content_length: Max length of each message content.
max_messages: Max number of messages to retain.
Notes:
- TODO: refactor to utilities function.
"""
if not messages or not isinstance(messages, list): if not messages or not isinstance(messages, list):
return messages return messages

View File

@@ -22,6 +22,38 @@ from crewai.utilities.logger_utils import suppress_logging
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
def _extract_chromadb_response_item(
response_data: Any,
index: int,
expected_type: type[Any] | tuple[type[Any], ...],
) -> Any | None:
"""Extract an item from ChromaDB response data at the given index.
Args:
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
index: The index of the item to extract.
expected_type: The expected type(s) of the item.
Returns:
The extracted item if it exists and matches the expected type, otherwise None.
"""
if response_data is None or not response_data:
return None
# ChromaDB sometimes returns nested lists, handle both cases
data_list = (
response_data[0]
if response_data and isinstance(response_data[0], list)
else response_data
)
if index < len(data_list):
item = data_list[index]
if isinstance(item, expected_type):
return item
return None
class KnowledgeStorage(BaseKnowledgeStorage): class KnowledgeStorage(BaseKnowledgeStorage):
""" """
Extends Storage to handle embeddings for memory entries, improving Extends Storage to handle embeddings for memory entries, improving
@@ -71,37 +103,22 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) )
for i in range(len(ids_list)): for i in range(len(ids_list)):
# Handle metadatas # Handle metadatas
metadata = {} meta_item = _extract_chromadb_response_item(
if fetched.get("metadatas") and len(fetched["metadatas"]) > 0: fetched.get("metadatas"), i, dict
metadata_list = ( )
fetched["metadatas"][0] metadata: dict[str, Any] = meta_item if meta_item else {}
if isinstance(fetched["metadatas"][0], list)
else fetched["metadatas"]
)
if i < len(metadata_list):
metadata = metadata_list[i]
# Handle documents # Handle documents
context = "" doc_item = _extract_chromadb_response_item(
if fetched.get("documents") and len(fetched["documents"]) > 0: fetched.get("documents"), i, str
docs_list = ( )
fetched["documents"][0] context = doc_item if doc_item else ""
if isinstance(fetched["documents"][0], list)
else fetched["documents"]
)
if i < len(docs_list):
context = docs_list[i]
# Handle distances # Handle distances
score = 1.0 dist_item = _extract_chromadb_response_item(
if fetched.get("distances") and len(fetched["distances"]) > 0: fetched.get("distances"), i, (int, float)
dist_list = ( )
fetched["distances"][0] score = dist_item if dist_item is not None else 1.0
if isinstance(fetched["distances"][0], list)
else fetched["distances"]
)
if i < len(dist_list):
score = dist_list[i]
result = { result = {
"id": ids_list[i], "id": ids_list[i],
@@ -231,11 +248,14 @@ class KnowledgeStorage(BaseKnowledgeStorage):
"""Set the embedding configuration for the knowledge storage. """Set the embedding configuration for the knowledge storage.
Args: Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. embedder (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function. If None or empty, defaults to the default embedding function.
Notes:
- TODO: Improve typing for embedder configuration, remove type: ignore
""" """
self.embedder = ( self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder) EmbeddingConfigurator().configure_embedder(embedder) # type: ignore
if embedder if embedder
else self._create_default_embedding_function() else self._create_default_embedding_function()
) )