mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 07:38:29 +00:00
Adds async support for tools with tests, async execution in the agent executor, and async operations for memory (with aiosqlite). Improves tool decorator typing, ensures _run backward compatibility, updates docs and docstrings, adds tests, and regenerates lockfiles.
496 lines
18 KiB
Python
496 lines
18 KiB
Python
"""Tests for async memory operations."""
|
|
|
|
import threading
|
|
from collections import defaultdict
|
|
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai.agent import Agent
|
|
from crewai.crew import Crew
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.memory_events import (
|
|
MemoryQueryCompletedEvent,
|
|
MemoryQueryStartedEvent,
|
|
MemorySaveCompletedEvent,
|
|
MemorySaveStartedEvent,
|
|
)
|
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
|
from crewai.memory.entity.entity_memory import EntityMemory
|
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
|
from crewai.memory.external.external_memory import ExternalMemory
|
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
|
from crewai.task import Task
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent():
|
|
"""Fixture to create a mock agent."""
|
|
return Agent(
|
|
role="Researcher",
|
|
goal="Search relevant data and provide results",
|
|
backstory="You are a researcher at a leading tech think tank.",
|
|
tools=[],
|
|
verbose=True,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_task(mock_agent):
|
|
"""Fixture to create a mock task."""
|
|
return Task(
|
|
description="Perform a search on specific topics.",
|
|
expected_output="A list of relevant URLs based on the search query.",
|
|
agent=mock_agent,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def short_term_memory(mock_agent, mock_task):
|
|
"""Fixture to create a ShortTermMemory instance."""
|
|
return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task]))
|
|
|
|
|
|
@pytest.fixture
|
|
def long_term_memory(tmp_path):
|
|
"""Fixture to create a LongTermMemory instance."""
|
|
db_path = str(tmp_path / "test_ltm.db")
|
|
return LongTermMemory(path=db_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def entity_memory(tmp_path, mock_agent, mock_task):
|
|
"""Fixture to create an EntityMemory instance."""
|
|
return EntityMemory(
|
|
crew=Crew(agents=[mock_agent], tasks=[mock_task]),
|
|
path=str(tmp_path / "test_entities"),
|
|
)
|
|
|
|
|
|
class TestAsyncShortTermMemory:
|
|
"""Tests for async ShortTermMemory operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asave_emits_events(self, short_term_memory):
|
|
"""Test that asave emits the correct events."""
|
|
events: dict[str, list] = defaultdict(list)
|
|
condition = threading.Condition()
|
|
|
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
|
def on_save_started(source, event):
|
|
with condition:
|
|
events["MemorySaveStartedEvent"].append(event)
|
|
condition.notify()
|
|
|
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
|
def on_save_completed(source, event):
|
|
with condition:
|
|
events["MemorySaveCompletedEvent"].append(event)
|
|
condition.notify()
|
|
|
|
await short_term_memory.asave(
|
|
value="async test value",
|
|
metadata={"task": "async_test_task"},
|
|
)
|
|
|
|
with condition:
|
|
success = condition.wait_for(
|
|
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
|
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
|
timeout=5,
|
|
)
|
|
assert success, "Timeout waiting for async save events"
|
|
|
|
assert len(events["MemorySaveStartedEvent"]) >= 1
|
|
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
|
assert events["MemorySaveStartedEvent"][-1].value == "async test value"
|
|
assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asearch_emits_events(self, short_term_memory):
|
|
"""Test that asearch emits the correct events."""
|
|
events: dict[str, list] = defaultdict(list)
|
|
search_started = threading.Event()
|
|
search_completed = threading.Event()
|
|
|
|
with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]):
|
|
|
|
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
|
def on_search_started(source, event):
|
|
events["MemoryQueryStartedEvent"].append(event)
|
|
search_started.set()
|
|
|
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
|
def on_search_completed(source, event):
|
|
events["MemoryQueryCompletedEvent"].append(event)
|
|
search_completed.set()
|
|
|
|
await short_term_memory.asearch(
|
|
query="async test query",
|
|
limit=3,
|
|
score_threshold=0.35,
|
|
)
|
|
|
|
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
|
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
|
|
|
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
|
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
|
assert events["MemoryQueryStartedEvent"][-1].query == "async test query"
|
|
assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory"
|
|
|
|
|
|
class TestAsyncLongTermMemory:
|
|
"""Tests for async LongTermMemory operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asave_emits_events(self, long_term_memory):
|
|
"""Test that asave emits the correct events."""
|
|
events: dict[str, list] = defaultdict(list)
|
|
condition = threading.Condition()
|
|
|
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
|
def on_save_started(source, event):
|
|
with condition:
|
|
events["MemorySaveStartedEvent"].append(event)
|
|
condition.notify()
|
|
|
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
|
def on_save_completed(source, event):
|
|
with condition:
|
|
events["MemorySaveCompletedEvent"].append(event)
|
|
condition.notify()
|
|
|
|
item = LongTermMemoryItem(
|
|
task="async test task",
|
|
agent="test_agent",
|
|
expected_output="test output",
|
|
datetime="2024-01-01T00:00:00",
|
|
quality=0.9,
|
|
metadata={"task": "async test task", "quality": 0.9},
|
|
)
|
|
|
|
await long_term_memory.asave(item)
|
|
|
|
with condition:
|
|
success = condition.wait_for(
|
|
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
|
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
|
timeout=5,
|
|
)
|
|
assert success, "Timeout waiting for async save events"
|
|
|
|
assert len(events["MemorySaveStartedEvent"]) >= 1
|
|
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
|
assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asearch_emits_events(self, long_term_memory):
|
|
"""Test that asearch emits the correct events."""
|
|
events: dict[str, list] = defaultdict(list)
|
|
search_started = threading.Event()
|
|
search_completed = threading.Event()
|
|
|
|
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
|
def on_search_started(source, event):
|
|
events["MemoryQueryStartedEvent"].append(event)
|
|
search_started.set()
|
|
|
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
|
def on_search_completed(source, event):
|
|
events["MemoryQueryCompletedEvent"].append(event)
|
|
search_completed.set()
|
|
|
|
await long_term_memory.asearch(task="async test task", latest_n=3)
|
|
|
|
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
|
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
|
|
|
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
|
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
|
assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asave_and_asearch_integration(self, long_term_memory):
|
|
"""Test that asave followed by asearch works correctly."""
|
|
item = LongTermMemoryItem(
|
|
task="integration test task",
|
|
agent="test_agent",
|
|
expected_output="test output",
|
|
datetime="2024-01-01T00:00:00",
|
|
quality=0.9,
|
|
metadata={"task": "integration test task", "quality": 0.9},
|
|
)
|
|
|
|
await long_term_memory.asave(item)
|
|
results = await long_term_memory.asearch(task="integration test task", latest_n=1)
|
|
|
|
assert results is not None
|
|
assert len(results) == 1
|
|
assert results[0]["metadata"]["agent"] == "test_agent"
|
|
|
|
|
|
class TestAsyncEntityMemory:
|
|
"""Tests for async EntityMemory operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asave_single_item_emits_events(self, entity_memory):
|
|
"""Test that asave with a single item emits the correct events."""
|
|
events: dict[str, list] = defaultdict(list)
|
|
condition = threading.Condition()
|
|
|
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
|
def on_save_started(source, event):
|
|
with condition:
|
|
events["MemorySaveStartedEvent"].append(event)
|
|
condition.notify()
|
|
|
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
|
def on_save_completed(source, event):
|
|
with condition:
|
|
events["MemorySaveCompletedEvent"].append(event)
|
|
condition.notify()
|
|
|
|
item = EntityMemoryItem(
|
|
name="TestEntity",
|
|
type="Person",
|
|
description="A test entity for async operations",
|
|
relationships="Related to other test entities",
|
|
)
|
|
|
|
await entity_memory.asave(item)
|
|
|
|
with condition:
|
|
success = condition.wait_for(
|
|
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
|
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
|
timeout=5,
|
|
)
|
|
assert success, "Timeout waiting for async save events"
|
|
|
|
assert len(events["MemorySaveStartedEvent"]) >= 1
|
|
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
|
assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asearch_emits_events(self, entity_memory):
|
|
"""Test that asearch emits the correct events."""
|
|
events: dict[str, list] = defaultdict(list)
|
|
search_started = threading.Event()
|
|
search_completed = threading.Event()
|
|
|
|
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
|
def on_search_started(source, event):
|
|
events["MemoryQueryStartedEvent"].append(event)
|
|
search_started.set()
|
|
|
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
|
def on_search_completed(source, event):
|
|
events["MemoryQueryCompletedEvent"].append(event)
|
|
search_completed.set()
|
|
|
|
await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6)
|
|
|
|
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
|
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
|
|
|
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
|
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
|
assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory"
|
|
|
|
|
|
class TestAsyncContextualMemory:
|
|
"""Tests for async ContextualMemory operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_abuild_context_for_task_with_empty_query(self, mock_task):
|
|
"""Test that abuild_context_for_task returns empty string for empty query."""
|
|
mock_task.description = ""
|
|
contextual_memory = ContextualMemory(
|
|
stm=None,
|
|
ltm=None,
|
|
em=None,
|
|
exm=None,
|
|
)
|
|
|
|
result = await contextual_memory.abuild_context_for_task(mock_task, "")
|
|
assert result == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_abuild_context_for_task_with_none_memories(self, mock_task):
|
|
"""Test that abuild_context_for_task handles None memory sources."""
|
|
contextual_memory = ContextualMemory(
|
|
stm=None,
|
|
ltm=None,
|
|
em=None,
|
|
exm=None,
|
|
)
|
|
|
|
result = await contextual_memory.abuild_context_for_task(mock_task, "some context")
|
|
assert result == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task):
|
|
"""Test that abuild_context_for_task aggregates results from all memory sources."""
|
|
mock_stm = MagicMock(spec=ShortTermMemory)
|
|
mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}])
|
|
|
|
mock_ltm = MagicMock(spec=LongTermMemory)
|
|
mock_ltm.asearch = AsyncMock(
|
|
return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}]
|
|
)
|
|
|
|
mock_em = MagicMock(spec=EntityMemory)
|
|
mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}])
|
|
|
|
mock_exm = MagicMock(spec=ExternalMemory)
|
|
mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}])
|
|
|
|
contextual_memory = ContextualMemory(
|
|
stm=mock_stm,
|
|
ltm=mock_ltm,
|
|
em=mock_em,
|
|
exm=mock_exm,
|
|
agent=mock_agent,
|
|
task=mock_task,
|
|
)
|
|
|
|
result = await contextual_memory.abuild_context_for_task(mock_task, "additional context")
|
|
|
|
assert "Recent Insights:" in result
|
|
assert "STM insight" in result
|
|
assert "Historical Data:" in result
|
|
assert "LTM suggestion" in result
|
|
assert "Entities:" in result
|
|
assert "Entity info" in result
|
|
assert "External memories:" in result
|
|
assert "External memory" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task):
|
|
"""Test that _afetch_stm_context returns properly formatted results."""
|
|
mock_stm = MagicMock(spec=ShortTermMemory)
|
|
mock_stm.asearch = AsyncMock(
|
|
return_value=[
|
|
{"content": "First insight"},
|
|
{"content": "Second insight"},
|
|
]
|
|
)
|
|
|
|
contextual_memory = ContextualMemory(
|
|
stm=mock_stm,
|
|
ltm=None,
|
|
em=None,
|
|
exm=None,
|
|
)
|
|
|
|
result = await contextual_memory._afetch_stm_context("test query")
|
|
|
|
assert "Recent Insights:" in result
|
|
assert "- First insight" in result
|
|
assert "- Second insight" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task):
|
|
"""Test that _afetch_ltm_context returns properly formatted results."""
|
|
mock_ltm = MagicMock(spec=LongTermMemory)
|
|
mock_ltm.asearch = AsyncMock(
|
|
return_value=[
|
|
{"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}},
|
|
]
|
|
)
|
|
|
|
contextual_memory = ContextualMemory(
|
|
stm=None,
|
|
ltm=mock_ltm,
|
|
em=None,
|
|
exm=None,
|
|
)
|
|
|
|
result = await contextual_memory._afetch_ltm_context("test task")
|
|
|
|
assert "Historical Data:" in result
|
|
assert "- Suggestion 1" in result
|
|
assert "- Suggestion 2" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task):
|
|
"""Test that _afetch_entity_context returns properly formatted results."""
|
|
mock_em = MagicMock(spec=EntityMemory)
|
|
mock_em.asearch = AsyncMock(
|
|
return_value=[
|
|
{"content": "Entity A details"},
|
|
{"content": "Entity B details"},
|
|
]
|
|
)
|
|
|
|
contextual_memory = ContextualMemory(
|
|
stm=None,
|
|
ltm=None,
|
|
em=mock_em,
|
|
exm=None,
|
|
)
|
|
|
|
result = await contextual_memory._afetch_entity_context("test query")
|
|
|
|
assert "Entities:" in result
|
|
assert "- Entity A details" in result
|
|
assert "- Entity B details" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_afetch_external_context_returns_formatted_results(self):
|
|
"""Test that _afetch_external_context returns properly formatted results."""
|
|
mock_exm = MagicMock(spec=ExternalMemory)
|
|
mock_exm.asearch = AsyncMock(
|
|
return_value=[
|
|
{"content": "External data 1"},
|
|
{"content": "External data 2"},
|
|
]
|
|
)
|
|
|
|
contextual_memory = ContextualMemory(
|
|
stm=None,
|
|
ltm=None,
|
|
em=None,
|
|
exm=mock_exm,
|
|
)
|
|
|
|
result = await contextual_memory._afetch_external_context("test query")
|
|
|
|
assert "External memories:" in result
|
|
assert "- External data 1" in result
|
|
assert "- External data 2" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_afetch_methods_return_empty_for_empty_results(self):
|
|
"""Test that async fetch methods return empty string for no results."""
|
|
mock_stm = MagicMock(spec=ShortTermMemory)
|
|
mock_stm.asearch = AsyncMock(return_value=[])
|
|
|
|
mock_ltm = MagicMock(spec=LongTermMemory)
|
|
mock_ltm.asearch = AsyncMock(return_value=[])
|
|
|
|
mock_em = MagicMock(spec=EntityMemory)
|
|
mock_em.asearch = AsyncMock(return_value=[])
|
|
|
|
mock_exm = MagicMock(spec=ExternalMemory)
|
|
mock_exm.asearch = AsyncMock(return_value=[])
|
|
|
|
contextual_memory = ContextualMemory(
|
|
stm=mock_stm,
|
|
ltm=mock_ltm,
|
|
em=mock_em,
|
|
exm=mock_exm,
|
|
)
|
|
|
|
stm_result = await contextual_memory._afetch_stm_context("query")
|
|
ltm_result = await contextual_memory._afetch_ltm_context("task")
|
|
em_result = await contextual_memory._afetch_entity_context("query")
|
|
exm_result = await contextual_memory._afetch_external_context("query")
|
|
|
|
assert stm_result == ""
|
|
assert ltm_result is None
|
|
assert em_result == ""
|
|
assert exm_result == "" |