mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: async memory support
Adds async support for tools with tests, async execution in the agent executor, and async operations for memory (with aiosqlite). Improves tool decorator typing, ensures _run backward compatibility, updates docs and docstrings, adds tests, and regenerates lockfiles.
This commit is contained in:
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""Tests for async memory operations."""
|
||||
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Fixture to create a mock agent."""
|
||||
return Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task(mock_agent):
|
||||
"""Fixture to create a mock task."""
|
||||
return Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=mock_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory(mock_agent, mock_task):
|
||||
"""Fixture to create a ShortTermMemory instance."""
|
||||
return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task]))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory(tmp_path):
|
||||
"""Fixture to create a LongTermMemory instance."""
|
||||
db_path = str(tmp_path / "test_ltm.db")
|
||||
return LongTermMemory(path=db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory(tmp_path, mock_agent, mock_task):
|
||||
"""Fixture to create an EntityMemory instance."""
|
||||
return EntityMemory(
|
||||
crew=Crew(agents=[mock_agent], tasks=[mock_task]),
|
||||
path=str(tmp_path / "test_entities"),
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncShortTermMemory:
|
||||
"""Tests for async ShortTermMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_emits_events(self, short_term_memory):
|
||||
"""Test that asave emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
await short_term_memory.asave(
|
||||
value="async test value",
|
||||
metadata={"task": "async_test_task"},
|
||||
)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].value == "async test value"
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, short_term_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]):
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await short_term_memory.asearch(
|
||||
query="async test query",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].query == "async test query"
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory"
|
||||
|
||||
|
||||
class TestAsyncLongTermMemory:
|
||||
"""Tests for async LongTermMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_emits_events(self, long_term_memory):
|
||||
"""Test that asave emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
item = LongTermMemoryItem(
|
||||
task="async test task",
|
||||
agent="test_agent",
|
||||
expected_output="test output",
|
||||
datetime="2024-01-01T00:00:00",
|
||||
quality=0.9,
|
||||
metadata={"task": "async test task", "quality": 0.9},
|
||||
)
|
||||
|
||||
await long_term_memory.asave(item)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, long_term_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await long_term_memory.asearch(task="async test task", latest_n=3)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_and_asearch_integration(self, long_term_memory):
|
||||
"""Test that asave followed by asearch works correctly."""
|
||||
item = LongTermMemoryItem(
|
||||
task="integration test task",
|
||||
agent="test_agent",
|
||||
expected_output="test output",
|
||||
datetime="2024-01-01T00:00:00",
|
||||
quality=0.9,
|
||||
metadata={"task": "integration test task", "quality": 0.9},
|
||||
)
|
||||
|
||||
await long_term_memory.asave(item)
|
||||
results = await long_term_memory.asearch(task="integration test task", latest_n=1)
|
||||
|
||||
assert results is not None
|
||||
assert len(results) == 1
|
||||
assert results[0]["metadata"]["agent"] == "test_agent"
|
||||
|
||||
|
||||
class TestAsyncEntityMemory:
|
||||
"""Tests for async EntityMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_single_item_emits_events(self, entity_memory):
|
||||
"""Test that asave with a single item emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
item = EntityMemoryItem(
|
||||
name="TestEntity",
|
||||
type="Person",
|
||||
description="A test entity for async operations",
|
||||
relationships="Related to other test entities",
|
||||
)
|
||||
|
||||
await entity_memory.asave(item)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, entity_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory"
|
||||
|
||||
|
||||
class TestAsyncContextualMemory:
|
||||
"""Tests for async ContextualMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_with_empty_query(self, mock_task):
|
||||
"""Test that abuild_context_for_task returns empty string for empty query."""
|
||||
mock_task.description = ""
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_with_none_memories(self, mock_task):
|
||||
"""Test that abuild_context_for_task handles None memory sources."""
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "some context")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task):
|
||||
"""Test that abuild_context_for_task aggregates results from all memory sources."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}])
|
||||
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(
|
||||
return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}]
|
||||
)
|
||||
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}])
|
||||
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}])
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=mock_ltm,
|
||||
em=mock_em,
|
||||
exm=mock_exm,
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "additional context")
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "STM insight" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "LTM suggestion" in result
|
||||
assert "Entities:" in result
|
||||
assert "Entity info" in result
|
||||
assert "External memories:" in result
|
||||
assert "External memory" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_stm_context returns properly formatted results."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "First insight"},
|
||||
{"content": "Second insight"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_stm_context("test query")
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "- First insight" in result
|
||||
assert "- Second insight" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_ltm_context returns properly formatted results."""
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=mock_ltm,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_ltm_context("test task")
|
||||
|
||||
assert "Historical Data:" in result
|
||||
assert "- Suggestion 1" in result
|
||||
assert "- Suggestion 2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_entity_context returns properly formatted results."""
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "Entity A details"},
|
||||
{"content": "Entity B details"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=mock_em,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_entity_context("test query")
|
||||
|
||||
assert "Entities:" in result
|
||||
assert "- Entity A details" in result
|
||||
assert "- Entity B details" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_external_context_returns_formatted_results(self):
|
||||
"""Test that _afetch_external_context returns properly formatted results."""
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "External data 1"},
|
||||
{"content": "External data 2"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=mock_exm,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_external_context("test query")
|
||||
|
||||
assert "External memories:" in result
|
||||
assert "- External data 1" in result
|
||||
assert "- External data 2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_methods_return_empty_for_empty_results(self):
|
||||
"""Test that async fetch methods return empty string for no results."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=mock_ltm,
|
||||
em=mock_em,
|
||||
exm=mock_exm,
|
||||
)
|
||||
|
||||
stm_result = await contextual_memory._afetch_stm_context("query")
|
||||
ltm_result = await contextual_memory._afetch_ltm_context("task")
|
||||
em_result = await contextual_memory._afetch_entity_context("query")
|
||||
exm_result = await contextual_memory._afetch_external_context("query")
|
||||
|
||||
assert stm_result == ""
|
||||
assert ltm_result is None
|
||||
assert em_result == ""
|
||||
assert exm_result == ""
|
||||
Reference in New Issue
Block a user