fix: resolve additional mypy type annotation issues

- Fixed rag_storage.py embedder type compatibility and query response handling
- Fixed knowledge_storage.py dict type parameters and return types
- Added comprehensive type annotations to telemetry.py methods
- Added type annotations to trace_listener.py event handlers and methods
- Fixed ChromaDB response indexing safety checks
This commit is contained in:
Greyson LaLonde
2025-09-04 13:22:36 -04:00
parent 8dd3493e9c
commit 23c60befd8
4 changed files with 125 additions and 102 deletions

View File

@@ -2,6 +2,8 @@ import os
import uuid import uuid
from typing import Any, Optional from typing import Any, Optional
from typing_extensions import Self
from crewai.cli.authentication.token import AuthError, get_auth_token from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version from crewai.cli.version import get_crewai_version
from crewai.events.base_event_listener import BaseEventListener from crewai.events.base_event_listener import BaseEventListener
@@ -84,7 +86,7 @@ class TraceCollectionListener(BaseEventListener):
_initialized = False _initialized = False
_listeners_setup = False _listeners_setup = False
def __new__(cls, batch_manager=None): def __new__(cls, batch_manager: Optional[Any] = None) -> "TraceCollectionListener":
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
@@ -129,169 +131,169 @@ class TraceCollectionListener(BaseEventListener):
self._listeners_setup = True self._listeners_setup = True
def _register_flow_event_handlers(self, event_bus): def _register_flow_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for flow events""" """Register handlers for flow events"""
@event_bus.on(FlowCreatedEvent) @event_bus.on(FlowCreatedEvent)
def on_flow_created(source, event): def on_flow_created(source: Any, event: Any) -> None:
pass pass
@event_bus.on(FlowStartedEvent) @event_bus.on(FlowStartedEvent)
def on_flow_started(source, event): def on_flow_started(source: Any, event: Any) -> None:
if not self.batch_manager.is_batch_initialized(): if not self.batch_manager.is_batch_initialized():
self._initialize_flow_batch(source, event) self._initialize_flow_batch(source, event)
self._handle_trace_event("flow_started", source, event) self._handle_trace_event("flow_started", source, event)
@event_bus.on(MethodExecutionStartedEvent) @event_bus.on(MethodExecutionStartedEvent)
def on_method_started(source, event): def on_method_started(source: Any, event: Any) -> None:
self._handle_trace_event("method_execution_started", source, event) self._handle_trace_event("method_execution_started", source, event)
@event_bus.on(MethodExecutionFinishedEvent) @event_bus.on(MethodExecutionFinishedEvent)
def on_method_finished(source, event): def on_method_finished(source: Any, event: Any) -> None:
self._handle_trace_event("method_execution_finished", source, event) self._handle_trace_event("method_execution_finished", source, event)
@event_bus.on(MethodExecutionFailedEvent) @event_bus.on(MethodExecutionFailedEvent)
def on_method_failed(source, event): def on_method_failed(source: Any, event: Any) -> None:
self._handle_trace_event("method_execution_failed", source, event) self._handle_trace_event("method_execution_failed", source, event)
@event_bus.on(FlowFinishedEvent) @event_bus.on(FlowFinishedEvent)
def on_flow_finished(source, event): def on_flow_finished(source: Any, event: Any) -> None:
self._handle_trace_event("flow_finished", source, event) self._handle_trace_event("flow_finished", source, event)
if self.batch_manager.batch_owner_type == "flow": if self.batch_manager.batch_owner_type == "flow":
self.batch_manager.finalize_batch() self.batch_manager.finalize_batch()
@event_bus.on(FlowPlotEvent) @event_bus.on(FlowPlotEvent)
def on_flow_plot(source, event): def on_flow_plot(source: Any, event: Any) -> None:
self._handle_action_event("flow_plot", source, event) self._handle_action_event("flow_plot", source, event)
def _register_context_event_handlers(self, event_bus): def _register_context_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for context events (start/end)""" """Register handlers for context events (start/end)"""
@event_bus.on(CrewKickoffStartedEvent) @event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source, event): def on_crew_started(source: Any, event: Any) -> None:
if not self.batch_manager.is_batch_initialized(): if not self.batch_manager.is_batch_initialized():
self._initialize_crew_batch(source, event) self._initialize_crew_batch(source, event)
self._handle_trace_event("crew_kickoff_started", source, event) self._handle_trace_event("crew_kickoff_started", source, event)
@event_bus.on(CrewKickoffCompletedEvent) @event_bus.on(CrewKickoffCompletedEvent)
def on_crew_completed(source, event): def on_crew_completed(source: Any, event: Any) -> None:
self._handle_trace_event("crew_kickoff_completed", source, event) self._handle_trace_event("crew_kickoff_completed", source, event)
if self.batch_manager.batch_owner_type == "crew": if self.batch_manager.batch_owner_type == "crew":
self.batch_manager.finalize_batch() self.batch_manager.finalize_batch()
@event_bus.on(CrewKickoffFailedEvent) @event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source, event): def on_crew_failed(source: Any, event: Any) -> None:
self._handle_trace_event("crew_kickoff_failed", source, event) self._handle_trace_event("crew_kickoff_failed", source, event)
self.batch_manager.finalize_batch() self.batch_manager.finalize_batch()
@event_bus.on(TaskStartedEvent) @event_bus.on(TaskStartedEvent)
def on_task_started(source, event): def on_task_started(source: Any, event: Any) -> None:
self._handle_trace_event("task_started", source, event) self._handle_trace_event("task_started", source, event)
@event_bus.on(TaskCompletedEvent) @event_bus.on(TaskCompletedEvent)
def on_task_completed(source, event): def on_task_completed(source: Any, event: Any) -> None:
self._handle_trace_event("task_completed", source, event) self._handle_trace_event("task_completed", source, event)
@event_bus.on(TaskFailedEvent) @event_bus.on(TaskFailedEvent)
def on_task_failed(source, event): def on_task_failed(source: Any, event: Any) -> None:
self._handle_trace_event("task_failed", source, event) self._handle_trace_event("task_failed", source, event)
@event_bus.on(AgentExecutionStartedEvent) @event_bus.on(AgentExecutionStartedEvent)
def on_agent_started(source, event): def on_agent_started(source: Any, event: Any) -> None:
self._handle_trace_event("agent_execution_started", source, event) self._handle_trace_event("agent_execution_started", source, event)
@event_bus.on(AgentExecutionCompletedEvent) @event_bus.on(AgentExecutionCompletedEvent)
def on_agent_completed(source, event): def on_agent_completed(source: Any, event: Any) -> None:
self._handle_trace_event("agent_execution_completed", source, event) self._handle_trace_event("agent_execution_completed", source, event)
@event_bus.on(LiteAgentExecutionStartedEvent) @event_bus.on(LiteAgentExecutionStartedEvent)
def on_lite_agent_started(source, event): def on_lite_agent_started(source: Any, event: Any) -> None:
self._handle_trace_event("lite_agent_execution_started", source, event) self._handle_trace_event("lite_agent_execution_started", source, event)
@event_bus.on(LiteAgentExecutionCompletedEvent) @event_bus.on(LiteAgentExecutionCompletedEvent)
def on_lite_agent_completed(source, event): def on_lite_agent_completed(source: Any, event: Any) -> None:
self._handle_trace_event("lite_agent_execution_completed", source, event) self._handle_trace_event("lite_agent_execution_completed", source, event)
@event_bus.on(LiteAgentExecutionErrorEvent) @event_bus.on(LiteAgentExecutionErrorEvent)
def on_lite_agent_error(source, event): def on_lite_agent_error(source: Any, event: Any) -> None:
self._handle_trace_event("lite_agent_execution_error", source, event) self._handle_trace_event("lite_agent_execution_error", source, event)
@event_bus.on(AgentExecutionErrorEvent) @event_bus.on(AgentExecutionErrorEvent)
def on_agent_error(source, event): def on_agent_error(source: Any, event: Any) -> None:
self._handle_trace_event("agent_execution_error", source, event) self._handle_trace_event("agent_execution_error", source, event)
@event_bus.on(LLMGuardrailStartedEvent) @event_bus.on(LLMGuardrailStartedEvent)
def on_guardrail_started(source, event): def on_guardrail_started(source: Any, event: Any) -> None:
self._handle_trace_event("llm_guardrail_started", source, event) self._handle_trace_event("llm_guardrail_started", source, event)
@event_bus.on(LLMGuardrailCompletedEvent) @event_bus.on(LLMGuardrailCompletedEvent)
def on_guardrail_completed(source, event): def on_guardrail_completed(source: Any, event: Any) -> None:
self._handle_trace_event("llm_guardrail_completed", source, event) self._handle_trace_event("llm_guardrail_completed", source, event)
def _register_action_event_handlers(self, event_bus): def _register_action_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for action events (LLM calls, tool usage)""" """Register handlers for action events (LLM calls, tool usage)"""
@event_bus.on(LLMCallStartedEvent) @event_bus.on(LLMCallStartedEvent)
def on_llm_call_started(source, event): def on_llm_call_started(source: Any, event: Any) -> None:
self._handle_action_event("llm_call_started", source, event) self._handle_action_event("llm_call_started", source, event)
@event_bus.on(LLMCallCompletedEvent) @event_bus.on(LLMCallCompletedEvent)
def on_llm_call_completed(source, event): def on_llm_call_completed(source: Any, event: Any) -> None:
self._handle_action_event("llm_call_completed", source, event) self._handle_action_event("llm_call_completed", source, event)
@event_bus.on(LLMCallFailedEvent) @event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(source, event): def on_llm_call_failed(source: Any, event: Any) -> None:
self._handle_action_event("llm_call_failed", source, event) self._handle_action_event("llm_call_failed", source, event)
@event_bus.on(ToolUsageStartedEvent) @event_bus.on(ToolUsageStartedEvent)
def on_tool_started(source, event): def on_tool_started(source: Any, event: Any) -> None:
self._handle_action_event("tool_usage_started", source, event) self._handle_action_event("tool_usage_started", source, event)
@event_bus.on(ToolUsageFinishedEvent) @event_bus.on(ToolUsageFinishedEvent)
def on_tool_finished(source, event): def on_tool_finished(source: Any, event: Any) -> None:
self._handle_action_event("tool_usage_finished", source, event) self._handle_action_event("tool_usage_finished", source, event)
@event_bus.on(ToolUsageErrorEvent) @event_bus.on(ToolUsageErrorEvent)
def on_tool_error(source, event): def on_tool_error(source: Any, event: Any) -> None:
self._handle_action_event("tool_usage_error", source, event) self._handle_action_event("tool_usage_error", source, event)
@event_bus.on(MemoryQueryStartedEvent) @event_bus.on(MemoryQueryStartedEvent)
def on_memory_query_started(source, event): def on_memory_query_started(source: Any, event: Any) -> None:
self._handle_action_event("memory_query_started", source, event) self._handle_action_event("memory_query_started", source, event)
@event_bus.on(MemoryQueryCompletedEvent) @event_bus.on(MemoryQueryCompletedEvent)
def on_memory_query_completed(source, event): def on_memory_query_completed(source: Any, event: Any) -> None:
self._handle_action_event("memory_query_completed", source, event) self._handle_action_event("memory_query_completed", source, event)
@event_bus.on(MemoryQueryFailedEvent) @event_bus.on(MemoryQueryFailedEvent)
def on_memory_query_failed(source, event): def on_memory_query_failed(source: Any, event: Any) -> None:
self._handle_action_event("memory_query_failed", source, event) self._handle_action_event("memory_query_failed", source, event)
@event_bus.on(MemorySaveStartedEvent) @event_bus.on(MemorySaveStartedEvent)
def on_memory_save_started(source, event): def on_memory_save_started(source: Any, event: Any) -> None:
self._handle_action_event("memory_save_started", source, event) self._handle_action_event("memory_save_started", source, event)
@event_bus.on(MemorySaveCompletedEvent) @event_bus.on(MemorySaveCompletedEvent)
def on_memory_save_completed(source, event): def on_memory_save_completed(source: Any, event: Any) -> None:
self._handle_action_event("memory_save_completed", source, event) self._handle_action_event("memory_save_completed", source, event)
@event_bus.on(MemorySaveFailedEvent) @event_bus.on(MemorySaveFailedEvent)
def on_memory_save_failed(source, event): def on_memory_save_failed(source: Any, event: Any) -> None:
self._handle_action_event("memory_save_failed", source, event) self._handle_action_event("memory_save_failed", source, event)
@event_bus.on(AgentReasoningStartedEvent) @event_bus.on(AgentReasoningStartedEvent)
def on_agent_reasoning_started(source, event): def on_agent_reasoning_started(source: Any, event: Any) -> None:
self._handle_action_event("agent_reasoning_started", source, event) self._handle_action_event("agent_reasoning_started", source, event)
@event_bus.on(AgentReasoningCompletedEvent) @event_bus.on(AgentReasoningCompletedEvent)
def on_agent_reasoning_completed(source, event): def on_agent_reasoning_completed(source: Any, event: Any) -> None:
self._handle_action_event("agent_reasoning_completed", source, event) self._handle_action_event("agent_reasoning_completed", source, event)
@event_bus.on(AgentReasoningFailedEvent) @event_bus.on(AgentReasoningFailedEvent)
def on_agent_reasoning_failed(source, event): def on_agent_reasoning_failed(source: Any, event: Any) -> None:
self._handle_action_event("agent_reasoning_failed", source, event) self._handle_action_event("agent_reasoning_failed", source, event)
def _initialize_crew_batch(self, source: Any, event: Any): def _initialize_crew_batch(self, source: Any, event: Any) -> None:
"""Initialize trace batch""" """Initialize trace batch"""
user_context = self._get_user_context() user_context = self._get_user_context()
execution_metadata = { execution_metadata = {
@@ -305,7 +307,7 @@ class TraceCollectionListener(BaseEventListener):
self._initialize_batch(user_context, execution_metadata) self._initialize_batch(user_context, execution_metadata)
def _initialize_flow_batch(self, source: Any, event: Any): def _initialize_flow_batch(self, source: Any, event: Any) -> None:
"""Initialize trace batch for Flow execution""" """Initialize trace batch for Flow execution"""
user_context = self._get_user_context() user_context = self._get_user_context()
execution_metadata = { execution_metadata = {
@@ -333,14 +335,14 @@ class TraceCollectionListener(BaseEventListener):
user_context, execution_metadata, use_ephemeral=False user_context, execution_metadata, use_ephemeral=False
) )
def _handle_trace_event(self, event_type: str, source: Any, event: Any): def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None:
"""Generic handler for context end events""" """Generic handler for context end events"""
trace_event = self._create_trace_event(event_type, source, event) trace_event = self._create_trace_event(event_type, source, event)
self.batch_manager.add_event(trace_event) self.batch_manager.add_event(trace_event)
def _handle_action_event(self, event_type: str, source: Any, event: Any): def _handle_action_event(self, event_type: str, source: Any, event: Any) -> None:
"""Generic handler for action events (LLM calls, tool usage)""" """Generic handler for action events (LLM calls, tool usage)"""
if not self.batch_manager.is_batch_initialized(): if not self.batch_manager.is_batch_initialized():
@@ -437,7 +439,9 @@ class TraceCollectionListener(BaseEventListener):
return {"serialization_error": str(e), "object_type": type(obj).__name__} return {"serialization_error": str(e), "object_type": type(obj).__name__}
# TODO: move to utils # TODO: move to utils
def _truncate_messages(self, messages, max_content_length=500, max_messages=5): def _truncate_messages(
self, messages: Any, max_content_length: int = 500, max_messages: int = 5
) -> Any:
"""Truncate message content and limit number of messages""" """Truncate message content and limit number of messages"""
if not messages or not isinstance(messages, list): if not messages or not isinstance(messages, list):
return messages return messages

View File

@@ -7,6 +7,7 @@ from typing import Any, Optional, Union
import chromadb import chromadb
import chromadb.errors import chromadb.errors
from chromadb import EmbeddingFunction
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany from chromadb.api.types import OneOrMany
from chromadb.config import Settings from chromadb.config import Settings
@@ -29,6 +30,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
collection: Optional[chromadb.Collection] = None collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge" collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None app: Optional[ClientAPI] = None
embedder: Optional[EmbeddingFunction] = None
def __init__( def __init__(
self, self,
@@ -42,7 +44,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
self, self,
query: list[str], query: list[str],
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: Optional[dict[str, Any]] = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
with suppress_logging( with suppress_logging(
@@ -55,14 +57,16 @@ class KnowledgeStorage(BaseKnowledgeStorage):
where=filter, where=filter,
) )
results = [] results = []
for i in range(len(fetched["ids"][0])): # type: ignore for i in range(len(fetched["ids"][0])):
result = { result = {
"id": fetched["ids"][0][i], # type: ignore "id": fetched["ids"][0][i],
"metadata": fetched["metadatas"][0][i], # type: ignore "metadata": fetched["metadatas"][0][i],
"context": fetched["documents"][0][i], # type: ignore "context": fetched["documents"][0][i],
"score": fetched["distances"][0][i], # type: ignore "score": fetched["distances"][0][i],
} }
if result["score"] >= score_threshold: if (
result["score"] <= score_threshold
): # Note: distances are smaller when more similar
results.append(result) results.append(result)
return results return results
else: else:
@@ -114,7 +118,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
self, self,
documents: list[str], documents: list[str],
metadata: Optional[dict[str, Any] | list[dict[str, Any]]] = None, metadata: Optional[dict[str, Any] | list[dict[str, Any]]] = None,
): ) -> None:
if not self.collection: if not self.collection:
raise Exception("Collection not initialized") raise Exception("Collection not initialized")
@@ -169,7 +173,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red") Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise raise
def _create_default_embedding_function(self): def _create_default_embedding_function(self) -> Any:
from chromadb.utils.embedding_functions.openai_embedding_function import ( from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction, OpenAIEmbeddingFunction,
) )

View File

@@ -5,6 +5,7 @@ import uuid
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from chromadb import EmbeddingFunction
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from crewai.rag.embeddings.configurator import EmbeddingConfigurator from crewai.rag.embeddings.configurator import EmbeddingConfigurator
@@ -22,6 +23,7 @@ class RAGStorage(BaseRAGStorage):
""" """
app: ClientAPI | None = None app: ClientAPI | None = None
embedder_config: EmbeddingFunction | None = None # type: ignore[assignment]
def __init__( def __init__(
self, self,
@@ -44,11 +46,11 @@ class RAGStorage(BaseRAGStorage):
self.path = path self.path = path
self._initialize_app() self._initialize_app()
def _set_embedder_config(self): def _set_embedder_config(self) -> None:
configurator = EmbeddingConfigurator() configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config) self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self): def _initialize_app(self) -> None:
from chromadb.config import Settings from chromadb.config import Settings
# Suppress deprecation warnings from chromadb, which are not relevant to us # Suppress deprecation warnings from chromadb, which are not relevant to us
@@ -103,7 +105,7 @@ class RAGStorage(BaseRAGStorage):
self, self,
query: str, query: str,
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: Optional[dict[str, Any]] = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> list[Any]: ) -> list[Any]:
if not hasattr(self, "app"): if not hasattr(self, "app"):
@@ -116,15 +118,24 @@ class RAGStorage(BaseRAGStorage):
response = self.collection.query(query_texts=query, n_results=limit) response = self.collection.query(query_texts=query, n_results=limit)
results = [] results = []
for i in range(len(response["ids"][0])): if response and "ids" in response and response["ids"]:
result = { for i in range(len(response["ids"][0])):
"id": response["ids"][0][i], result = {
"metadata": response["metadatas"][0][i], "id": response["ids"][0][i],
"context": response["documents"][0][i], "metadata": response["metadatas"][0][i]
"score": response["distances"][0][i], if response.get("metadatas")
} else {},
if result["score"] >= score_threshold: "context": response["documents"][0][i]
results.append(result) if response.get("documents")
else "",
"score": response["distances"][0][i]
if response.get("distances")
else 1.0,
}
if (
result["score"] <= score_threshold
): # Note: distances are smaller when more similar
results.append(result)
return results return results
except Exception as e: except Exception as e:

View File

@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
@contextmanager @contextmanager
def suppress_warnings(): def suppress_warnings() -> Any:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
yield yield
@@ -45,7 +45,7 @@ if TYPE_CHECKING:
class SafeOTLPSpanExporter(OTLPSpanExporter): class SafeOTLPSpanExporter(OTLPSpanExporter):
def export(self, spans) -> SpanExportResult: def export(self, spans: Any) -> SpanExportResult:
try: try:
return super().export(spans) return super().export(spans)
except Exception as e: except Exception as e:
@@ -69,7 +69,7 @@ class Telemetry:
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
def __new__(cls): def __new__(cls) -> Telemetry:
if cls._instance is None: if cls._instance is None:
with cls._lock: with cls._lock:
if cls._instance is None: if cls._instance is None:
@@ -144,10 +144,10 @@ class Telemetry:
except Exception: except Exception:
pass pass
def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None): def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None) -> None:
"""Records the creation of a crew.""" """Records the creation of a crew."""
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Created") span = tracer.start_span("Crew Created")
self._add_attribute( self._add_attribute(
@@ -352,7 +352,7 @@ class Telemetry:
def task_started(self, crew: Crew, task: Task) -> Span | None: def task_started(self, crew: Crew, task: Task) -> Span | None:
"""Records task started in a crew.""" """Records task started in a crew."""
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
created_span = tracer.start_span("Task Created") created_span = tracer.start_span("Task Created")
@@ -439,7 +439,7 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
return None return None
def task_ended(self, span: Span, task: Task, crew: Crew): def task_ended(self, span: Span, task: Task, crew: Crew) -> None:
"""Records the completion of a task execution in a crew. """Records the completion of a task execution in a crew.
Args: Args:
@@ -451,7 +451,7 @@ class Telemetry:
If share_crew is enabled, this will also record the task output If share_crew is enabled, this will also record the task output
""" """
def operation(): def operation() -> Any:
# Ensure fingerprint data is present on completion span # Ensure fingerprint data is present on completion span
if hasattr(task, "fingerprint") and task.fingerprint: if hasattr(task, "fingerprint") and task.fingerprint:
self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str) self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str)
@@ -468,7 +468,7 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int): def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int) -> None:
"""Records when a tool is used repeatedly, which might indicate an issue. """Records when a tool is used repeatedly, which might indicate an issue.
Args: Args:
@@ -477,7 +477,7 @@ class Telemetry:
attempts (int): Number of attempts made with this tool attempts (int): Number of attempts made with this tool
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Tool Repeated Usage") span = tracer.start_span("Tool Repeated Usage")
self._add_attribute( self._add_attribute(
@@ -494,7 +494,9 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def tool_usage(self, llm: Any, tool_name: str, attempts: int, agent: Any = None): def tool_usage(
self, llm: Any, tool_name: str, attempts: int, agent: Any = None
) -> None:
"""Records the usage of a tool by an agent. """Records the usage of a tool by an agent.
Args: Args:
@@ -504,7 +506,7 @@ class Telemetry:
agent (Any, optional): The agent using the tool agent (Any, optional): The agent using the tool
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Tool Usage") span = tracer.start_span("Tool Usage")
self._add_attribute( self._add_attribute(
@@ -541,7 +543,7 @@ class Telemetry:
tool_name (str, optional): Name of the tool that caused the error tool_name (str, optional): Name of the tool that caused the error
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Tool Usage Error") span = tracer.start_span("Tool Usage Error")
self._add_attribute( self._add_attribute(
@@ -580,7 +582,7 @@ class Telemetry:
model_name (str): Name of the model used model_name (str): Name of the model used
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Individual Test Result") span = tracer.start_span("Crew Individual Test Result")
@@ -615,7 +617,7 @@ class Telemetry:
model_name (str): Name of the model used in testing model_name (str): Name of the model used in testing
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Test Execution") span = tracer.start_span("Crew Test Execution")
@@ -639,10 +641,10 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def deploy_signup_error_span(self): def deploy_signup_error_span(self) -> None:
"""Records when an error occurs during the deployment signup process.""" """Records when an error occurs during the deployment signup process."""
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Deploy Signup Error") span = tracer.start_span("Deploy Signup Error")
span.set_status(Status(StatusCode.OK)) span.set_status(Status(StatusCode.OK))
@@ -650,14 +652,14 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def start_deployment_span(self, uuid: Optional[str] = None): def start_deployment_span(self, uuid: Optional[str] = None) -> None:
"""Records the start of a deployment process. """Records the start of a deployment process.
Args: Args:
uuid (Optional[str]): Unique identifier for the deployment uuid (Optional[str]): Unique identifier for the deployment
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Start Deployment") span = tracer.start_span("Start Deployment")
if uuid: if uuid:
@@ -667,10 +669,10 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def create_crew_deployment_span(self): def create_crew_deployment_span(self) -> None:
"""Records the creation of a new crew deployment.""" """Records the creation of a new crew deployment."""
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Create Crew Deployment") span = tracer.start_span("Create Crew Deployment")
span.set_status(Status(StatusCode.OK)) span.set_status(Status(StatusCode.OK))
@@ -678,7 +680,9 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def get_crew_logs_span(self, uuid: Optional[str], log_type: str = "deployment"): def get_crew_logs_span(
self, uuid: Optional[str], log_type: str = "deployment"
) -> None:
"""Records the retrieval of crew logs. """Records the retrieval of crew logs.
Args: Args:
@@ -686,7 +690,7 @@ class Telemetry:
log_type (str, optional): Type of logs being retrieved. Defaults to "deployment". log_type (str, optional): Type of logs being retrieved. Defaults to "deployment".
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Get Crew Logs") span = tracer.start_span("Get Crew Logs")
self._add_attribute(span, "log_type", log_type) self._add_attribute(span, "log_type", log_type)
@@ -697,14 +701,14 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def remove_crew_span(self, uuid: Optional[str] = None): def remove_crew_span(self, uuid: Optional[str] = None) -> None:
"""Records the removal of a crew. """Records the removal of a crew.
Args: Args:
uuid (Optional[str]): Unique identifier for the crew being removed uuid (Optional[str]): Unique identifier for the crew being removed
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Remove Crew") span = tracer.start_span("Remove Crew")
if uuid: if uuid:
@@ -714,13 +718,13 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def crew_execution_span(self, crew: Crew, inputs: dict[str, Any] | None): def crew_execution_span(self, crew: Crew, inputs: dict[str, Any] | None) -> None:
"""Records the complete execution of a crew. """Records the complete execution of a crew.
This is only collected if the user has opted-in to share the crew. This is only collected if the user has opted-in to share the crew.
""" """
self.crew_creation(crew, inputs) self.crew_creation(crew, inputs)
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Execution") span = tracer.start_span("Crew Execution")
self._add_attribute( self._add_attribute(
@@ -792,7 +796,7 @@ class Telemetry:
return None return None
def end_crew(self, crew: Any, final_string_output: str) -> None: def end_crew(self, crew: Any, final_string_output: str) -> None:
def operation(): def operation() -> Any:
self._add_attribute( self._add_attribute(
crew._execution_span, crew._execution_span,
"crewai_version", "crewai_version",
@@ -821,22 +825,22 @@ class Telemetry:
if crew.share_crew: if crew.share_crew:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def _add_attribute(self, span, key, value): def _add_attribute(self, span: Any, key: str, value: Any) -> None:
"""Add an attribute to a span.""" """Add an attribute to a span."""
def operation(): def operation() -> Any:
return span.set_attribute(key, value) return span.set_attribute(key, value)
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def flow_creation_span(self, flow_name: str): def flow_creation_span(self, flow_name: str) -> None:
"""Records the creation of a new flow. """Records the creation of a new flow.
Args: Args:
flow_name (str): Name of the flow being created flow_name (str): Name of the flow being created
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Flow Creation") span = tracer.start_span("Flow Creation")
self._add_attribute(span, "flow_name", flow_name) self._add_attribute(span, "flow_name", flow_name)
@@ -845,7 +849,7 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def flow_plotting_span(self, flow_name: str, node_names: list[str]): def flow_plotting_span(self, flow_name: str, node_names: list[str]) -> None:
"""Records flow visualization/plotting activity. """Records flow visualization/plotting activity.
Args: Args:
@@ -853,7 +857,7 @@ class Telemetry:
node_names (list[str]): List of node names in the flow node_names (list[str]): List of node names in the flow
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Flow Plotting") span = tracer.start_span("Flow Plotting")
self._add_attribute(span, "flow_name", flow_name) self._add_attribute(span, "flow_name", flow_name)
@@ -863,7 +867,7 @@ class Telemetry:
self._safe_telemetry_operation(operation) self._safe_telemetry_operation(operation)
def flow_execution_span(self, flow_name: str, node_names: list[str]): def flow_execution_span(self, flow_name: str, node_names: list[str]) -> None:
"""Records the execution of a flow. """Records the execution of a flow.
Args: Args:
@@ -871,7 +875,7 @@ class Telemetry:
node_names (list[str]): List of nodes being executed in the flow node_names (list[str]): List of nodes being executed in the flow
""" """
def operation(): def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry") tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Flow Execution") span = tracer.start_span("Flow Execution")
self._add_attribute(span, "flow_name", flow_name) self._add_attribute(span, "flow_name", flow_name)