mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 16:48:13 +00:00
test: enhance cache call assertions in crew tests
* Improved the test for cache hitting between agents by filtering mock calls to ensure they include the expected 'tool' and 'input' keywords. * Added assertions to verify the number of cache calls and their expected arguments, enhancing the reliability of the test. * Cleaned up whitespace and improved readability in various test cases for better maintainability.
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -52,6 +52,7 @@ from crewai.utilities.events.memory_events import (
|
|||||||
MemoryRetrievalCompletedEvent,
|
MemoryRetrievalCompletedEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ceo():
|
def ceo():
|
||||||
return Agent(
|
return Agent(
|
||||||
@@ -935,12 +936,27 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
|
|||||||
read.return_value = "12"
|
read.return_value = "12"
|
||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
assert read.call_count == 2, "read was not called exactly twice"
|
assert read.call_count == 2, "read was not called exactly twice"
|
||||||
# Check if read was called with the expected arguments
|
|
||||||
expected_calls = [
|
# Filter the mock calls to only include the ones with 'tool' and 'input' keywords
|
||||||
call(tool="multiplier", input={"first_number": 2, "second_number": 6}),
|
cache_calls = [
|
||||||
call(tool="multiplier", input={"first_number": 2, "second_number": 6}),
|
call
|
||||||
|
for call in read.call_args_list
|
||||||
|
if len(call.kwargs) == 2
|
||||||
|
and "tool" in call.kwargs
|
||||||
|
and "input" in call.kwargs
|
||||||
]
|
]
|
||||||
read.assert_has_calls(expected_calls, any_order=False)
|
|
||||||
|
# Check if we have the expected number of cache calls
|
||||||
|
assert len(cache_calls) == 2, f"Expected 2 cache calls, got {len(cache_calls)}"
|
||||||
|
|
||||||
|
# Check if both calls were made with the expected arguments
|
||||||
|
expected_call = call(
|
||||||
|
tool="multiplier", input={"first_number": 2, "second_number": 6}
|
||||||
|
)
|
||||||
|
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
|
||||||
|
assert (
|
||||||
|
cache_calls[1] == expected_call
|
||||||
|
), f"Second call mismatch: {cache_calls[1]}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@@ -1797,7 +1813,7 @@ def test_hierarchical_kickoff_usage_metrics_include_manager(researcher):
|
|||||||
agent=researcher, # *regular* agent
|
agent=researcher, # *regular* agent
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── 2. Stub out each agent’s _token_process.get_summary() ───────────────────
|
# ── 2. Stub out each agent's _token_process.get_summary() ───────────────────
|
||||||
researcher_metrics = UsageMetrics(
|
researcher_metrics = UsageMetrics(
|
||||||
total_tokens=120, prompt_tokens=80, completion_tokens=40, successful_requests=2
|
total_tokens=120, prompt_tokens=80, completion_tokens=40, successful_requests=2
|
||||||
)
|
)
|
||||||
@@ -1821,7 +1837,7 @@ def test_hierarchical_kickoff_usage_metrics_include_manager(researcher):
|
|||||||
process=Process.hierarchical,
|
process=Process.hierarchical,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We don’t care about LLM output here; patch execute_sync to avoid network
|
# We don't care about LLM output here; patch execute_sync to avoid network
|
||||||
with patch.object(
|
with patch.object(
|
||||||
Task,
|
Task,
|
||||||
"execute_sync",
|
"execute_sync",
|
||||||
@@ -2489,17 +2505,19 @@ def test_using_contextual_memory():
|
|||||||
memory=True,
|
memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(ContextualMemory, "build_context_for_task", return_value="") as contextual_mem:
|
with patch.object(
|
||||||
|
ContextualMemory, "build_context_for_task", return_value=""
|
||||||
|
) as contextual_mem:
|
||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
contextual_mem.assert_called_once()
|
contextual_mem.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_memory_events_are_emitted():
|
def test_memory_events_are_emitted():
|
||||||
events = defaultdict(list)
|
events = defaultdict(list)
|
||||||
|
|
||||||
with crewai_event_bus.scoped_handlers():
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
|
||||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
def handle_memory_save_started(source, event):
|
def handle_memory_save_started(source, event):
|
||||||
events["MemorySaveStartedEvent"].append(event)
|
events["MemorySaveStartedEvent"].append(event)
|
||||||
@@ -2562,6 +2580,7 @@ def test_memory_events_are_emitted():
|
|||||||
assert len(events["MemoryRetrievalStartedEvent"]) == 1
|
assert len(events["MemoryRetrievalStartedEvent"]) == 1
|
||||||
assert len(events["MemoryRetrievalCompletedEvent"]) == 1
|
assert len(events["MemoryRetrievalCompletedEvent"]) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_using_contextual_memory_with_long_term_memory():
|
def test_using_contextual_memory_with_long_term_memory():
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@@ -2585,7 +2604,9 @@ def test_using_contextual_memory_with_long_term_memory():
|
|||||||
long_term_memory=LongTermMemory(),
|
long_term_memory=LongTermMemory(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(ContextualMemory, "build_context_for_task", return_value="") as contextual_mem:
|
with patch.object(
|
||||||
|
ContextualMemory, "build_context_for_task", return_value=""
|
||||||
|
) as contextual_mem:
|
||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
contextual_mem.assert_called_once()
|
contextual_mem.assert_called_once()
|
||||||
assert crew.memory is False
|
assert crew.memory is False
|
||||||
@@ -2686,7 +2707,9 @@ def test_using_contextual_memory_with_short_term_memory():
|
|||||||
short_term_memory=ShortTermMemory(),
|
short_term_memory=ShortTermMemory(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(ContextualMemory, "build_context_for_task", return_value="") as contextual_mem:
|
with patch.object(
|
||||||
|
ContextualMemory, "build_context_for_task", return_value=""
|
||||||
|
) as contextual_mem:
|
||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
contextual_mem.assert_called_once()
|
contextual_mem.assert_called_once()
|
||||||
assert crew.memory is False
|
assert crew.memory is False
|
||||||
@@ -2715,7 +2738,9 @@ def test_disabled_memory_using_contextual_memory():
|
|||||||
memory=False,
|
memory=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(ContextualMemory, "build_context_for_task", return_value="") as contextual_mem:
|
with patch.object(
|
||||||
|
ContextualMemory, "build_context_for_task", return_value=""
|
||||||
|
) as contextual_mem:
|
||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
contextual_mem.assert_not_called()
|
contextual_mem.assert_not_called()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user