mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: resolve remaining mypy type annotation issues
- Applied proper decorator typing with ParamSpec and typing_extensions.Self - Fixed event bus decorator to preserve type information - Added type annotations to BaseEventListener and TraceCollectionListener - Fixed LongTermMemory.search to handle None return from storage.load - Resolved all type errors tracked in strict mode
This commit is contained in:
@@ -242,7 +242,7 @@ class Agent(BaseAgent):
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
) -> str:
|
||||
) -> Any:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -6,10 +6,10 @@ from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
|
||||
class BaseEventListener(ABC):
|
||||
verbose: bool = False
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.setup_listeners(crewai_event_bus)
|
||||
|
||||
@abstractmethod
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus):
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from blinker import Signal
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_types import EventTypes
|
||||
|
||||
EventT = TypeVar("EventT", bound=BaseEvent)
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class CrewAIEventsBus:
|
||||
@@ -21,21 +23,21 @@ class CrewAIEventsBus:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> Self:
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None: # prevent race condition
|
||||
cls._instance = super(CrewAIEventsBus, cls).__new__(cls)
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""Initialize the event bus internal state"""
|
||||
self._signal = Signal("crewai_event_bus")
|
||||
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
|
||||
self._handlers: dict[type[BaseEvent], list[Callable[[Any, Any], None]]] = {}
|
||||
|
||||
def on(
|
||||
self, event_type: Type[EventT]
|
||||
self, event_type: type[EventT]
|
||||
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
|
||||
"""
|
||||
Decorator to register an event handler for a specific event type.
|
||||
@@ -54,9 +56,7 @@ class CrewAIEventsBus:
|
||||
) -> Callable[[Any, EventT], None]:
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(
|
||||
cast(Callable[[Any, EventT], None], handler)
|
||||
)
|
||||
self._handlers[event_type].append(cast(Callable[[Any, Any], None], handler))
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
@@ -82,17 +82,15 @@ class CrewAIEventsBus:
|
||||
self._signal.send(source, event=event)
|
||||
|
||||
def register_handler(
|
||||
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||
self, event_type: type[BaseEvent], handler: Callable[[Any, Any], None]
|
||||
) -> None:
|
||||
"""Register an event handler for a specific event type"""
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(
|
||||
cast(Callable[[Any, EventTypes], None], handler)
|
||||
)
|
||||
self._handlers[event_type].append(handler)
|
||||
|
||||
@contextmanager
|
||||
def scoped_handlers(self):
|
||||
def scoped_handlers(self) -> Iterator[None]:
|
||||
"""
|
||||
Context manager for temporary event handling scope.
|
||||
Useful for testing or temporary event handling.
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import Field, PrivateAttr
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
@@ -79,6 +80,7 @@ from .types.tool_usage_events import (
|
||||
|
||||
class EventListener(BaseEventListener):
|
||||
_instance = None
|
||||
_initialized: bool = False
|
||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
execution_spans: dict[Task, Any] = Field(default_factory=dict)
|
||||
@@ -106,7 +108,7 @@ class EventListener(BaseEventListener):
|
||||
|
||||
# ----------- CREW EVENTS -----------
|
||||
|
||||
def setup_listeners(self, crewai_event_bus: Any) -> None:
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
|
||||
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
class MemoryListener(BaseEventListener):
|
||||
def __init__(self, formatter):
|
||||
def __init__(self, formatter: Any) -> None:
|
||||
super().__init__()
|
||||
self.formatter = formatter
|
||||
self.memory_retrieval_in_progress = False
|
||||
|
||||
@@ -1,28 +1,56 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
)
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -33,42 +61,10 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
FlowPlotEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
LLMGuardrailCompletedEvent,
|
||||
)
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
|
||||
from .trace_batch_manager import TraceBatchManager
|
||||
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class TraceCollectionListener(BaseEventListener):
|
||||
"""
|
||||
@@ -112,7 +108,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
except AuthError:
|
||||
return False
|
||||
|
||||
def _get_user_context(self) -> Dict[str, str]:
|
||||
def _get_user_context(self) -> dict[str, str]:
|
||||
"""Extract user context for tracing"""
|
||||
return {
|
||||
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
|
||||
@@ -121,7 +117,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"trace_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
"""Setup event listeners - delegates to specific handlers"""
|
||||
|
||||
if self._listeners_setup:
|
||||
@@ -325,7 +321,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_batch(
|
||||
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
|
||||
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
|
||||
):
|
||||
"""Initialize trace batch if ephemeral"""
|
||||
if not self._check_authenticated():
|
||||
@@ -371,7 +367,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
def _build_event_data(
|
||||
self, event_type: str, event: Any, source: Any
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Build event data"""
|
||||
if event_type not in self.complex_events:
|
||||
return self._safe_serialize_to_dict(event)
|
||||
@@ -429,7 +425,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
# TODO: move to utils
|
||||
def _safe_serialize_to_dict(
|
||||
self, obj, exclude: set[str] | None = None
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data."""
|
||||
try:
|
||||
serialized = to_serializable(obj, exclude)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class TaskStartedEvent(BaseEvent):
|
||||
@@ -11,7 +11,7 @@ class TaskStartedEvent(BaseEvent):
|
||||
context: Optional[str]
|
||||
task: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
@@ -31,7 +31,7 @@ class TaskCompletedEvent(BaseEvent):
|
||||
type: str = "task_completed"
|
||||
task: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
@@ -51,7 +51,7 @@ class TaskFailedEvent(BaseEvent):
|
||||
type: str = "task_failed"
|
||||
task: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
@@ -71,7 +71,7 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
evaluation_type: str
|
||||
task: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
|
||||
@@ -2,23 +2,22 @@ import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import OneOrMany
|
||||
from chromadb.config import Settings
|
||||
import warnings
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from crewai.utilities.chromadb import create_persistent_client, sanitize_collection_name
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
@@ -33,7 +32,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
embedder: Optional[dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
@@ -41,11 +40,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
@@ -69,7 +68,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
else:
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
def initialize_knowledge_storage(self) -> None:
|
||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
|
||||
warnings.filterwarnings(
|
||||
@@ -99,7 +98,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
except Exception:
|
||||
raise Exception("Failed to create or get collection")
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||
if not self.app:
|
||||
self.app = create_persistent_client(
|
||||
@@ -113,8 +112,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def save(
|
||||
self,
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
documents: list[str],
|
||||
metadata: Optional[dict[str, Any] | list[dict[str, Any]]] = None,
|
||||
):
|
||||
if not self.collection:
|
||||
raise Exception("Collection not initialized")
|
||||
@@ -179,7 +178,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||
def _set_embedder_config(self, embedder: Optional[dict[str, Any]] = None) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
@@ -29,67 +29,79 @@ class LongTermMemory(Memory):
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
# Handle both LongTermMemoryItem and regular save calls
|
||||
if isinstance(value, LongTermMemoryItem):
|
||||
item = value
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata.copy()
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Regular save for compatibility with parent class
|
||||
metadata = metadata or {}
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> list[Any]:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
query=query,
|
||||
limit=limit,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -98,14 +110,16 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(task, latest_n)
|
||||
# The storage.load method uses different parameter names
|
||||
# but we'll call it with the aligned names
|
||||
results = self.storage.load(query, limit)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
query=query,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
limit=limit,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -113,15 +127,17 @@ class LongTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return results if results is not None else []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
query=query,
|
||||
limit=limit,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -2,16 +2,17 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from chromadb.api import ClientAPI
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
import warnings
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
@@ -23,8 +24,13 @@ class RAGStorage(BaseRAGStorage):
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Any = None,
|
||||
crew: Any = None,
|
||||
path: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
@@ -85,7 +91,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
try:
|
||||
@@ -99,7 +105,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
@@ -125,7 +131,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
|
||||
|
||||
@@ -37,6 +37,8 @@ class ConditionalTask(Task):
|
||||
Returns:
|
||||
bool: True if the task should be executed, False otherwise.
|
||||
"""
|
||||
if self.condition is None:
|
||||
return False
|
||||
return self.condition(context)
|
||||
|
||||
def get_skipped_task_output(self) -> TaskOutput:
|
||||
|
||||
@@ -5,11 +5,12 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from importlib.metadata import version
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
@@ -72,14 +73,14 @@ class Telemetry:
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Telemetry, cls).__new__(cls)
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
if hasattr(self, "_initialized") and self._initialized:
|
||||
return
|
||||
|
||||
|
||||
self.ready: bool = False
|
||||
self.trace_set: bool = False
|
||||
self._initialized: bool = True
|
||||
@@ -124,7 +125,7 @@ class Telemetry:
|
||||
"""Check if telemetry operations should be executed."""
|
||||
return self.ready and not self._is_telemetry_disabled()
|
||||
|
||||
def set_tracer(self):
|
||||
def set_tracer(self) -> None:
|
||||
if self.ready and not self.trace_set:
|
||||
try:
|
||||
with suppress_warnings():
|
||||
@@ -790,7 +791,7 @@ class Telemetry:
|
||||
return operation()
|
||||
return None
|
||||
|
||||
def end_crew(self, crew, final_string_output):
|
||||
def end_crew(self, crew: Any, final_string_output: str) -> None:
|
||||
def operation():
|
||||
self._add_attribute(
|
||||
crew._execution_span,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -73,7 +73,7 @@ class TaskEvaluator:
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
return converter.to_pydantic()
|
||||
return cast(TaskEvaluation, converter.to_pydantic())
|
||||
|
||||
def evaluate_training_data(
|
||||
self, training_data: dict[str, Any], agent_id: str
|
||||
@@ -143,4 +143,4 @@ class TaskEvaluator:
|
||||
)
|
||||
|
||||
pydantic_result = converter.to_pydantic()
|
||||
return pydantic_result
|
||||
return cast(TrainingTaskEvaluation, pydantic_result)
|
||||
|
||||
@@ -2,33 +2,39 @@ import json
|
||||
import os
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
|
||||
class FileHandler:
|
||||
"""Handler for file operations supporting both JSON and text-based logging.
|
||||
|
||||
|
||||
Args:
|
||||
file_path (Union[bool, str]): Path to the log file or boolean flag
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: Union[bool, str]):
|
||||
def __init__(self, file_path: bool | str):
|
||||
self._initialize_path(file_path)
|
||||
|
||||
def _initialize_path(self, file_path: Union[bool, str]):
|
||||
|
||||
def _initialize_path(self, file_path: bool | str) -> None:
|
||||
if file_path is True: # File path is boolean True
|
||||
self._path = os.path.join(os.curdir, "logs.txt")
|
||||
|
||||
|
||||
elif isinstance(file_path, str): # File path is a string
|
||||
if file_path.endswith((".json", ".txt")):
|
||||
self._path = file_path # No modification if the file ends with .json or .txt
|
||||
self._path = (
|
||||
file_path # No modification if the file ends with .json or .txt
|
||||
)
|
||||
else:
|
||||
self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt
|
||||
|
||||
self._path = (
|
||||
file_path + ".txt"
|
||||
) # Append .txt if the file doesn't end with .json or .txt
|
||||
|
||||
else:
|
||||
raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid
|
||||
|
||||
def log(self, **kwargs):
|
||||
raise ValueError(
|
||||
"file_path must be a string or boolean."
|
||||
) # Handle the case where file_path isn't valid
|
||||
|
||||
def log(self, **kwargs: Any) -> None:
|
||||
try:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = {"timestamp": now, **kwargs}
|
||||
@@ -45,20 +51,25 @@ class FileHandler:
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# If no valid JSON or file doesn't exist, start with an empty list
|
||||
existing_data = [log_entry]
|
||||
|
||||
|
||||
with open(self._path, "w", encoding="utf-8") as write_file:
|
||||
json.dump(existing_data, write_file, indent=4)
|
||||
write_file.write("\n")
|
||||
|
||||
|
||||
else:
|
||||
# Append log in plain text format
|
||||
message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n"
|
||||
message = (
|
||||
f"{now}: "
|
||||
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
|
||||
+ "\n"
|
||||
)
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to log message: {str(e)}")
|
||||
|
||||
|
||||
|
||||
class PickleHandler:
|
||||
def __init__(self, file_name: str) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user