fix: resolve additional mypy type annotation issues

- Fixed file_handler.py PickleHandler type annotations
- Fixed task_events.py None checks before accessing task.fingerprint
- Added type annotations to memory_listener.py event handlers
This commit is contained in:
Greyson LaLonde
2025-09-04 13:07:37 -04:00
parent 9306d889a7
commit 8dd3493e9c
3 changed files with 32 additions and 19 deletions

View File

@@ -1,6 +1,7 @@
from typing import Any
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
@@ -19,9 +20,11 @@ class MemoryListener(BaseEventListener):
self.memory_retrieval_in_progress = False
self.memory_save_in_progress = False
def setup_listeners(self, crewai_event_bus):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
def on_memory_retrieval_started(source, event: MemoryRetrievalStartedEvent):
def on_memory_retrieval_started(
source: Any, event: MemoryRetrievalStartedEvent
) -> None:
if self.memory_retrieval_in_progress:
return
@@ -33,7 +36,9 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
def on_memory_retrieval_completed(source, event: MemoryRetrievalCompletedEvent):
def on_memory_retrieval_completed(
source: Any, event: MemoryRetrievalCompletedEvent
) -> None:
if not self.memory_retrieval_in_progress:
return
@@ -46,7 +51,9 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_memory_query_completed(source, event: MemoryQueryCompletedEvent):
def on_memory_query_completed(
source: Any, event: MemoryQueryCompletedEvent
) -> None:
if not self.memory_retrieval_in_progress:
return
@@ -58,7 +65,7 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemoryQueryFailedEvent)
def on_memory_query_failed(source, event: MemoryQueryFailedEvent):
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
if not self.memory_retrieval_in_progress:
return
@@ -70,7 +77,7 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_memory_save_started(source, event: MemorySaveStartedEvent):
def on_memory_save_started(source: Any, event: MemorySaveStartedEvent) -> None:
if self.memory_save_in_progress:
return
@@ -82,7 +89,9 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_memory_save_completed(source, event: MemorySaveCompletedEvent):
def on_memory_save_completed(
source: Any, event: MemorySaveCompletedEvent
) -> None:
if not self.memory_save_in_progress:
return
@@ -96,7 +105,7 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemorySaveFailedEvent)
def on_memory_save_failed(source, event: MemorySaveFailedEvent):
def on_memory_save_failed(source: Any, event: MemorySaveFailedEvent) -> None:
if not self.memory_save_in_progress:
return

View File

@@ -14,11 +14,12 @@ class TaskStartedEvent(BaseEvent):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
self.task
and hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
@@ -34,11 +35,12 @@ class TaskCompletedEvent(BaseEvent):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
self.task
and hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
@@ -54,11 +56,12 @@ class TaskFailedEvent(BaseEvent):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
self.task
and hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
@@ -74,11 +77,12 @@ class TaskEvaluationEvent(BaseEvent):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
self.task
and hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata

View File

@@ -90,7 +90,7 @@ class PickleHandler:
"""
self.save({})
def save(self, data) -> None:
def save(self, data: Any) -> None:
"""
Save the data to the specified file using pickle.
@@ -100,12 +100,12 @@ class PickleHandler:
with open(self.file_path, "wb") as file:
pickle.dump(data, file)
def load(self) -> dict:
def load(self) -> Any:
"""
Load the data from the specified file using pickle.
Returns:
- dict: The data loaded from the file.
- The data loaded from the file.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return {} # Return an empty dictionary if the file does not exist or is empty