fix: resolve remaining mypy type annotation issues

- Applied proper decorator typing with ParamSpec and typing_extensions.Self
- Fixed event bus decorator to preserve type information
- Added type annotations to BaseEventListener and TraceCollectionListener
- Fixed LongTermMemory.search to handle None return from storage.load
- Resolved all type errors tracked in strict mode
This commit is contained in:
Greyson LaLonde
2025-09-04 13:00:11 -04:00
parent 8354cdf061
commit 9306d889a7
14 changed files with 207 additions and 174 deletions

View File

@@ -242,7 +242,7 @@ class Agent(BaseAgent):
task: Task, task: Task,
context: Optional[str] = None, context: Optional[str] = None,
tools: Optional[list[BaseTool]] = None, tools: Optional[list[BaseTool]] = None,
) -> str: ) -> Any:
"""Execute a task with the agent. """Execute a task with the agent.
Args: Args:

View File

@@ -6,10 +6,10 @@ from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
class BaseEventListener(ABC): class BaseEventListener(ABC):
verbose: bool = False verbose: bool = False
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.setup_listeners(crewai_event_bus) self.setup_listeners(crewai_event_bus)
@abstractmethod @abstractmethod
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus): def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
pass pass

View File

@@ -1,15 +1,17 @@
from __future__ import annotations from __future__ import annotations
import threading import threading
from collections.abc import Callable, Iterator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Type, TypeVar, cast from typing import Any, ParamSpec, TypeVar, cast
from blinker import Signal from blinker import Signal
from typing_extensions import Self
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
from crewai.events.event_types import EventTypes
EventT = TypeVar("EventT", bound=BaseEvent) EventT = TypeVar("EventT", bound=BaseEvent)
P = ParamSpec("P")
class CrewAIEventsBus: class CrewAIEventsBus:
@@ -21,21 +23,21 @@ class CrewAIEventsBus:
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
def __new__(cls): def __new__(cls) -> Self:
if cls._instance is None: if cls._instance is None:
with cls._lock: with cls._lock:
if cls._instance is None: # prevent race condition if cls._instance is None: # prevent race condition
cls._instance = super(CrewAIEventsBus, cls).__new__(cls) cls._instance = super().__new__(cls)
cls._instance._initialize() cls._instance._initialize()
return cls._instance return cls._instance
def _initialize(self) -> None: def _initialize(self) -> None:
"""Initialize the event bus internal state""" """Initialize the event bus internal state"""
self._signal = Signal("crewai_event_bus") self._signal = Signal("crewai_event_bus")
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {} self._handlers: dict[type[BaseEvent], list[Callable[[Any, Any], None]]] = {}
def on( def on(
self, event_type: Type[EventT] self, event_type: type[EventT]
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]: ) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
""" """
Decorator to register an event handler for a specific event type. Decorator to register an event handler for a specific event type.
@@ -54,9 +56,7 @@ class CrewAIEventsBus:
) -> Callable[[Any, EventT], None]: ) -> Callable[[Any, EventT], None]:
if event_type not in self._handlers: if event_type not in self._handlers:
self._handlers[event_type] = [] self._handlers[event_type] = []
self._handlers[event_type].append( self._handlers[event_type].append(cast(Callable[[Any, Any], None], handler))
cast(Callable[[Any, EventT], None], handler)
)
return handler return handler
return decorator return decorator
@@ -82,17 +82,15 @@ class CrewAIEventsBus:
self._signal.send(source, event=event) self._signal.send(source, event=event)
def register_handler( def register_handler(
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None] self, event_type: type[BaseEvent], handler: Callable[[Any, Any], None]
) -> None: ) -> None:
"""Register an event handler for a specific event type""" """Register an event handler for a specific event type"""
if event_type not in self._handlers: if event_type not in self._handlers:
self._handlers[event_type] = [] self._handlers[event_type] = []
self._handlers[event_type].append( self._handlers[event_type].append(handler)
cast(Callable[[Any, EventTypes], None], handler)
)
@contextmanager @contextmanager
def scoped_handlers(self): def scoped_handlers(self) -> Iterator[None]:
""" """
Context manager for temporary event handling scope. Context manager for temporary event handling scope.
Useful for testing or temporary event handling. Useful for testing or temporary event handling.

