test: enhance cache call assertions in crew tests

* Improved the test for cache hitting between agents by filtering mock calls to ensure they include the expected 'tool' and 'input' keywords.
* Added assertions to verify the number of cache calls and their expected arguments, enhancing the reliability of the test.
* Cleaned up whitespace and improved readability in various test cases for better maintainability.
This commit is contained in:
lorenzejay
2025-07-02 15:04:36 -07:00
parent 20a517591d
commit 06ea9dcfe6
2 changed files with 158 additions and 12 deletions

File diff suppressed because one or more lines are too long

View File

@@ -52,6 +52,7 @@ from crewai.utilities.events.memory_events import (
MemoryRetrievalCompletedEvent,
)
@pytest.fixture
def ceo():
return Agent(
@@ -935,12 +936,27 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
read.return_value = "12"
crew.kickoff()
assert read.call_count == 2, "read was not called exactly twice"
# Check if read was called with the expected arguments
expected_calls = [
call(tool="multiplier", input={"first_number": 2, "second_number": 6}),
call(tool="multiplier", input={"first_number": 2, "second_number": 6}),
# Filter the mock calls to only include the ones with 'tool' and 'input' keywords
cache_calls = [
call
for call in read.call_args_list
if len(call.kwargs) == 2
and "tool" in call.kwargs
and "input" in call.kwargs
]
read.assert_has_calls(expected_calls, any_order=False)
# Check if we have the expected number of cache calls
assert len(cache_calls) == 2, f"Expected 2 cache calls, got {len(cache_calls)}"
# Check if both calls were made with the expected arguments
expected_call = call(
tool="multiplier", input={"first_number": 2, "second_number": 6}
)
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
assert (
cache_calls[1] == expected_call
), f"Second call mismatch: {cache_calls[1]}"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -1797,7 +1813,7 @@ def test_hierarchical_kickoff_usage_metrics_include_manager(researcher):
agent=researcher, # *regular* agent
)
# ── 2. Stub out each agents _token_process.get_summary() ───────────────────
# ── 2. Stub out each agent's _token_process.get_summary() ───────────────────
researcher_metrics = UsageMetrics(
total_tokens=120, prompt_tokens=80, completion_tokens=40, successful_requests=2
)
@@ -1821,7 +1837,7 @@ def test_hierarchical_kickoff_usage_metrics_include_manager(researcher):
process=Process.hierarchical,
)
# We dont care about LLM output here; patch execute_sync to avoid network
# We don't care about LLM output here; patch execute_sync to avoid network
with patch.object(
Task,
"execute_sync",
@@ -2489,17 +2505,19 @@ def test_using_contextual_memory():
memory=True,
)
with patch.object(ContextualMemory, "build_context_for_task", return_value="") as contextual_mem:
with patch.object(
ContextualMemory, "build_context_for_task", return_value=""
) as contextual_mem:
crew.kickoff()
contextual_mem.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_memory_events_are_emitted():
events = defaultdict(list)
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemorySaveStartedEvent)
def handle_memory_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
@@ -2562,6 +2580,7 @@ def test_memory_events_are_emitted():
assert len(events["MemoryRetrievalStartedEvent"]) == 1
assert len(events["MemoryRetrievalCompletedEvent"]) == 1
@pytest.mark.vcr(filter_headers=["authorization"])
def test_using_contextual_memory_with_long_term_memory():
from unittest.mock import patch
@@ -2585,7 +2604,9 @@ def test_using_contextual_memory_with_long_term_memory():
long_term_memory=LongTermMemory(),
)
with patch.object(ContextualMemory, "build_context_for_task", return_value="") as contextual_mem:
with patch.object(
ContextualMemory, "build_context_for_task", return_value=""
) as contextual_mem:
crew.kickoff()
contextual_mem.assert_called_once()
assert crew.memory is False
@@ -2686,7 +2707,9 @@ def test_using_contextual_memory_with_short_term_memory():
short_term_memory=ShortTermMemory(),
)
with patch.object(ContextualMemory, "build_context_for_task", return_value="") as contextual_mem:
with patch.object(
ContextualMemory, "build_context_for_task", return_value=""
) as contextual_mem:
crew.kickoff()
contextual_mem.assert_called_once()
assert crew.memory is False
@@ -2715,7 +2738,9 @@ def test_disabled_memory_using_contextual_memory():
memory=False,
)
with patch.object(ContextualMemory, "build_context_for_task", return_value="") as contextual_mem:
with patch.object(
ContextualMemory, "build_context_for_task", return_value=""
) as contextual_mem:
crew.kickoff()
contextual_mem.assert_not_called()