From bcee792390a49c01ce51fc884244925ec96e5299 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 4 Sep 2025 15:39:01 -0400 Subject: [PATCH] fix: resolve mypy errors in storage and tracing modules --- .../listeners/tracing/trace_listener.py | 46 +++++++---- .../knowledge/storage/knowledge_storage.py | 78 ++++++++++++------- 2 files changed, 80 insertions(+), 44 deletions(-) diff --git a/src/crewai/events/listeners/tracing/trace_listener.py b/src/crewai/events/listeners/tracing/trace_listener.py index e76e005ef..097f98a7e 100644 --- a/src/crewai/events/listeners/tracing/trace_listener.py +++ b/src/crewai/events/listeners/tracing/trace_listener.py @@ -8,6 +8,7 @@ from crewai.cli.authentication.token import AuthError, get_auth_token from crewai.cli.version import get_crewai_version from crewai.events.base_event_listener import BaseEventListener from crewai.events.event_bus import CrewAIEventsBus +from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager from crewai.events.listeners.tracing.types import TraceEvent from crewai.events.types.agent_events import ( AgentExecutionCompletedEvent, @@ -65,8 +66,6 @@ from crewai.events.types.tool_usage_events import ( ) from crewai.utilities.serialization import to_serializable -from .trace_batch_manager import TraceBatchManager - class TraceCollectionListener(BaseEventListener): """ @@ -86,7 +85,7 @@ class TraceCollectionListener(BaseEventListener): _initialized = False _listeners_setup = False - def __new__(cls, batch_manager: Optional[Any] = None) -> "TraceCollectionListener": + def __new__(cls, batch_manager: Optional[Any] = None) -> Self: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @@ -99,10 +98,11 @@ class TraceCollectionListener(BaseEventListener): return super().__init__() - self.batch_manager = batch_manager or TraceBatchManager() # type: ignore[call-arg] + self.batch_manager = batch_manager or TraceBatchManager() # type: ignore self._initialized = True - def _check_authenticated(self) -> bool: + @staticmethod + def _check_authenticated() -> bool: """Check if tracing should be enabled""" try: res = bool(get_auth_token()) @@ -110,7 +110,8 @@ class TraceCollectionListener(BaseEventListener): except AuthError: return False - def _get_user_context(self) -> dict[str, str]: + @staticmethod + def _get_user_context() -> dict[str, str]: """Extract user context for tracing""" return { "user_id": os.getenv("CREWAI_USER_ID", "anonymous"), @@ -331,9 +332,7 @@ class TraceCollectionListener(BaseEventListener): user_context, execution_metadata, use_ephemeral=True ) else: - self.batch_manager.initialize_batch( - user_context, execution_metadata, use_ephemeral=False - ) + self.batch_manager.initialize_batch(user_context, execution_metadata) def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None: """Generic handler for context end events""" @@ -424,11 +423,19 @@ class TraceCollectionListener(BaseEventListener): "source": source, } - # TODO: move to utils + @staticmethod def _safe_serialize_to_dict( - self, obj: Any, exclude: set[str] | None = None + obj: Any, exclude: set[str] | None = None ) -> dict[str, Any]: - """Safely serialize an object to a dictionary for event data.""" + """Safely serialize an object to a dictionary for event data. + + Args: + obj: The object to serialize. + exclude: Optional set of attribute names to exclude from serialization. + + Notes: + - TODO: refactor to utilities function. + """ try: serialized = to_serializable(obj, exclude) if isinstance(serialized, dict): @@ -438,11 +445,20 @@ class TraceCollectionListener(BaseEventListener): except Exception as e: return {"serialization_error": str(e), "object_type": type(obj).__name__} - # TODO: move to utils + @staticmethod def _truncate_messages( - self, messages: Any, max_content_length: int = 500, max_messages: int = 5 + messages: Any, max_content_length: int = 500, max_messages: int = 5 ) -> Any: - """Truncate message content and limit number of messages""" + """Truncate message content and limit number of messages + + Args: + messages: List of message dicts with 'content' keys. + max_content_length: Max length of each message content. + max_messages: Max number of messages to retain. + + Notes: + - TODO: refactor to utilities function. + """ if not messages or not isinstance(messages, list): return messages diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index 04e36b283..5a811694c 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -22,6 +22,38 @@ from crewai.utilities.logger_utils import suppress_logging from crewai.utilities.paths import db_storage_path +def _extract_chromadb_response_item( + response_data: Any, + index: int, + expected_type: type[Any] | tuple[type[Any], ...], +) -> Any | None: + """Extract an item from ChromaDB response data at the given index. + + Args: + response_data: The response data from ChromaDB query (e.g., documents, metadatas). + index: The index of the item to extract. + expected_type: The expected type(s) of the item. + + Returns: + The extracted item if it exists and matches the expected type, otherwise None. + """ + if response_data is None or not response_data: + return None + + # ChromaDB sometimes returns nested lists, handle both cases + data_list = ( + response_data[0] + if response_data and isinstance(response_data[0], list) + else response_data + ) + + if index < len(data_list): + item = data_list[index] + if isinstance(item, expected_type): + return item + return None + + class KnowledgeStorage(BaseKnowledgeStorage): """ Extends Storage to handle embeddings for memory entries, improving @@ -71,37 +103,22 @@ class KnowledgeStorage(BaseKnowledgeStorage): ) for i in range(len(ids_list)): # Handle metadatas - metadata = {} - if fetched.get("metadatas") and len(fetched["metadatas"]) > 0: - metadata_list = ( - fetched["metadatas"][0] - if isinstance(fetched["metadatas"][0], list) - else fetched["metadatas"] - ) - if i < len(metadata_list): - metadata = metadata_list[i] + meta_item = _extract_chromadb_response_item( + fetched.get("metadatas"), i, dict + ) + metadata: dict[str, Any] = meta_item if meta_item else {} # Handle documents - context = "" - if fetched.get("documents") and len(fetched["documents"]) > 0: - docs_list = ( - fetched["documents"][0] - if isinstance(fetched["documents"][0], list) - else fetched["documents"] - ) - if i < len(docs_list): - context = docs_list[i] + doc_item = _extract_chromadb_response_item( + fetched.get("documents"), i, str + ) + context = doc_item if doc_item else "" # Handle distances - score = 1.0 - if fetched.get("distances") and len(fetched["distances"]) > 0: - dist_list = ( - fetched["distances"][0] - if isinstance(fetched["distances"][0], list) - else fetched["distances"] - ) - if i < len(dist_list): - score = dist_list[i] + dist_item = _extract_chromadb_response_item( + fetched.get("distances"), i, (int, float) + ) + score = dist_item if dist_item is not None else 1.0 result = { "id": ids_list[i], @@ -231,11 +248,14 @@ class KnowledgeStorage(BaseKnowledgeStorage): """Set the embedding configuration for the knowledge storage. Args: - embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. + embedder (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. If None or empty, defaults to the default embedding function. + + Notes: + - TODO: Improve typing for embedder configuration, remove type: ignore """ self.embedder = ( - EmbeddingConfigurator().configure_embedder(embedder) + EmbeddingConfigurator().configure_embedder(embedder) # type: ignore if embedder else self._create_default_embedding_function() )