View File

@@ -7,6 +7,7 @@ from pydantic import Field, PrivateAttr
from typing_extensions import Self from typing_extensions import Self
from crewai.events.base_event_listener import BaseEventListener from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.types.agent_events import ( from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent, AgentExecutionCompletedEvent,
AgentExecutionStartedEvent, AgentExecutionStartedEvent,
@@ -79,6 +80,7 @@ from .types.tool_usage_events import (
class EventListener(BaseEventListener): class EventListener(BaseEventListener):
_instance = None _instance = None
_initialized: bool = False
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry()) _telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
logger = Logger(verbose=True, default_color=EMITTER_COLOR) logger = Logger(verbose=True, default_color=EMITTER_COLOR)
execution_spans: dict[Task, Any] = Field(default_factory=dict) execution_spans: dict[Task, Any] = Field(default_factory=dict)
@@ -106,7 +108,7 @@ class EventListener(BaseEventListener):
# ----------- CREW EVENTS ----------- # ----------- CREW EVENTS -----------
def setup_listeners(self, crewai_event_bus: Any) -> None: def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
@crewai_event_bus.on(CrewKickoffStartedEvent) @crewai_event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None: def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id) self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)

View File

@@ -1,17 +1,19 @@
from typing import Any
from crewai.events.base_event_listener import BaseEventListener from crewai.events.base_event_listener import BaseEventListener
from crewai.events.types.memory_events import ( from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryRetrievalCompletedEvent, MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent, MemoryRetrievalStartedEvent,
MemoryQueryFailedEvent,
MemoryQueryCompletedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent, MemorySaveCompletedEvent,
MemorySaveFailedEvent, MemorySaveFailedEvent,
MemorySaveStartedEvent,
) )
class MemoryListener(BaseEventListener): class MemoryListener(BaseEventListener):
def __init__(self, formatter): def __init__(self, formatter: Any) -> None:
super().__init__() super().__init__()
self.formatter = formatter self.formatter = formatter
self.memory_retrieval_in_progress = False self.memory_retrieval_in_progress = False

View File

