mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 17:48:13 +00:00
fix: resolve additional mypy type annotation issues
- Fixed file_handler.py PickleHandler type annotations - Fixed task_events.py None checks before accessing task.fingerprint - Added type annotations to memory_listener.py event handlers
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from crewai.events.base_event_listener import BaseEventListener
|
from crewai.events.base_event_listener import BaseEventListener
|
||||||
|
from crewai.events.event_bus import CrewAIEventsBus
|
||||||
from crewai.events.types.memory_events import (
|
from crewai.events.types.memory_events import (
|
||||||
MemoryQueryCompletedEvent,
|
MemoryQueryCompletedEvent,
|
||||||
MemoryQueryFailedEvent,
|
MemoryQueryFailedEvent,
|
||||||
@@ -19,9 +20,11 @@ class MemoryListener(BaseEventListener):
|
|||||||
self.memory_retrieval_in_progress = False
|
self.memory_retrieval_in_progress = False
|
||||||
self.memory_save_in_progress = False
|
self.memory_save_in_progress = False
|
||||||
|
|
||||||
def setup_listeners(self, crewai_event_bus):
|
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||||
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
|
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
|
||||||
def on_memory_retrieval_started(source, event: MemoryRetrievalStartedEvent):
|
def on_memory_retrieval_started(
|
||||||
|
source: Any, event: MemoryRetrievalStartedEvent
|
||||||
|
) -> None:
|
||||||
if self.memory_retrieval_in_progress:
|
if self.memory_retrieval_in_progress:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -33,7 +36,9 @@ class MemoryListener(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
|
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
|
||||||
def on_memory_retrieval_completed(source, event: MemoryRetrievalCompletedEvent):
|
def on_memory_retrieval_completed(
|
||||||
|
source: Any, event: MemoryRetrievalCompletedEvent
|
||||||
|
) -> None:
|
||||||
if not self.memory_retrieval_in_progress:
|
if not self.memory_retrieval_in_progress:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -46,7 +51,9 @@ class MemoryListener(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||||
def on_memory_query_completed(source, event: MemoryQueryCompletedEvent):
|
def on_memory_query_completed(
|
||||||
|
source: Any, event: MemoryQueryCompletedEvent
|
||||||
|
) -> None:
|
||||||
if not self.memory_retrieval_in_progress:
|
if not self.memory_retrieval_in_progress:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -58,7 +65,7 @@ class MemoryListener(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@crewai_event_bus.on(MemoryQueryFailedEvent)
|
@crewai_event_bus.on(MemoryQueryFailedEvent)
|
||||||
def on_memory_query_failed(source, event: MemoryQueryFailedEvent):
|
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
|
||||||
if not self.memory_retrieval_in_progress:
|
if not self.memory_retrieval_in_progress:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -70,7 +77,7 @@ class MemoryListener(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
def on_memory_save_started(source, event: MemorySaveStartedEvent):
|
def on_memory_save_started(source: Any, event: MemorySaveStartedEvent) -> None:
|
||||||
if self.memory_save_in_progress:
|
if self.memory_save_in_progress:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -82,7 +89,9 @@ class MemoryListener(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||||
def on_memory_save_completed(source, event: MemorySaveCompletedEvent):
|
def on_memory_save_completed(
|
||||||
|
source: Any, event: MemorySaveCompletedEvent
|
||||||
|
) -> None:
|
||||||
if not self.memory_save_in_progress:
|
if not self.memory_save_in_progress:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -96,7 +105,7 @@ class MemoryListener(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@crewai_event_bus.on(MemorySaveFailedEvent)
|
@crewai_event_bus.on(MemorySaveFailedEvent)
|
||||||
def on_memory_save_failed(source, event: MemorySaveFailedEvent):
|
def on_memory_save_failed(source: Any, event: MemorySaveFailedEvent) -> None:
|
||||||
if not self.memory_save_in_progress:
|
if not self.memory_save_in_progress:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,12 @@ class TaskStartedEvent(BaseEvent):
|
|||||||
def __init__(self, **data: Any) -> None:
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
# Set fingerprint data from the task
|
# Set fingerprint data from the task
|
||||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||||
self.source_type = "task"
|
self.source_type = "task"
|
||||||
if (
|
if (
|
||||||
hasattr(self.task.fingerprint, "metadata")
|
self.task
|
||||||
|
and hasattr(self.task.fingerprint, "metadata")
|
||||||
and self.task.fingerprint.metadata
|
and self.task.fingerprint.metadata
|
||||||
):
|
):
|
||||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||||
@@ -34,11 +35,12 @@ class TaskCompletedEvent(BaseEvent):
|
|||||||
def __init__(self, **data: Any) -> None:
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
# Set fingerprint data from the task
|
# Set fingerprint data from the task
|
||||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||||
self.source_type = "task"
|
self.source_type = "task"
|
||||||
if (
|
if (
|
||||||
hasattr(self.task.fingerprint, "metadata")
|
self.task
|
||||||
|
and hasattr(self.task.fingerprint, "metadata")
|
||||||
and self.task.fingerprint.metadata
|
and self.task.fingerprint.metadata
|
||||||
):
|
):
|
||||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||||
@@ -54,11 +56,12 @@ class TaskFailedEvent(BaseEvent):
|
|||||||
def __init__(self, **data: Any) -> None:
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
# Set fingerprint data from the task
|
# Set fingerprint data from the task
|
||||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||||
self.source_type = "task"
|
self.source_type = "task"
|
||||||
if (
|
if (
|
||||||
hasattr(self.task.fingerprint, "metadata")
|
self.task
|
||||||
|
and hasattr(self.task.fingerprint, "metadata")
|
||||||
and self.task.fingerprint.metadata
|
and self.task.fingerprint.metadata
|
||||||
):
|
):
|
||||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||||
@@ -74,11 +77,12 @@ class TaskEvaluationEvent(BaseEvent):
|
|||||||
def __init__(self, **data: Any) -> None:
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
# Set fingerprint data from the task
|
# Set fingerprint data from the task
|
||||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||||
self.source_type = "task"
|
self.source_type = "task"
|
||||||
if (
|
if (
|
||||||
hasattr(self.task.fingerprint, "metadata")
|
self.task
|
||||||
|
and hasattr(self.task.fingerprint, "metadata")
|
||||||
and self.task.fingerprint.metadata
|
and self.task.fingerprint.metadata
|
||||||
):
|
):
|
||||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class PickleHandler:
|
|||||||
"""
|
"""
|
||||||
self.save({})
|
self.save({})
|
||||||
|
|
||||||
def save(self, data) -> None:
|
def save(self, data: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Save the data to the specified file using pickle.
|
Save the data to the specified file using pickle.
|
||||||
|
|
||||||
@@ -100,12 +100,12 @@ class PickleHandler:
|
|||||||
with open(self.file_path, "wb") as file:
|
with open(self.file_path, "wb") as file:
|
||||||
pickle.dump(data, file)
|
pickle.dump(data, file)
|
||||||
|
|
||||||
def load(self) -> dict:
|
def load(self) -> Any:
|
||||||
"""
|
"""
|
||||||
Load the data from the specified file using pickle.
|
Load the data from the specified file using pickle.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict: The data loaded from the file.
|
- The data loaded from the file.
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||||
return {} # Return an empty dictionary if the file does not exist or is empty
|
return {} # Return an empty dictionary if the file does not exist or is empty
|
||||||
|
|||||||
Reference in New Issue
Block a user