mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
fix: resolve additional mypy type annotation issues
- Fixed file_handler.py PickleHandler type annotations - Fixed task_events.py None checks before accessing task.fingerprint - Added type annotations to memory_listener.py event handlers
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
@@ -19,9 +20,11 @@ class MemoryListener(BaseEventListener):
|
||||
self.memory_retrieval_in_progress = False
|
||||
self.memory_save_in_progress = False
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
|
||||
def on_memory_retrieval_started(source, event: MemoryRetrievalStartedEvent):
|
||||
def on_memory_retrieval_started(
|
||||
source: Any, event: MemoryRetrievalStartedEvent
|
||||
) -> None:
|
||||
if self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -33,7 +36,9 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
|
||||
def on_memory_retrieval_completed(source, event: MemoryRetrievalCompletedEvent):
|
||||
def on_memory_retrieval_completed(
|
||||
source: Any, event: MemoryRetrievalCompletedEvent
|
||||
) -> None:
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -46,7 +51,9 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_memory_query_completed(source, event: MemoryQueryCompletedEvent):
|
||||
def on_memory_query_completed(
|
||||
source: Any, event: MemoryQueryCompletedEvent
|
||||
) -> None:
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -58,7 +65,7 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryFailedEvent)
|
||||
def on_memory_query_failed(source, event: MemoryQueryFailedEvent):
|
||||
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -70,7 +77,7 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_memory_save_started(source, event: MemorySaveStartedEvent):
|
||||
def on_memory_save_started(source: Any, event: MemorySaveStartedEvent) -> None:
|
||||
if self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
@@ -82,7 +89,9 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_memory_save_completed(source, event: MemorySaveCompletedEvent):
|
||||
def on_memory_save_completed(
|
||||
source: Any, event: MemorySaveCompletedEvent
|
||||
) -> None:
|
||||
if not self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
@@ -96,7 +105,7 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveFailedEvent)
|
||||
def on_memory_save_failed(source, event: MemorySaveFailedEvent):
|
||||
def on_memory_save_failed(source: Any, event: MemorySaveFailedEvent) -> None:
|
||||
if not self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
|
||||
@@ -14,11 +14,12 @@ class TaskStartedEvent(BaseEvent):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
self.task
|
||||
and hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
@@ -34,11 +35,12 @@ class TaskCompletedEvent(BaseEvent):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
self.task
|
||||
and hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
@@ -54,11 +56,12 @@ class TaskFailedEvent(BaseEvent):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
self.task
|
||||
and hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
@@ -74,11 +77,12 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
self.task
|
||||
and hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
|
||||
@@ -90,7 +90,7 @@ class PickleHandler:
|
||||
"""
|
||||
self.save({})
|
||||
|
||||
def save(self, data) -> None:
|
||||
def save(self, data: Any) -> None:
|
||||
"""
|
||||
Save the data to the specified file using pickle.
|
||||
|
||||
@@ -100,12 +100,12 @@ class PickleHandler:
|
||||
with open(self.file_path, "wb") as file:
|
||||
pickle.dump(data, file)
|
||||
|
||||
def load(self) -> dict:
|
||||
def load(self) -> Any:
|
||||
"""
|
||||
Load the data from the specified file using pickle.
|
||||
|
||||
Returns:
|
||||
- dict: The data loaded from the file.
|
||||
- The data loaded from the file.
|
||||
"""
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
return {} # Return an empty dictionary if the file does not exist or is empty
|
||||
|
||||
Reference in New Issue
Block a user