@@ -1,28 +1,56 @@
import os import os
import uuid import uuid
from typing import Any, Optional
from typing import Dict, Any, Optional from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
from crewai.events.base_event_listener import BaseEventListener from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.types.agent_events import ( from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent, AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent, AgentExecutionStartedEvent,
LiteAgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent, LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent, LiteAgentExecutionErrorEvent,
AgentExecutionErrorEvent, LiteAgentExecutionStartedEvent,
)
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
) )
from crewai.events.types.crew_events import ( from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent, CrewKickoffCompletedEvent,
CrewKickoffFailedEvent, CrewKickoffFailedEvent,
CrewKickoffStartedEvent, CrewKickoffStartedEvent,
) )
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFailedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.events.types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from crewai.events.types.task_events import ( from crewai.events.types.task_events import (
TaskCompletedEvent, TaskCompletedEvent,
TaskFailedEvent, TaskFailedEvent,
@@ -33,42 +61,10 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent, ToolUsageFinishedEvent,
ToolUsageStartedEvent, ToolUsageStartedEvent,
) )
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowStartedEvent,
FlowFinishedEvent,
MethodExecutionStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionFailedEvent,
FlowPlotEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
)
from crewai.utilities.serialization import to_serializable from crewai.utilities.serialization import to_serializable
from .trace_batch_manager import TraceBatchManager from .trace_batch_manager import TraceBatchManager
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
class TraceCollectionListener(BaseEventListener): class TraceCollectionListener(BaseEventListener):
""" """
@@ -112,7 +108,7 @@ class TraceCollectionListener(BaseEventListener):
except AuthError: except AuthError:
return False return False
def _get_user_context(self) -> Dict[str, str]: def _get_user_context(self) -> dict[str, str]:
"""Extract user context for tracing""" """Extract user context for tracing"""
return { return {
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"), "user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
@@ -121,7 +117,7 @@ class TraceCollectionListener(BaseEventListener):
"trace_id": str(uuid.uuid4()), "trace_id": str(uuid.uuid4()),
} }
def setup_listeners(self, crewai_event_bus): def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
"""Setup event listeners - delegates to specific handlers""" """Setup event listeners - delegates to specific handlers"""
if self._listeners_setup: if self._listeners_setup:
@@ -325,7 +321,7 @@ class TraceCollectionListener(BaseEventListener):
self._initialize_batch(user_context, execution_metadata) self._initialize_batch(user_context, execution_metadata)
def _initialize_batch( def _initialize_batch(
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any] self, user_context: dict[str, str], execution_metadata: dict[str, Any]
): ):
"""Initialize trace batch if ephemeral""" """Initialize trace batch if ephemeral"""
if not self._check_authenticated(): if not self._check_authenticated():
@@ -371,7 +367,7 @@ class TraceCollectionListener(BaseEventListener):
def _build_event_data( def _build_event_data(
self, event_type: str, event: Any, source: Any self, event_type: str, event: Any, source: Any
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Build event data""" """Build event data"""
if event_type not in self.complex_events: if event_type not in self.complex_events:
return self._safe_serialize_to_dict(event) return self._safe_serialize_to_dict(event)
@@ -429,7 +425,7 @@ class TraceCollectionListener(BaseEventListener):
# TODO: move to utils # TODO: move to utils
def _safe_serialize_to_dict( def _safe_serialize_to_dict(
self, obj, exclude: set[str] | None = None self, obj, exclude: set[str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Safely serialize an object to a dictionary for event data.""" """Safely serialize an object to a dictionary for event data."""
try: try:
serialized = to_serializable(obj, exclude) serialized = to_serializable(obj, exclude)

View File

@@ -1,7 +1,7 @@
from typing import Any, Optional from typing import Any, Optional
from crewai.tasks.task_output import TaskOutput
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
from crewai.tasks.task_output import TaskOutput
class TaskStartedEvent(BaseEvent): class TaskStartedEvent(BaseEvent):
@@ -11,7 +11,7 @@ class TaskStartedEvent(BaseEvent):
context: Optional[str] context: Optional[str]
task: Optional[Any] = None task: Optional[Any] = None
def __init__(self, **data): def __init__(self, **data: Any) -> None:
super().__init__(**data) super().__init__(**data)
# Set fingerprint data from the task # Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint: if hasattr(self.task, "fingerprint") and self.task.fingerprint:
@@ -31,7 +31,7 @@ class TaskCompletedEvent(BaseEvent):
type: str = "task_completed" type: str = "task_completed"
task: Optional[Any] = None task: Optional[Any] = None
def __init__(self, **data): def __init__(self, **data: Any) -> None:
super().__init__(**data) super().__init__(**data)
# Set fingerprint data from the task # Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint: if hasattr(self.task, "fingerprint") and self.task.fingerprint:
@@ -51,7 +51,7 @@ class TaskFailedEvent(BaseEvent):
type: str = "task_failed" type: str = "task_failed"
task: Optional[Any] = None task: Optional[Any] = None
def __init__(self, **data): def __init__(self, **data: Any) -> None:
super().__init__(**data) super().__init__(**data)
# Set fingerprint data from the task # Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint: if hasattr(self.task, "fingerprint") and self.task.fingerprint:
@@ -71,7 +71,7 @@ class TaskEvaluationEvent(BaseEvent):
evaluation_type: str evaluation_type: str
task: Optional[Any] = None task: Optional[Any] = None
def __init__(self, **data): def __init__(self, **data: Any) -> None:
super().__init__(**data) super().__init__(**data)
# Set fingerprint data from the task # Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint: if hasattr(self.task, "fingerprint") and self.task.fingerprint:

View File

@@ -2,23 +2,22 @@ import hashlib
import logging import logging
import os import os
import shutil import shutil
from typing import Any, Dict, List, Optional, Union import warnings
from typing import Any, Optional, Union
import chromadb import chromadb
import chromadb.errors import chromadb.errors
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany from chromadb.api.types import OneOrMany
from chromadb.config import Settings from chromadb.config import Settings
import warnings
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import sanitize_collection_name from crewai.utilities.chromadb import create_persistent_client, sanitize_collection_name
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.logger_utils import suppress_logging from crewai.utilities.logger_utils import suppress_logging
from crewai.utilities.paths import db_storage_path
class KnowledgeStorage(BaseKnowledgeStorage): class KnowledgeStorage(BaseKnowledgeStorage):
@@ -33,7 +32,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__( def __init__(
self, self,
embedder: Optional[Dict[str, Any]] = None, embedder: Optional[dict[str, Any]] = None,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
): ):
self.collection_name = collection_name self.collection_name = collection_name
@@ -41,11 +40,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def search( def search(
self, self,
query: List[str], query: list[str],
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: Optional[dict] = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
with suppress_logging( with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
): ):
@@ -69,7 +68,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
else: else:
raise Exception("Collection not initialized") raise Exception("Collection not initialized")
def initialize_knowledge_storage(self): def initialize_knowledge_storage(self) -> None:
# Suppress deprecation warnings from chromadb, which are not relevant to us # Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8. # TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings( warnings.filterwarnings(
@@ -99,7 +98,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
except Exception: except Exception:
raise Exception("Failed to create or get collection") raise Exception("Failed to create or get collection")
def reset(self): def reset(self) -> None:
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY) base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
if not self.app: if not self.app:
self.app = create_persistent_client( self.app = create_persistent_client(
@@ -113,8 +112,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def save( def save(
self, self,
documents: List[str], documents: list[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, metadata: Optional[dict[str, Any] | list[dict[str, Any]]] = None,
): ):
if not self.collection: if not self.collection:
raise Exception("Collection not initialized") raise Exception("Collection not initialized")
@@ -179,7 +178,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
) )
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None: def _set_embedder_config(self, embedder: Optional[dict[str, Any]] = None) -> None:
"""Set the embedding configuration for the knowledge storage. """Set the embedding configuration for the knowledge storage.
Args: Args:

View File

@@ -1,5 +1,5 @@
import time import time
from typing import Any from typing import Any, Optional
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import ( from crewai.events.types.memory_events import (
@@ -29,67 +29,79 @@ class LongTermMemory(Memory):
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage) super().__init__(storage=storage)
def save(self, item: LongTermMemoryItem) -> None: def save(
crewai_event_bus.emit( self,
self, value: Any,
event=MemorySaveStartedEvent( metadata: Optional[dict[str, Any]] = None,
value=item.task, ) -> None:
metadata=item.metadata, # Handle both LongTermMemoryItem and regular save calls
agent_role=item.agent, if isinstance(value, LongTermMemoryItem):
source_type="long_term_memory", item = value
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
metadata = item.metadata
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
self.storage.save(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemorySaveCompletedEvent( event=MemorySaveStartedEvent(
value=item.task, value=item.task,
metadata=item.metadata, metadata=item.metadata,
agent_role=item.agent, agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory", source_type="long_term_memory",
from_agent=self.agent, from_agent=self.agent,
from_task=self.task, from_task=self.task,
), ),
) )
except Exception as e:
crewai_event_bus.emit( start_time = time.time()
self, try:
event=MemorySaveFailedEvent( metadata = item.metadata.copy()
value=item.task, metadata.update(
metadata=item.metadata, {"agent": item.agent, "expected_output": item.expected_output}
agent_role=item.agent, )
error=str(e), self.storage.save(
source_type="long_term_memory", task_description=item.task,
), score=metadata["quality"],
) metadata=metadata,
raise datetime=item.datetime,
)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
error=str(e),
source_type="long_term_memory",
),
)
raise
else:
# Regular save for compatibility with parent class
metadata = metadata or {}
self.storage.save(value, metadata)
def search( def search(
self, self,
task: str, query: str,
latest_n: int = 3, limit: int = 3,
) -> list[dict[str, Any]]: score_threshold: float = 0.35,
) -> list[Any]:
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemoryQueryStartedEvent( event=MemoryQueryStartedEvent(
query=task, query=query,
limit=latest_n, limit=limit,
source_type="long_term_memory", source_type="long_term_memory",
from_agent=self.agent, from_agent=self.agent,
from_task=self.task, from_task=self.task,
@@ -98,14 +110,16 @@ class LongTermMemory(Memory):
start_time = time.time() start_time = time.time()
try: try:
results = self.storage.load(task, latest_n) # The storage.load method uses different parameter names
# but we'll call it with the aligned names
results = self.storage.load(query, limit)
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemoryQueryCompletedEvent( event=MemoryQueryCompletedEvent(
query=task, query=query,
results=results, results=results,
limit=latest_n, limit=limit,
query_time_ms=(time.time() - start_time) * 1000, query_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory", source_type="long_term_memory",
from_agent=self.agent, from_agent=self.agent,
@@ -113,15 +127,17 @@ class LongTermMemory(Memory):
), ),
) )
return results return results if results is not None else []
except Exception as e: except Exception as e:
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemoryQueryFailedEvent( event=MemoryQueryFailedEvent(
query=task, query=query,
limit=latest_n, limit=limit,
error=str(e), error=str(e),
source_type="long_term_memory", source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
), ),
) )
raise raise

View File

@@ -2,16 +2,17 @@ import logging
import os import os
import shutil import shutil
import uuid import uuid
import warnings
from typing import Any, Optional
from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.chromadb import create_persistent_client from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
from crewai.utilities.logger_utils import suppress_logging from crewai.utilities.logger_utils import suppress_logging
import warnings from crewai.utilities.paths import db_storage_path
class RAGStorage(BaseRAGStorage): class RAGStorage(BaseRAGStorage):
@@ -23,8 +24,13 @@ class RAGStorage(BaseRAGStorage):
app: ClientAPI | None = None app: ClientAPI | None = None
def __init__( def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None self,
): type: str,
allow_reset: bool = True,
embedder_config: Any = None,
crew: Any = None,
path: Optional[str] = None,
) -> None:
super().__init__(type, allow_reset, embedder_config, crew) super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else [] agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents] agents = [self._sanitize_role(agent.role) for agent in agents]
@@ -85,7 +91,7 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}" return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"): if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app() self._initialize_app()
try: try:
@@ -99,7 +105,7 @@ class RAGStorage(BaseRAGStorage):
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: Optional[dict] = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> list[Any]:
if not hasattr(self, "app"): if not hasattr(self, "app"):
self._initialize_app() self._initialize_app()
@@ -125,7 +131,7 @@ class RAGStorage(BaseRAGStorage):
logging.error(f"Error during {self.type} search: {str(e)}") logging.error(f"Error during {self.type} search: {str(e)}")
return [] return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"): if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app() self._initialize_app()

