mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: resolve mypy errors in storage and tracing modules
This commit is contained in:
@@ -8,6 +8,7 @@ from crewai.cli.authentication.token import AuthError, get_auth_token
|
|||||||
from crewai.cli.version import get_crewai_version
|
from crewai.cli.version import get_crewai_version
|
||||||
from crewai.events.base_event_listener import BaseEventListener
|
from crewai.events.base_event_listener import BaseEventListener
|
||||||
from crewai.events.event_bus import CrewAIEventsBus
|
from crewai.events.event_bus import CrewAIEventsBus
|
||||||
|
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
|
||||||
from crewai.events.listeners.tracing.types import TraceEvent
|
from crewai.events.listeners.tracing.types import TraceEvent
|
||||||
from crewai.events.types.agent_events import (
|
from crewai.events.types.agent_events import (
|
||||||
AgentExecutionCompletedEvent,
|
AgentExecutionCompletedEvent,
|
||||||
@@ -65,8 +66,6 @@ from crewai.events.types.tool_usage_events import (
|
|||||||
)
|
)
|
||||||
from crewai.utilities.serialization import to_serializable
|
from crewai.utilities.serialization import to_serializable
|
||||||
|
|
||||||
from .trace_batch_manager import TraceBatchManager
|
|
||||||
|
|
||||||
|
|
||||||
class TraceCollectionListener(BaseEventListener):
|
class TraceCollectionListener(BaseEventListener):
|
||||||
"""
|
"""
|
||||||
@@ -86,7 +85,7 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
_initialized = False
|
_initialized = False
|
||||||
_listeners_setup = False
|
_listeners_setup = False
|
||||||
|
|
||||||
def __new__(cls, batch_manager: Optional[Any] = None) -> "TraceCollectionListener":
|
def __new__(cls, batch_manager: Optional[Any] = None) -> Self:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
@@ -99,10 +98,11 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
return
|
return
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_manager = batch_manager or TraceBatchManager() # type: ignore[call-arg]
|
self.batch_manager = batch_manager or TraceBatchManager() # type: ignore
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def _check_authenticated(self) -> bool:
|
@staticmethod
|
||||||
|
def _check_authenticated() -> bool:
|
||||||
"""Check if tracing should be enabled"""
|
"""Check if tracing should be enabled"""
|
||||||
try:
|
try:
|
||||||
res = bool(get_auth_token())
|
res = bool(get_auth_token())
|
||||||
@@ -110,7 +110,8 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
except AuthError:
|
except AuthError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_user_context(self) -> dict[str, str]:
|
@staticmethod
|
||||||
|
def _get_user_context() -> dict[str, str]:
|
||||||
"""Extract user context for tracing"""
|
"""Extract user context for tracing"""
|
||||||
return {
|
return {
|
||||||
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
|
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
|
||||||
@@ -331,9 +332,7 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
user_context, execution_metadata, use_ephemeral=True
|
user_context, execution_metadata, use_ephemeral=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.batch_manager.initialize_batch(
|
self.batch_manager.initialize_batch(user_context, execution_metadata)
|
||||||
user_context, execution_metadata, use_ephemeral=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None:
|
def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None:
|
||||||
"""Generic handler for context end events"""
|
"""Generic handler for context end events"""
|
||||||
@@ -424,11 +423,19 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
"source": source,
|
"source": source,
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO: move to utils
|
@staticmethod
|
||||||
def _safe_serialize_to_dict(
|
def _safe_serialize_to_dict(
|
||||||
self, obj: Any, exclude: set[str] | None = None
|
obj: Any, exclude: set[str] | None = None
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Safely serialize an object to a dictionary for event data."""
|
"""Safely serialize an object to a dictionary for event data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to serialize.
|
||||||
|
exclude: Optional set of attribute names to exclude from serialization.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- TODO: refactor to utilities function.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
serialized = to_serializable(obj, exclude)
|
serialized = to_serializable(obj, exclude)
|
||||||
if isinstance(serialized, dict):
|
if isinstance(serialized, dict):
|
||||||
@@ -438,11 +445,20 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
||||||
|
|
||||||
# TODO: move to utils
|
@staticmethod
|
||||||
def _truncate_messages(
|
def _truncate_messages(
|
||||||
self, messages: Any, max_content_length: int = 500, max_messages: int = 5
|
messages: Any, max_content_length: int = 500, max_messages: int = 5
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Truncate message content and limit number of messages"""
|
"""Truncate message content and limit number of messages
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'content' keys.
|
||||||
|
max_content_length: Max length of each message content.
|
||||||
|
max_messages: Max number of messages to retain.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- TODO: refactor to utilities function.
|
||||||
|
"""
|
||||||
if not messages or not isinstance(messages, list):
|
if not messages or not isinstance(messages, list):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,38 @@ from crewai.utilities.logger_utils import suppress_logging
|
|||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_chromadb_response_item(
|
||||||
|
response_data: Any,
|
||||||
|
index: int,
|
||||||
|
expected_type: type[Any] | tuple[type[Any], ...],
|
||||||
|
) -> Any | None:
|
||||||
|
"""Extract an item from ChromaDB response data at the given index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
|
||||||
|
index: The index of the item to extract.
|
||||||
|
expected_type: The expected type(s) of the item.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted item if it exists and matches the expected type, otherwise None.
|
||||||
|
"""
|
||||||
|
if response_data is None or not response_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ChromaDB sometimes returns nested lists, handle both cases
|
||||||
|
data_list = (
|
||||||
|
response_data[0]
|
||||||
|
if response_data and isinstance(response_data[0], list)
|
||||||
|
else response_data
|
||||||
|
)
|
||||||
|
|
||||||
|
if index < len(data_list):
|
||||||
|
item = data_list[index]
|
||||||
|
if isinstance(item, expected_type):
|
||||||
|
return item
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||||
"""
|
"""
|
||||||
Extends Storage to handle embeddings for memory entries, improving
|
Extends Storage to handle embeddings for memory entries, improving
|
||||||
@@ -71,37 +103,22 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
)
|
)
|
||||||
for i in range(len(ids_list)):
|
for i in range(len(ids_list)):
|
||||||
# Handle metadatas
|
# Handle metadatas
|
||||||
metadata = {}
|
meta_item = _extract_chromadb_response_item(
|
||||||
if fetched.get("metadatas") and len(fetched["metadatas"]) > 0:
|
fetched.get("metadatas"), i, dict
|
||||||
metadata_list = (
|
)
|
||||||
fetched["metadatas"][0]
|
metadata: dict[str, Any] = meta_item if meta_item else {}
|
||||||
if isinstance(fetched["metadatas"][0], list)
|
|
||||||
else fetched["metadatas"]
|
|
||||||
)
|
|
||||||
if i < len(metadata_list):
|
|
||||||
metadata = metadata_list[i]
|
|
||||||
|
|
||||||
# Handle documents
|
# Handle documents
|
||||||
context = ""
|
doc_item = _extract_chromadb_response_item(
|
||||||
if fetched.get("documents") and len(fetched["documents"]) > 0:
|
fetched.get("documents"), i, str
|
||||||
docs_list = (
|
)
|
||||||
fetched["documents"][0]
|
context = doc_item if doc_item else ""
|
||||||
if isinstance(fetched["documents"][0], list)
|
|
||||||
else fetched["documents"]
|
|
||||||
)
|
|
||||||
if i < len(docs_list):
|
|
||||||
context = docs_list[i]
|
|
||||||
|
|
||||||
# Handle distances
|
# Handle distances
|
||||||
score = 1.0
|
dist_item = _extract_chromadb_response_item(
|
||||||
if fetched.get("distances") and len(fetched["distances"]) > 0:
|
fetched.get("distances"), i, (int, float)
|
||||||
dist_list = (
|
)
|
||||||
fetched["distances"][0]
|
score = dist_item if dist_item is not None else 1.0
|
||||||
if isinstance(fetched["distances"][0], list)
|
|
||||||
else fetched["distances"]
|
|
||||||
)
|
|
||||||
if i < len(dist_list):
|
|
||||||
score = dist_list[i]
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"id": ids_list[i],
|
"id": ids_list[i],
|
||||||
@@ -231,11 +248,14 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
"""Set the embedding configuration for the knowledge storage.
|
"""Set the embedding configuration for the knowledge storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
embedder (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||||
If None or empty, defaults to the default embedding function.
|
If None or empty, defaults to the default embedding function.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- TODO: Improve typing for embedder configuration, remove type: ignore
|
||||||
"""
|
"""
|
||||||
self.embedder = (
|
self.embedder = (
|
||||||
EmbeddingConfigurator().configure_embedder(embedder)
|
EmbeddingConfigurator().configure_embedder(embedder) # type: ignore
|
||||||
if embedder
|
if embedder
|
||||||
else self._create_default_embedding_function()
|
else self._create_default_embedding_function()
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user