From 8dd3493e9ceec926a7e8d253a7c513ebda1f58c0 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 4 Sep 2025 13:07:37 -0400 Subject: [PATCH] fix: resolve additional mypy type annotation issues - Fixed file_handler.py PickleHandler type annotations - Fixed task_events.py None checks before accessing task.fingerprint - Added type annotations to memory_listener.py event handlers --- .../events/listeners/memory_listener.py | 25 +++++++++++++------ src/crewai/events/types/task_events.py | 20 +++++++++------ src/crewai/utilities/file_handler.py | 6 ++--- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/src/crewai/events/listeners/memory_listener.py b/src/crewai/events/listeners/memory_listener.py index 0da19ba06..4224e9d44 100644 --- a/src/crewai/events/listeners/memory_listener.py +++ b/src/crewai/events/listeners/memory_listener.py @@ -1,6 +1,7 @@ from typing import Any from crewai.events.base_event_listener import BaseEventListener +from crewai.events.event_bus import CrewAIEventsBus from crewai.events.types.memory_events import ( MemoryQueryCompletedEvent, MemoryQueryFailedEvent, @@ -19,9 +20,11 @@ class MemoryListener(BaseEventListener): self.memory_retrieval_in_progress = False self.memory_save_in_progress = False - def setup_listeners(self, crewai_event_bus): + def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None: @crewai_event_bus.on(MemoryRetrievalStartedEvent) - def on_memory_retrieval_started(source, event: MemoryRetrievalStartedEvent): + def on_memory_retrieval_started( + source: Any, event: MemoryRetrievalStartedEvent + ) -> None: if self.memory_retrieval_in_progress: return @@ -33,7 +36,9 @@ class MemoryListener(BaseEventListener): ) @crewai_event_bus.on(MemoryRetrievalCompletedEvent) - def on_memory_retrieval_completed(source, event: MemoryRetrievalCompletedEvent): + def on_memory_retrieval_completed( + source: Any, event: MemoryRetrievalCompletedEvent + ) -> None: if not self.memory_retrieval_in_progress: return @@ -46,7 +51,9 @@ class MemoryListener(BaseEventListener): ) @crewai_event_bus.on(MemoryQueryCompletedEvent) - def on_memory_query_completed(source, event: MemoryQueryCompletedEvent): + def on_memory_query_completed( + source: Any, event: MemoryQueryCompletedEvent + ) -> None: if not self.memory_retrieval_in_progress: return @@ -58,7 +65,7 @@ class MemoryListener(BaseEventListener): ) @crewai_event_bus.on(MemoryQueryFailedEvent) - def on_memory_query_failed(source, event: MemoryQueryFailedEvent): + def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None: if not self.memory_retrieval_in_progress: return @@ -70,7 +77,7 @@ class MemoryListener(BaseEventListener): ) @crewai_event_bus.on(MemorySaveStartedEvent) - def on_memory_save_started(source, event: MemorySaveStartedEvent): + def on_memory_save_started(source: Any, event: MemorySaveStartedEvent) -> None: if self.memory_save_in_progress: return @@ -82,7 +89,9 @@ class MemoryListener(BaseEventListener): ) @crewai_event_bus.on(MemorySaveCompletedEvent) - def on_memory_save_completed(source, event: MemorySaveCompletedEvent): + def on_memory_save_completed( + source: Any, event: MemorySaveCompletedEvent + ) -> None: if not self.memory_save_in_progress: return @@ -96,7 +105,7 @@ class MemoryListener(BaseEventListener): ) @crewai_event_bus.on(MemorySaveFailedEvent) - def on_memory_save_failed(source, event: MemorySaveFailedEvent): + def on_memory_save_failed(source: Any, event: MemorySaveFailedEvent) -> None: if not self.memory_save_in_progress: return diff --git a/src/crewai/events/types/task_events.py b/src/crewai/events/types/task_events.py index 927223694..57750f608 100644 --- a/src/crewai/events/types/task_events.py +++ b/src/crewai/events/types/task_events.py @@ -14,11 +14,12 @@ class TaskStartedEvent(BaseEvent): def __init__(self, **data: Any) -> None: super().__init__(**data) # Set fingerprint data from the task - if hasattr(self.task, "fingerprint") and self.task.fingerprint: + if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint: self.source_fingerprint = self.task.fingerprint.uuid_str self.source_type = "task" if ( - hasattr(self.task.fingerprint, "metadata") + self.task + and hasattr(self.task.fingerprint, "metadata") and self.task.fingerprint.metadata ): self.fingerprint_metadata = self.task.fingerprint.metadata @@ -34,11 +35,12 @@ class TaskCompletedEvent(BaseEvent): def __init__(self, **data: Any) -> None: super().__init__(**data) # Set fingerprint data from the task - if hasattr(self.task, "fingerprint") and self.task.fingerprint: + if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint: self.source_fingerprint = self.task.fingerprint.uuid_str self.source_type = "task" if ( - hasattr(self.task.fingerprint, "metadata") + self.task + and hasattr(self.task.fingerprint, "metadata") and self.task.fingerprint.metadata ): self.fingerprint_metadata = self.task.fingerprint.metadata @@ -54,11 +56,12 @@ class TaskFailedEvent(BaseEvent): def __init__(self, **data: Any) -> None: super().__init__(**data) # Set fingerprint data from the task - if hasattr(self.task, "fingerprint") and self.task.fingerprint: + if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint: self.source_fingerprint = self.task.fingerprint.uuid_str self.source_type = "task" if ( - hasattr(self.task.fingerprint, "metadata") + self.task + and hasattr(self.task.fingerprint, "metadata") and self.task.fingerprint.metadata ): self.fingerprint_metadata = self.task.fingerprint.metadata @@ -74,11 +77,12 @@ class TaskEvaluationEvent(BaseEvent): def __init__(self, **data: Any) -> None: super().__init__(**data) # Set fingerprint data from the task - if hasattr(self.task, "fingerprint") and self.task.fingerprint: + if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint: self.source_fingerprint = self.task.fingerprint.uuid_str self.source_type = "task" if ( - hasattr(self.task.fingerprint, "metadata") + self.task + and hasattr(self.task.fingerprint, "metadata") and self.task.fingerprint.metadata ): self.fingerprint_metadata = self.task.fingerprint.metadata diff --git a/src/crewai/utilities/file_handler.py b/src/crewai/utilities/file_handler.py index 4dcd5f4b6..241e65715 100644 --- a/src/crewai/utilities/file_handler.py +++ b/src/crewai/utilities/file_handler.py @@ -90,7 +90,7 @@ class PickleHandler: """ self.save({}) - def save(self, data) -> None: + def save(self, data: Any) -> None: """ Save the data to the specified file using pickle. @@ -100,12 +100,12 @@ class PickleHandler: with open(self.file_path, "wb") as file: pickle.dump(data, file) - def load(self) -> dict: + def load(self) -> Any: """ Load the data from the specified file using pickle. Returns: - - dict: The data loaded from the file. + - The data loaded from the file. """ if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: return {} # Return an empty dictionary if the file does not exist or is empty