mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 23:32:39 +00:00
fix: resolve additional mypy type annotation issues
- Fixed rag_storage.py embedder type compatibility and query response handling - Fixed knowledge_storage.py dict type parameters and return types - Added comprehensive type annotations to telemetry.py methods - Added type annotations to trace_listener.py event handlers and methods - Fixed ChromaDB response indexing safety checks
This commit is contained in:
@@ -2,6 +2,8 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||||
from crewai.cli.version import get_crewai_version
|
from crewai.cli.version import get_crewai_version
|
||||||
from crewai.events.base_event_listener import BaseEventListener
|
from crewai.events.base_event_listener import BaseEventListener
|
||||||
@@ -84,7 +86,7 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
_initialized = False
|
_initialized = False
|
||||||
_listeners_setup = False
|
_listeners_setup = False
|
||||||
|
|
||||||
def __new__(cls, batch_manager=None):
|
def __new__(cls, batch_manager: Optional[Any] = None) -> "TraceCollectionListener":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
@@ -129,169 +131,169 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
|
|
||||||
self._listeners_setup = True
|
self._listeners_setup = True
|
||||||
|
|
||||||
def _register_flow_event_handlers(self, event_bus):
|
def _register_flow_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||||
"""Register handlers for flow events"""
|
"""Register handlers for flow events"""
|
||||||
|
|
||||||
@event_bus.on(FlowCreatedEvent)
|
@event_bus.on(FlowCreatedEvent)
|
||||||
def on_flow_created(source, event):
|
def on_flow_created(source: Any, event: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@event_bus.on(FlowStartedEvent)
|
@event_bus.on(FlowStartedEvent)
|
||||||
def on_flow_started(source, event):
|
def on_flow_started(source: Any, event: Any) -> None:
|
||||||
if not self.batch_manager.is_batch_initialized():
|
if not self.batch_manager.is_batch_initialized():
|
||||||
self._initialize_flow_batch(source, event)
|
self._initialize_flow_batch(source, event)
|
||||||
self._handle_trace_event("flow_started", source, event)
|
self._handle_trace_event("flow_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(MethodExecutionStartedEvent)
|
@event_bus.on(MethodExecutionStartedEvent)
|
||||||
def on_method_started(source, event):
|
def on_method_started(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("method_execution_started", source, event)
|
self._handle_trace_event("method_execution_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(MethodExecutionFinishedEvent)
|
@event_bus.on(MethodExecutionFinishedEvent)
|
||||||
def on_method_finished(source, event):
|
def on_method_finished(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("method_execution_finished", source, event)
|
self._handle_trace_event("method_execution_finished", source, event)
|
||||||
|
|
||||||
@event_bus.on(MethodExecutionFailedEvent)
|
@event_bus.on(MethodExecutionFailedEvent)
|
||||||
def on_method_failed(source, event):
|
def on_method_failed(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("method_execution_failed", source, event)
|
self._handle_trace_event("method_execution_failed", source, event)
|
||||||
|
|
||||||
@event_bus.on(FlowFinishedEvent)
|
@event_bus.on(FlowFinishedEvent)
|
||||||
def on_flow_finished(source, event):
|
def on_flow_finished(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("flow_finished", source, event)
|
self._handle_trace_event("flow_finished", source, event)
|
||||||
if self.batch_manager.batch_owner_type == "flow":
|
if self.batch_manager.batch_owner_type == "flow":
|
||||||
self.batch_manager.finalize_batch()
|
self.batch_manager.finalize_batch()
|
||||||
|
|
||||||
@event_bus.on(FlowPlotEvent)
|
@event_bus.on(FlowPlotEvent)
|
||||||
def on_flow_plot(source, event):
|
def on_flow_plot(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("flow_plot", source, event)
|
self._handle_action_event("flow_plot", source, event)
|
||||||
|
|
||||||
def _register_context_event_handlers(self, event_bus):
|
def _register_context_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||||
"""Register handlers for context events (start/end)"""
|
"""Register handlers for context events (start/end)"""
|
||||||
|
|
||||||
@event_bus.on(CrewKickoffStartedEvent)
|
@event_bus.on(CrewKickoffStartedEvent)
|
||||||
def on_crew_started(source, event):
|
def on_crew_started(source: Any, event: Any) -> None:
|
||||||
if not self.batch_manager.is_batch_initialized():
|
if not self.batch_manager.is_batch_initialized():
|
||||||
self._initialize_crew_batch(source, event)
|
self._initialize_crew_batch(source, event)
|
||||||
self._handle_trace_event("crew_kickoff_started", source, event)
|
self._handle_trace_event("crew_kickoff_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(CrewKickoffCompletedEvent)
|
@event_bus.on(CrewKickoffCompletedEvent)
|
||||||
def on_crew_completed(source, event):
|
def on_crew_completed(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("crew_kickoff_completed", source, event)
|
self._handle_trace_event("crew_kickoff_completed", source, event)
|
||||||
if self.batch_manager.batch_owner_type == "crew":
|
if self.batch_manager.batch_owner_type == "crew":
|
||||||
self.batch_manager.finalize_batch()
|
self.batch_manager.finalize_batch()
|
||||||
|
|
||||||
@event_bus.on(CrewKickoffFailedEvent)
|
@event_bus.on(CrewKickoffFailedEvent)
|
||||||
def on_crew_failed(source, event):
|
def on_crew_failed(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("crew_kickoff_failed", source, event)
|
self._handle_trace_event("crew_kickoff_failed", source, event)
|
||||||
self.batch_manager.finalize_batch()
|
self.batch_manager.finalize_batch()
|
||||||
|
|
||||||
@event_bus.on(TaskStartedEvent)
|
@event_bus.on(TaskStartedEvent)
|
||||||
def on_task_started(source, event):
|
def on_task_started(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("task_started", source, event)
|
self._handle_trace_event("task_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(TaskCompletedEvent)
|
@event_bus.on(TaskCompletedEvent)
|
||||||
def on_task_completed(source, event):
|
def on_task_completed(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("task_completed", source, event)
|
self._handle_trace_event("task_completed", source, event)
|
||||||
|
|
||||||
@event_bus.on(TaskFailedEvent)
|
@event_bus.on(TaskFailedEvent)
|
||||||
def on_task_failed(source, event):
|
def on_task_failed(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("task_failed", source, event)
|
self._handle_trace_event("task_failed", source, event)
|
||||||
|
|
||||||
@event_bus.on(AgentExecutionStartedEvent)
|
@event_bus.on(AgentExecutionStartedEvent)
|
||||||
def on_agent_started(source, event):
|
def on_agent_started(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("agent_execution_started", source, event)
|
self._handle_trace_event("agent_execution_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(AgentExecutionCompletedEvent)
|
@event_bus.on(AgentExecutionCompletedEvent)
|
||||||
def on_agent_completed(source, event):
|
def on_agent_completed(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("agent_execution_completed", source, event)
|
self._handle_trace_event("agent_execution_completed", source, event)
|
||||||
|
|
||||||
@event_bus.on(LiteAgentExecutionStartedEvent)
|
@event_bus.on(LiteAgentExecutionStartedEvent)
|
||||||
def on_lite_agent_started(source, event):
|
def on_lite_agent_started(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("lite_agent_execution_started", source, event)
|
self._handle_trace_event("lite_agent_execution_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(LiteAgentExecutionCompletedEvent)
|
@event_bus.on(LiteAgentExecutionCompletedEvent)
|
||||||
def on_lite_agent_completed(source, event):
|
def on_lite_agent_completed(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("lite_agent_execution_completed", source, event)
|
self._handle_trace_event("lite_agent_execution_completed", source, event)
|
||||||
|
|
||||||
@event_bus.on(LiteAgentExecutionErrorEvent)
|
@event_bus.on(LiteAgentExecutionErrorEvent)
|
||||||
def on_lite_agent_error(source, event):
|
def on_lite_agent_error(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("lite_agent_execution_error", source, event)
|
self._handle_trace_event("lite_agent_execution_error", source, event)
|
||||||
|
|
||||||
@event_bus.on(AgentExecutionErrorEvent)
|
@event_bus.on(AgentExecutionErrorEvent)
|
||||||
def on_agent_error(source, event):
|
def on_agent_error(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("agent_execution_error", source, event)
|
self._handle_trace_event("agent_execution_error", source, event)
|
||||||
|
|
||||||
@event_bus.on(LLMGuardrailStartedEvent)
|
@event_bus.on(LLMGuardrailStartedEvent)
|
||||||
def on_guardrail_started(source, event):
|
def on_guardrail_started(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("llm_guardrail_started", source, event)
|
self._handle_trace_event("llm_guardrail_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(LLMGuardrailCompletedEvent)
|
@event_bus.on(LLMGuardrailCompletedEvent)
|
||||||
def on_guardrail_completed(source, event):
|
def on_guardrail_completed(source: Any, event: Any) -> None:
|
||||||
self._handle_trace_event("llm_guardrail_completed", source, event)
|
self._handle_trace_event("llm_guardrail_completed", source, event)
|
||||||
|
|
||||||
def _register_action_event_handlers(self, event_bus):
|
def _register_action_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||||
"""Register handlers for action events (LLM calls, tool usage)"""
|
"""Register handlers for action events (LLM calls, tool usage)"""
|
||||||
|
|
||||||
@event_bus.on(LLMCallStartedEvent)
|
@event_bus.on(LLMCallStartedEvent)
|
||||||
def on_llm_call_started(source, event):
|
def on_llm_call_started(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("llm_call_started", source, event)
|
self._handle_action_event("llm_call_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(LLMCallCompletedEvent)
|
@event_bus.on(LLMCallCompletedEvent)
|
||||||
def on_llm_call_completed(source, event):
|
def on_llm_call_completed(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("llm_call_completed", source, event)
|
self._handle_action_event("llm_call_completed", source, event)
|
||||||
|
|
||||||
@event_bus.on(LLMCallFailedEvent)
|
@event_bus.on(LLMCallFailedEvent)
|
||||||
def on_llm_call_failed(source, event):
|
def on_llm_call_failed(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("llm_call_failed", source, event)
|
self._handle_action_event("llm_call_failed", source, event)
|
||||||
|
|
||||||
@event_bus.on(ToolUsageStartedEvent)
|
@event_bus.on(ToolUsageStartedEvent)
|
||||||
def on_tool_started(source, event):
|
def on_tool_started(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("tool_usage_started", source, event)
|
self._handle_action_event("tool_usage_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(ToolUsageFinishedEvent)
|
@event_bus.on(ToolUsageFinishedEvent)
|
||||||
def on_tool_finished(source, event):
|
def on_tool_finished(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("tool_usage_finished", source, event)
|
self._handle_action_event("tool_usage_finished", source, event)
|
||||||
|
|
||||||
@event_bus.on(ToolUsageErrorEvent)
|
@event_bus.on(ToolUsageErrorEvent)
|
||||||
def on_tool_error(source, event):
|
def on_tool_error(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("tool_usage_error", source, event)
|
self._handle_action_event("tool_usage_error", source, event)
|
||||||
|
|
||||||
@event_bus.on(MemoryQueryStartedEvent)
|
@event_bus.on(MemoryQueryStartedEvent)
|
||||||
def on_memory_query_started(source, event):
|
def on_memory_query_started(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("memory_query_started", source, event)
|
self._handle_action_event("memory_query_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(MemoryQueryCompletedEvent)
|
@event_bus.on(MemoryQueryCompletedEvent)
|
||||||
def on_memory_query_completed(source, event):
|
def on_memory_query_completed(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("memory_query_completed", source, event)
|
self._handle_action_event("memory_query_completed", source, event)
|
||||||
|
|
||||||
@event_bus.on(MemoryQueryFailedEvent)
|
@event_bus.on(MemoryQueryFailedEvent)
|
||||||
def on_memory_query_failed(source, event):
|
def on_memory_query_failed(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("memory_query_failed", source, event)
|
self._handle_action_event("memory_query_failed", source, event)
|
||||||
|
|
||||||
@event_bus.on(MemorySaveStartedEvent)
|
@event_bus.on(MemorySaveStartedEvent)
|
||||||
def on_memory_save_started(source, event):
|
def on_memory_save_started(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("memory_save_started", source, event)
|
self._handle_action_event("memory_save_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(MemorySaveCompletedEvent)
|
@event_bus.on(MemorySaveCompletedEvent)
|
||||||
def on_memory_save_completed(source, event):
|
def on_memory_save_completed(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("memory_save_completed", source, event)
|
self._handle_action_event("memory_save_completed", source, event)
|
||||||
|
|
||||||
@event_bus.on(MemorySaveFailedEvent)
|
@event_bus.on(MemorySaveFailedEvent)
|
||||||
def on_memory_save_failed(source, event):
|
def on_memory_save_failed(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("memory_save_failed", source, event)
|
self._handle_action_event("memory_save_failed", source, event)
|
||||||
|
|
||||||
@event_bus.on(AgentReasoningStartedEvent)
|
@event_bus.on(AgentReasoningStartedEvent)
|
||||||
def on_agent_reasoning_started(source, event):
|
def on_agent_reasoning_started(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("agent_reasoning_started", source, event)
|
self._handle_action_event("agent_reasoning_started", source, event)
|
||||||
|
|
||||||
@event_bus.on(AgentReasoningCompletedEvent)
|
@event_bus.on(AgentReasoningCompletedEvent)
|
||||||
def on_agent_reasoning_completed(source, event):
|
def on_agent_reasoning_completed(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("agent_reasoning_completed", source, event)
|
self._handle_action_event("agent_reasoning_completed", source, event)
|
||||||
|
|
||||||
@event_bus.on(AgentReasoningFailedEvent)
|
@event_bus.on(AgentReasoningFailedEvent)
|
||||||
def on_agent_reasoning_failed(source, event):
|
def on_agent_reasoning_failed(source: Any, event: Any) -> None:
|
||||||
self._handle_action_event("agent_reasoning_failed", source, event)
|
self._handle_action_event("agent_reasoning_failed", source, event)
|
||||||
|
|
||||||
def _initialize_crew_batch(self, source: Any, event: Any):
|
def _initialize_crew_batch(self, source: Any, event: Any) -> None:
|
||||||
"""Initialize trace batch"""
|
"""Initialize trace batch"""
|
||||||
user_context = self._get_user_context()
|
user_context = self._get_user_context()
|
||||||
execution_metadata = {
|
execution_metadata = {
|
||||||
@@ -305,7 +307,7 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
|
|
||||||
self._initialize_batch(user_context, execution_metadata)
|
self._initialize_batch(user_context, execution_metadata)
|
||||||
|
|
||||||
def _initialize_flow_batch(self, source: Any, event: Any):
|
def _initialize_flow_batch(self, source: Any, event: Any) -> None:
|
||||||
"""Initialize trace batch for Flow execution"""
|
"""Initialize trace batch for Flow execution"""
|
||||||
user_context = self._get_user_context()
|
user_context = self._get_user_context()
|
||||||
execution_metadata = {
|
execution_metadata = {
|
||||||
@@ -333,14 +335,14 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
user_context, execution_metadata, use_ephemeral=False
|
user_context, execution_metadata, use_ephemeral=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
|
def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None:
|
||||||
"""Generic handler for context end events"""
|
"""Generic handler for context end events"""
|
||||||
|
|
||||||
trace_event = self._create_trace_event(event_type, source, event)
|
trace_event = self._create_trace_event(event_type, source, event)
|
||||||
|
|
||||||
self.batch_manager.add_event(trace_event)
|
self.batch_manager.add_event(trace_event)
|
||||||
|
|
||||||
def _handle_action_event(self, event_type: str, source: Any, event: Any):
|
def _handle_action_event(self, event_type: str, source: Any, event: Any) -> None:
|
||||||
"""Generic handler for action events (LLM calls, tool usage)"""
|
"""Generic handler for action events (LLM calls, tool usage)"""
|
||||||
|
|
||||||
if not self.batch_manager.is_batch_initialized():
|
if not self.batch_manager.is_batch_initialized():
|
||||||
@@ -437,7 +439,9 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
||||||
|
|
||||||
# TODO: move to utils
|
# TODO: move to utils
|
||||||
def _truncate_messages(self, messages, max_content_length=500, max_messages=5):
|
def _truncate_messages(
|
||||||
|
self, messages: Any, max_content_length: int = 500, max_messages: int = 5
|
||||||
|
) -> Any:
|
||||||
"""Truncate message content and limit number of messages"""
|
"""Truncate message content and limit number of messages"""
|
||||||
if not messages or not isinstance(messages, list):
|
if not messages or not isinstance(messages, list):
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Any, Optional, Union
|
|||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import chromadb.errors
|
import chromadb.errors
|
||||||
|
from chromadb import EmbeddingFunction
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
from chromadb.api.types import OneOrMany
|
from chromadb.api.types import OneOrMany
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
@@ -29,6 +30,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
collection: Optional[chromadb.Collection] = None
|
collection: Optional[chromadb.Collection] = None
|
||||||
collection_name: Optional[str] = "knowledge"
|
collection_name: Optional[str] = "knowledge"
|
||||||
app: Optional[ClientAPI] = None
|
app: Optional[ClientAPI] = None
|
||||||
|
embedder: Optional[EmbeddingFunction] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -42,7 +44,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
self,
|
self,
|
||||||
query: list[str],
|
query: list[str],
|
||||||
limit: int = 3,
|
limit: int = 3,
|
||||||
filter: Optional[dict] = None,
|
filter: Optional[dict[str, Any]] = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
with suppress_logging(
|
with suppress_logging(
|
||||||
@@ -55,14 +57,16 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
where=filter,
|
where=filter,
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
for i in range(len(fetched["ids"][0])): # type: ignore
|
for i in range(len(fetched["ids"][0])):
|
||||||
result = {
|
result = {
|
||||||
"id": fetched["ids"][0][i], # type: ignore
|
"id": fetched["ids"][0][i],
|
||||||
"metadata": fetched["metadatas"][0][i], # type: ignore
|
"metadata": fetched["metadatas"][0][i],
|
||||||
"context": fetched["documents"][0][i], # type: ignore
|
"context": fetched["documents"][0][i],
|
||||||
"score": fetched["distances"][0][i], # type: ignore
|
"score": fetched["distances"][0][i],
|
||||||
}
|
}
|
||||||
if result["score"] >= score_threshold:
|
if (
|
||||||
|
result["score"] <= score_threshold
|
||||||
|
): # Note: distances are smaller when more similar
|
||||||
results.append(result)
|
results.append(result)
|
||||||
return results
|
return results
|
||||||
else:
|
else:
|
||||||
@@ -114,7 +118,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
self,
|
self,
|
||||||
documents: list[str],
|
documents: list[str],
|
||||||
metadata: Optional[dict[str, Any] | list[dict[str, Any]]] = None,
|
metadata: Optional[dict[str, Any] | list[dict[str, Any]]] = None,
|
||||||
):
|
) -> None:
|
||||||
if not self.collection:
|
if not self.collection:
|
||||||
raise Exception("Collection not initialized")
|
raise Exception("Collection not initialized")
|
||||||
|
|
||||||
@@ -169,7 +173,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_default_embedding_function(self):
|
def _create_default_embedding_function(self) -> Any:
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from chromadb import EmbeddingFunction
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
|
|
||||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||||
@@ -22,6 +23,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
app: ClientAPI | None = None
|
app: ClientAPI | None = None
|
||||||
|
embedder_config: EmbeddingFunction | None = None # type: ignore[assignment]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -44,11 +46,11 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self.path = path
|
self.path = path
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
def _set_embedder_config(self):
|
def _set_embedder_config(self) -> None:
|
||||||
configurator = EmbeddingConfigurator()
|
configurator = EmbeddingConfigurator()
|
||||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||||
|
|
||||||
def _initialize_app(self):
|
def _initialize_app(self) -> None:
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
|
|
||||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||||
@@ -103,7 +105,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 3,
|
||||||
filter: Optional[dict] = None,
|
filter: Optional[dict[str, Any]] = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
if not hasattr(self, "app"):
|
if not hasattr(self, "app"):
|
||||||
@@ -116,15 +118,24 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
response = self.collection.query(query_texts=query, n_results=limit)
|
response = self.collection.query(query_texts=query, n_results=limit)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i in range(len(response["ids"][0])):
|
if response and "ids" in response and response["ids"]:
|
||||||
result = {
|
for i in range(len(response["ids"][0])):
|
||||||
"id": response["ids"][0][i],
|
result = {
|
||||||
"metadata": response["metadatas"][0][i],
|
"id": response["ids"][0][i],
|
||||||
"context": response["documents"][0][i],
|
"metadata": response["metadatas"][0][i]
|
||||||
"score": response["distances"][0][i],
|
if response.get("metadatas")
|
||||||
}
|
else {},
|
||||||
if result["score"] >= score_threshold:
|
"context": response["documents"][0][i]
|
||||||
results.append(result)
|
if response.get("documents")
|
||||||
|
else "",
|
||||||
|
"score": response["distances"][0][i]
|
||||||
|
if response.get("distances")
|
||||||
|
else 1.0,
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
result["score"] <= score_threshold
|
||||||
|
): # Note: distances are smaller when more similar
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def suppress_warnings():
|
def suppress_warnings() -> Any:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
yield
|
yield
|
||||||
@@ -45,7 +45,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class SafeOTLPSpanExporter(OTLPSpanExporter):
|
class SafeOTLPSpanExporter(OTLPSpanExporter):
|
||||||
def export(self, spans) -> SpanExportResult:
|
def export(self, spans: Any) -> SpanExportResult:
|
||||||
try:
|
try:
|
||||||
return super().export(spans)
|
return super().export(spans)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -69,7 +69,7 @@ class Telemetry:
|
|||||||
_instance = None
|
_instance = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls) -> Telemetry:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
@@ -144,10 +144,10 @@ class Telemetry:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None):
|
def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None) -> None:
|
||||||
"""Records the creation of a crew."""
|
"""Records the creation of a crew."""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Crew Created")
|
span = tracer.start_span("Crew Created")
|
||||||
self._add_attribute(
|
self._add_attribute(
|
||||||
@@ -352,7 +352,7 @@ class Telemetry:
|
|||||||
def task_started(self, crew: Crew, task: Task) -> Span | None:
|
def task_started(self, crew: Crew, task: Task) -> Span | None:
|
||||||
"""Records task started in a crew."""
|
"""Records task started in a crew."""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
|
|
||||||
created_span = tracer.start_span("Task Created")
|
created_span = tracer.start_span("Task Created")
|
||||||
@@ -439,7 +439,7 @@ class Telemetry:
|
|||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def task_ended(self, span: Span, task: Task, crew: Crew):
|
def task_ended(self, span: Span, task: Task, crew: Crew) -> None:
|
||||||
"""Records the completion of a task execution in a crew.
|
"""Records the completion of a task execution in a crew.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -451,7 +451,7 @@ class Telemetry:
|
|||||||
If share_crew is enabled, this will also record the task output
|
If share_crew is enabled, this will also record the task output
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
# Ensure fingerprint data is present on completion span
|
# Ensure fingerprint data is present on completion span
|
||||||
if hasattr(task, "fingerprint") and task.fingerprint:
|
if hasattr(task, "fingerprint") and task.fingerprint:
|
||||||
self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str)
|
self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str)
|
||||||
@@ -468,7 +468,7 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int):
|
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int) -> None:
|
||||||
"""Records when a tool is used repeatedly, which might indicate an issue.
|
"""Records when a tool is used repeatedly, which might indicate an issue.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -477,7 +477,7 @@ class Telemetry:
|
|||||||
attempts (int): Number of attempts made with this tool
|
attempts (int): Number of attempts made with this tool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Tool Repeated Usage")
|
span = tracer.start_span("Tool Repeated Usage")
|
||||||
self._add_attribute(
|
self._add_attribute(
|
||||||
@@ -494,7 +494,9 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def tool_usage(self, llm: Any, tool_name: str, attempts: int, agent: Any = None):
|
def tool_usage(
|
||||||
|
self, llm: Any, tool_name: str, attempts: int, agent: Any = None
|
||||||
|
) -> None:
|
||||||
"""Records the usage of a tool by an agent.
|
"""Records the usage of a tool by an agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -504,7 +506,7 @@ class Telemetry:
|
|||||||
agent (Any, optional): The agent using the tool
|
agent (Any, optional): The agent using the tool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Tool Usage")
|
span = tracer.start_span("Tool Usage")
|
||||||
self._add_attribute(
|
self._add_attribute(
|
||||||
@@ -541,7 +543,7 @@ class Telemetry:
|
|||||||
tool_name (str, optional): Name of the tool that caused the error
|
tool_name (str, optional): Name of the tool that caused the error
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Tool Usage Error")
|
span = tracer.start_span("Tool Usage Error")
|
||||||
self._add_attribute(
|
self._add_attribute(
|
||||||
@@ -580,7 +582,7 @@ class Telemetry:
|
|||||||
model_name (str): Name of the model used
|
model_name (str): Name of the model used
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Crew Individual Test Result")
|
span = tracer.start_span("Crew Individual Test Result")
|
||||||
|
|
||||||
@@ -615,7 +617,7 @@ class Telemetry:
|
|||||||
model_name (str): Name of the model used in testing
|
model_name (str): Name of the model used in testing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Crew Test Execution")
|
span = tracer.start_span("Crew Test Execution")
|
||||||
|
|
||||||
@@ -639,10 +641,10 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def deploy_signup_error_span(self):
|
def deploy_signup_error_span(self) -> None:
|
||||||
"""Records when an error occurs during the deployment signup process."""
|
"""Records when an error occurs during the deployment signup process."""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Deploy Signup Error")
|
span = tracer.start_span("Deploy Signup Error")
|
||||||
span.set_status(Status(StatusCode.OK))
|
span.set_status(Status(StatusCode.OK))
|
||||||
@@ -650,14 +652,14 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def start_deployment_span(self, uuid: Optional[str] = None):
|
def start_deployment_span(self, uuid: Optional[str] = None) -> None:
|
||||||
"""Records the start of a deployment process.
|
"""Records the start of a deployment process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
uuid (Optional[str]): Unique identifier for the deployment
|
uuid (Optional[str]): Unique identifier for the deployment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Start Deployment")
|
span = tracer.start_span("Start Deployment")
|
||||||
if uuid:
|
if uuid:
|
||||||
@@ -667,10 +669,10 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def create_crew_deployment_span(self):
|
def create_crew_deployment_span(self) -> None:
|
||||||
"""Records the creation of a new crew deployment."""
|
"""Records the creation of a new crew deployment."""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Create Crew Deployment")
|
span = tracer.start_span("Create Crew Deployment")
|
||||||
span.set_status(Status(StatusCode.OK))
|
span.set_status(Status(StatusCode.OK))
|
||||||
@@ -678,7 +680,9 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def get_crew_logs_span(self, uuid: Optional[str], log_type: str = "deployment"):
|
def get_crew_logs_span(
|
||||||
|
self, uuid: Optional[str], log_type: str = "deployment"
|
||||||
|
) -> None:
|
||||||
"""Records the retrieval of crew logs.
|
"""Records the retrieval of crew logs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -686,7 +690,7 @@ class Telemetry:
|
|||||||
log_type (str, optional): Type of logs being retrieved. Defaults to "deployment".
|
log_type (str, optional): Type of logs being retrieved. Defaults to "deployment".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Get Crew Logs")
|
span = tracer.start_span("Get Crew Logs")
|
||||||
self._add_attribute(span, "log_type", log_type)
|
self._add_attribute(span, "log_type", log_type)
|
||||||
@@ -697,14 +701,14 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def remove_crew_span(self, uuid: Optional[str] = None):
|
def remove_crew_span(self, uuid: Optional[str] = None) -> None:
|
||||||
"""Records the removal of a crew.
|
"""Records the removal of a crew.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
uuid (Optional[str]): Unique identifier for the crew being removed
|
uuid (Optional[str]): Unique identifier for the crew being removed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Remove Crew")
|
span = tracer.start_span("Remove Crew")
|
||||||
if uuid:
|
if uuid:
|
||||||
@@ -714,13 +718,13 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def crew_execution_span(self, crew: Crew, inputs: dict[str, Any] | None):
|
def crew_execution_span(self, crew: Crew, inputs: dict[str, Any] | None) -> None:
|
||||||
"""Records the complete execution of a crew.
|
"""Records the complete execution of a crew.
|
||||||
This is only collected if the user has opted-in to share the crew.
|
This is only collected if the user has opted-in to share the crew.
|
||||||
"""
|
"""
|
||||||
self.crew_creation(crew, inputs)
|
self.crew_creation(crew, inputs)
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Crew Execution")
|
span = tracer.start_span("Crew Execution")
|
||||||
self._add_attribute(
|
self._add_attribute(
|
||||||
@@ -792,7 +796,7 @@ class Telemetry:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def end_crew(self, crew: Any, final_string_output: str) -> None:
|
def end_crew(self, crew: Any, final_string_output: str) -> None:
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
self._add_attribute(
|
self._add_attribute(
|
||||||
crew._execution_span,
|
crew._execution_span,
|
||||||
"crewai_version",
|
"crewai_version",
|
||||||
@@ -821,22 +825,22 @@ class Telemetry:
|
|||||||
if crew.share_crew:
|
if crew.share_crew:
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def _add_attribute(self, span, key, value):
|
def _add_attribute(self, span: Any, key: str, value: Any) -> None:
|
||||||
"""Add an attribute to a span."""
|
"""Add an attribute to a span."""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
return span.set_attribute(key, value)
|
return span.set_attribute(key, value)
|
||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def flow_creation_span(self, flow_name: str):
|
def flow_creation_span(self, flow_name: str) -> None:
|
||||||
"""Records the creation of a new flow.
|
"""Records the creation of a new flow.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flow_name (str): Name of the flow being created
|
flow_name (str): Name of the flow being created
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Flow Creation")
|
span = tracer.start_span("Flow Creation")
|
||||||
self._add_attribute(span, "flow_name", flow_name)
|
self._add_attribute(span, "flow_name", flow_name)
|
||||||
@@ -845,7 +849,7 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def flow_plotting_span(self, flow_name: str, node_names: list[str]):
|
def flow_plotting_span(self, flow_name: str, node_names: list[str]) -> None:
|
||||||
"""Records flow visualization/plotting activity.
|
"""Records flow visualization/plotting activity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -853,7 +857,7 @@ class Telemetry:
|
|||||||
node_names (list[str]): List of node names in the flow
|
node_names (list[str]): List of node names in the flow
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Flow Plotting")
|
span = tracer.start_span("Flow Plotting")
|
||||||
self._add_attribute(span, "flow_name", flow_name)
|
self._add_attribute(span, "flow_name", flow_name)
|
||||||
@@ -863,7 +867,7 @@ class Telemetry:
|
|||||||
|
|
||||||
self._safe_telemetry_operation(operation)
|
self._safe_telemetry_operation(operation)
|
||||||
|
|
||||||
def flow_execution_span(self, flow_name: str, node_names: list[str]):
|
def flow_execution_span(self, flow_name: str, node_names: list[str]) -> None:
|
||||||
"""Records the execution of a flow.
|
"""Records the execution of a flow.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -871,7 +875,7 @@ class Telemetry:
|
|||||||
node_names (list[str]): List of nodes being executed in the flow
|
node_names (list[str]): List of nodes being executed in the flow
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def operation():
|
def operation() -> Any:
|
||||||
tracer = trace.get_tracer("crewai.telemetry")
|
tracer = trace.get_tracer("crewai.telemetry")
|
||||||
span = tracer.start_span("Flow Execution")
|
span = tracer.start_span("Flow Execution")
|
||||||
self._add_attribute(span, "flow_name", flow_name)
|
self._add_attribute(span, "flow_name", flow_name)
|
||||||
|
|||||||
Reference in New Issue
Block a user