mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-13 01:58:30 +00:00
fix: resolve type signature incompatibility in LongTermMemory
This commit fixes the type signature incompatibility issue in the LongTermMemory class (issue #4213). The save(), asave(), search(), and asearch() methods now have signatures compatible with the Memory base class, following the Liskov Substitution Principle. Changes: - Renamed 'item' parameter to 'value' in save() and asave() methods - Added 'metadata' parameter to save() and asave() methods for LSP compliance - Renamed 'task' parameter to 'query' in search() and asearch() methods - Renamed 'latest_n' parameter to 'limit' in search() and asearch() methods - Added 'score_threshold' parameter to search() and asearch() methods - Removed '# type: ignore' comments that were suppressing the type errors - Updated existing tests to use new parameter names - Added comprehensive tests to verify type signature compatibility Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -33,13 +33,24 @@ class LongTermMemory(Memory):
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
def save(
|
||||
self,
|
||||
value: LongTermMemoryItem,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save an item to long-term memory.
|
||||
|
||||
Args:
|
||||
value: The LongTermMemoryItem to save.
|
||||
metadata: Optional metadata dict (not used, metadata is extracted from the
|
||||
LongTermMemoryItem). Included for supertype compatibility.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -48,23 +59,23 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
item_metadata = value.metadata
|
||||
item_metadata.update(
|
||||
{"agent": value.agent, "expected_output": value.expected_output}
|
||||
)
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
task_description=value.task,
|
||||
score=item_metadata["quality"],
|
||||
metadata=item_metadata,
|
||||
datetime=value.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -75,25 +86,28 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def search( # type: ignore[override]
|
||||
def search(
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
query: The task description to search for.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results (not used for
|
||||
long-term memory, included for supertype compatibility).
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
@@ -101,8 +115,8 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
query=query,
|
||||
limit=limit,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -111,14 +125,14 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(task, latest_n)
|
||||
results = self.storage.load(query, limit)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
query=query,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
limit=limit,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -131,26 +145,32 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
query=query,
|
||||
limit=limit,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
|
||||
async def asave(
|
||||
self,
|
||||
value: LongTermMemoryItem,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save an item to long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
item: The LongTermMemoryItem to save.
|
||||
value: The LongTermMemoryItem to save.
|
||||
metadata: Optional metadata dict (not used, metadata is extracted from the
|
||||
LongTermMemoryItem). Included for supertype compatibility.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -159,23 +179,23 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
item_metadata = value.metadata
|
||||
item_metadata.update(
|
||||
{"agent": value.agent, "expected_output": value.expected_output}
|
||||
)
|
||||
await self.storage.asave(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
task_description=value.task,
|
||||
score=item_metadata["quality"],
|
||||
metadata=item_metadata,
|
||||
datetime=value.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -186,25 +206,28 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch( # type: ignore[override]
|
||||
async def asearch(
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
query: The task description to search for.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results (not used for
|
||||
long-term memory, included for supertype compatibility).
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
@@ -212,8 +235,8 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
query=query,
|
||||
limit=limit,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -222,14 +245,14 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.aload(task, latest_n)
|
||||
results = await self.storage.aload(query, limit)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
query=query,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
limit=limit,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -242,8 +265,8 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
query=query,
|
||||
limit=limit,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import inspect
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
@@ -13,6 +15,7 @@ from crewai.events.types.memory_events import (
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -114,7 +117,7 @@ def test_long_term_memory_search_events(long_term_memory):
|
||||
|
||||
test_query = "test query"
|
||||
|
||||
long_term_memory.search(test_query, latest_n=5)
|
||||
long_term_memory.search(test_query, limit=5)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
@@ -174,10 +177,104 @@ def test_save_and_search(long_term_memory):
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
find = long_term_memory.search("test_task", latest_n=5)[0]
|
||||
find = long_term_memory.search("test_task", limit=5)[0]
|
||||
assert find["score"] == 0.5
|
||||
assert find["datetime"] == "test_datetime"
|
||||
assert find["metadata"]["agent"] == "test_agent"
|
||||
assert find["metadata"]["quality"] == 0.5
|
||||
assert find["metadata"]["task"] == "test_task"
|
||||
assert find["metadata"]["expected_output"] == "test_output"
|
||||
|
||||
|
||||
class TestLongTermMemoryTypeSignatureCompatibility:
|
||||
"""Tests to verify LongTermMemory method signatures are compatible with Memory base class.
|
||||
|
||||
These tests ensure that the Liskov Substitution Principle is maintained and that
|
||||
LongTermMemory can be used polymorphically wherever Memory is expected.
|
||||
"""
|
||||
|
||||
def test_save_signature_has_value_parameter(self):
|
||||
"""Test that save() uses 'value' parameter name matching Memory base class."""
|
||||
sig = inspect.signature(LongTermMemory.save)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "value" in params, "save() should have 'value' parameter for LSP compliance"
|
||||
assert "metadata" in params, "save() should have 'metadata' parameter for LSP compliance"
|
||||
|
||||
def test_save_signature_has_metadata_with_default(self):
|
||||
"""Test that save() has metadata parameter with default value."""
|
||||
sig = inspect.signature(LongTermMemory.save)
|
||||
metadata_param = sig.parameters.get("metadata")
|
||||
assert metadata_param is not None, "save() should have 'metadata' parameter"
|
||||
assert metadata_param.default is None, "metadata should default to None"
|
||||
|
||||
def test_search_signature_has_query_parameter(self):
|
||||
"""Test that search() uses 'query' parameter name matching Memory base class."""
|
||||
sig = inspect.signature(LongTermMemory.search)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "query" in params, "search() should have 'query' parameter for LSP compliance"
|
||||
assert "limit" in params, "search() should have 'limit' parameter for LSP compliance"
|
||||
assert "score_threshold" in params, "search() should have 'score_threshold' parameter for LSP compliance"
|
||||
|
||||
def test_search_signature_has_score_threshold_with_default(self):
|
||||
"""Test that search() has score_threshold parameter with default value."""
|
||||
sig = inspect.signature(LongTermMemory.search)
|
||||
score_threshold_param = sig.parameters.get("score_threshold")
|
||||
assert score_threshold_param is not None, "search() should have 'score_threshold' parameter"
|
||||
assert score_threshold_param.default == 0.6, "score_threshold should default to 0.6"
|
||||
|
||||
def test_asave_signature_has_value_parameter(self):
|
||||
"""Test that asave() uses 'value' parameter name matching Memory base class."""
|
||||
sig = inspect.signature(LongTermMemory.asave)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "value" in params, "asave() should have 'value' parameter for LSP compliance"
|
||||
assert "metadata" in params, "asave() should have 'metadata' parameter for LSP compliance"
|
||||
|
||||
def test_asearch_signature_has_query_parameter(self):
|
||||
"""Test that asearch() uses 'query' parameter name matching Memory base class."""
|
||||
sig = inspect.signature(LongTermMemory.asearch)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "query" in params, "asearch() should have 'query' parameter for LSP compliance"
|
||||
assert "limit" in params, "asearch() should have 'limit' parameter for LSP compliance"
|
||||
assert "score_threshold" in params, "asearch() should have 'score_threshold' parameter for LSP compliance"
|
||||
|
||||
def test_long_term_memory_is_subclass_of_memory(self):
|
||||
"""Test that LongTermMemory is a proper subclass of Memory."""
|
||||
assert issubclass(LongTermMemory, Memory), "LongTermMemory should be a subclass of Memory"
|
||||
|
||||
def test_save_with_metadata_parameter(self, long_term_memory):
|
||||
"""Test that save() can be called with the metadata parameter (even if unused)."""
|
||||
memory_item = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task_with_metadata",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.8,
|
||||
metadata={"task": "test_task_with_metadata", "quality": 0.8},
|
||||
)
|
||||
long_term_memory.save(value=memory_item, metadata={"extra": "data"})
|
||||
results = long_term_memory.search(query="test_task_with_metadata", limit=1)
|
||||
assert len(results) > 0
|
||||
assert results[0]["metadata"]["agent"] == "test_agent"
|
||||
|
||||
def test_search_with_score_threshold_parameter(self, long_term_memory):
|
||||
"""Test that search() can be called with the score_threshold parameter."""
|
||||
memory_item = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task_score_threshold",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.9,
|
||||
metadata={"task": "test_task_score_threshold", "quality": 0.9},
|
||||
)
|
||||
long_term_memory.save(value=memory_item)
|
||||
results = long_term_memory.search(
|
||||
query="test_task_score_threshold",
|
||||
limit=5,
|
||||
score_threshold=0.5,
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory(self):
|
||||
"""Fixture to create a LongTermMemory instance for this test class."""
|
||||
return LongTermMemory()
|
||||
|
||||
Reference in New Issue
Block a user