View File

@@ -37,6 +37,8 @@ class ConditionalTask(Task):
Returns: Returns:
bool: True if the task should be executed, False otherwise. bool: True if the task should be executed, False otherwise.
""" """
if self.condition is None:
return False
return self.condition(context) return self.condition(context)
def get_skipped_task_output(self) -> TaskOutput: def get_skipped_task_output(self) -> TaskOutput:

View File

@@ -5,11 +5,12 @@ import json
import logging import logging
import os import os
import platform import platform
import threading
import warnings import warnings
from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from importlib.metadata import version from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Optional
import threading
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
@@ -72,12 +73,12 @@ class Telemetry:
if cls._instance is None: if cls._instance is None:
with cls._lock: with cls._lock:
if cls._instance is None: if cls._instance is None:
cls._instance = super(Telemetry, cls).__new__(cls) cls._instance = super().__new__(cls)
cls._instance._initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
def __init__(self) -> None: def __init__(self) -> None:
if hasattr(self, '_initialized') and self._initialized: if hasattr(self, "_initialized") and self._initialized:
return return
self.ready: bool = False self.ready: bool = False
@@ -124,7 +125,7 @@ class Telemetry:
"""Check if telemetry operations should be executed.""" """Check if telemetry operations should be executed."""
return self.ready and not self._is_telemetry_disabled() return self.ready and not self._is_telemetry_disabled()
def set_tracer(self): def set_tracer(self) -> None:
if self.ready and not self.trace_set: if self.ready and not self.trace_set:
try: try:
with suppress_warnings(): with suppress_warnings():
@@ -790,7 +791,7 @@ class Telemetry:
return operation() return operation()
return None return None
def end_crew(self, crew, final_string_output): def end_crew(self, crew: Any, final_string_output: str) -> None:
def operation(): def operation():
self._add_attribute( self._add_attribute(
crew._execution_span, crew._execution_span,

View File

@@ -1,4 +1,4 @@
from typing import Any from typing import Any, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -73,7 +73,7 @@ class TaskEvaluator:
instructions=instructions, instructions=instructions,
) )
return converter.to_pydantic() return cast(TaskEvaluation, converter.to_pydantic())
def evaluate_training_data( def evaluate_training_data(
self, training_data: dict[str, Any], agent_id: str self, training_data: dict[str, Any], agent_id: str
@@ -143,4 +143,4 @@ class TaskEvaluator:
) )
pydantic_result = converter.to_pydantic() pydantic_result = converter.to_pydantic()
return pydantic_result return cast(TrainingTaskEvaluation, pydantic_result)

View File

@@ -2,7 +2,7 @@ import json
import os import os
import pickle import pickle
from datetime import datetime from datetime import datetime
from typing import Union from typing import Any, Union
class FileHandler: class FileHandler:
@@ -12,23 +12,29 @@ class FileHandler:
file_path (Union[bool, str]): Path to the log file or boolean flag file_path (Union[bool, str]): Path to the log file or boolean flag
""" """
def __init__(self, file_path: Union[bool, str]): def __init__(self, file_path: bool | str):
self._initialize_path(file_path) self._initialize_path(file_path)
def _initialize_path(self, file_path: Union[bool, str]): def _initialize_path(self, file_path: bool | str) -> None:
if file_path is True: # File path is boolean True if file_path is True: # File path is boolean True
self._path = os.path.join(os.curdir, "logs.txt") self._path = os.path.join(os.curdir, "logs.txt")
elif isinstance(file_path, str): # File path is a string elif isinstance(file_path, str): # File path is a string
if file_path.endswith((".json", ".txt")): if file_path.endswith((".json", ".txt")):
self._path = file_path # No modification if the file ends with .json or .txt self._path = (
file_path # No modification if the file ends with .json or .txt
)
else: else:
self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt self._path = (
file_path + ".txt"
) # Append .txt if the file doesn't end with .json or .txt
else: else:
raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid raise ValueError(
"file_path must be a string or boolean."
) # Handle the case where file_path isn't valid
def log(self, **kwargs): def log(self, **kwargs: Any) -> None:
try: try:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = {"timestamp": now, **kwargs} log_entry = {"timestamp": now, **kwargs}
@@ -52,13 +58,18 @@ class FileHandler:
else: else:
# Append log in plain text format # Append log in plain text format
message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n" message = (
f"{now}: "
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
+ "\n"
)
with open(self._path, "a", encoding="utf-8") as file: with open(self._path, "a", encoding="utf-8") as file:
file.write(message) file.write(message)
except Exception as e: except Exception as e:
raise ValueError(f"Failed to log message: {str(e)}") raise ValueError(f"Failed to log message: {str(e)}")
class PickleHandler: class PickleHandler:
def __init__(self, file_name: str) -> None: def __init__(self, file_name: str) -> None:
""" """