fix: batch entity memory items to reduce redundant operations (#3409)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled

* fix: batch save entity memory items to reduce redundant operations

* test: update memory event count after entity batch save implementation
This commit is contained in:
Greyson LaLonde
2025-08-27 10:47:20 -04:00
committed by GitHub
parent 92b70e652d
commit 109de91d08
3 changed files with 157 additions and 108 deletions

View File

@@ -98,8 +98,8 @@ class CrewAgentExecutorMixin:
) )
self.crew._long_term_memory.save(long_term_memory) self.crew._long_term_memory.save(long_term_memory)
for entity in evaluation.entities: entity_memories = [
entity_memory = EntityMemoryItem( EntityMemoryItem(
name=entity.name, name=entity.name,
type=entity.type, type=entity.type,
description=entity.description, description=entity.description,
@@ -107,7 +107,10 @@ class CrewAgentExecutorMixin:
[f"- {r}" for r in entity.relationships] [f"- {r}" for r in entity.relationships]
), ),
) )
self.crew._entity_memory.save(entity_memory) for entity in evaluation.entities
]
if entity_memories:
self.crew._entity_memory.save(entity_memories)
except AttributeError as e: except AttributeError as e:
print(f"Missing attributes for long term memory: {e}") print(f"Missing attributes for long term memory: {e}")
pass pass

View File

@@ -1,4 +1,4 @@
from typing import Optional from typing import Any
import time import time
from pydantic import PrivateAttr from pydantic import PrivateAttr
@@ -24,7 +24,7 @@ class EntityMemory(Memory):
Inherits from the Memory class. Inherits from the Memory class.
""" """
_memory_provider: Optional[str] = PrivateAttr() _memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None memory_provider = embedder_config.get("provider") if embedder_config else None
@@ -53,12 +53,33 @@ class EntityMemory(Memory):
super().__init__(storage=storage) super().__init__(storage=storage)
self._memory_provider = memory_provider self._memory_provider = memory_provider
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" def save(
"""Saves an entity item into the SQLite storage.""" self,
value: EntityMemoryItem | list[EntityMemoryItem],
metadata: dict[str, Any] | None = None,
) -> None:
"""Saves one or more entity items into the SQLite storage.
Args:
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
metadata: Optional metadata dict (included for supertype compatibility but not used).
Notes:
The metadata parameter is included to satisfy the supertype signature but is not
used - entity metadata is extracted from the EntityMemoryItem objects themselves.
"""
if not value:
return
items = value if isinstance(value, list) else [value]
is_batch = len(items) > 1
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemorySaveStartedEvent( event=MemorySaveStartedEvent(
metadata=item.metadata, metadata=metadata,
source_type="entity_memory", source_type="entity_memory",
from_agent=self.agent, from_agent=self.agent,
from_task=self.task, from_task=self.task,
@@ -66,36 +87,61 @@ class EntityMemory(Memory):
) )
start_time = time.time() start_time = time.time()
saved_count = 0
errors = []
try: try:
if self._memory_provider == "mem0": for item in items:
data = f""" try:
Remember details about the following entity: if self._memory_provider == "mem0":
Name: {item.name} data = f"""
Type: {item.type} Remember details about the following entity:
Entity Description: {item.description} Name: {item.name}
""" Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
saved_count += 1
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
if is_batch:
emit_value = f"Saved {saved_count} entities"
metadata = {"entity_count": saved_count, "errors": errors}
else: else:
data = f"{item.name}({item.type}): {item.description}" emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
metadata = items[0].metadata
super().save(data, item.metadata)
# Emit memory save completed event
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemorySaveCompletedEvent( event=MemorySaveCompletedEvent(
value=data, value=emit_value,
metadata=item.metadata, metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000, save_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory", source_type="entity_memory",
from_agent=self.agent, from_agent=self.agent,
from_task=self.task, from_task=self.task,
), ),
) )
if errors:
raise Exception(
f"Partial save: {len(errors)} failed out of {len(items)}"
)
except Exception as e: except Exception as e:
fail_metadata = (
{"entity_count": len(items), "saved": saved_count}
if is_batch
else items[0].metadata
)
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemorySaveFailedEvent( event=MemorySaveFailedEvent(
metadata=item.metadata, metadata=fail_metadata,
error=str(e), error=str(e),
source_type="entity_memory", source_type="entity_memory",
from_agent=self.agent, from_agent=self.agent,

View File

@@ -624,12 +624,12 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
_, kwargs = mock_execute_sync.call_args _, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"] tools = kwargs["tools"]
assert any(isinstance(tool, TestTool) for tool in tools), ( assert any(
"TestTool should be present" isinstance(tool, TestTool) for tool in tools
) ), "TestTool should be present"
assert any("delegate" in tool.name.lower() for tool in tools), ( assert any(
"Delegation tool should be present" "delegate" in tool.name.lower() for tool in tools
) ), "Delegation tool should be present"
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -688,12 +688,12 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
_, kwargs = mock_execute_sync.call_args _, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"] tools = kwargs["tools"]
assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), ( assert any(
"TestTool should be present" isinstance(tool, TestTool) for tool in new_ceo.tools
) ), "TestTool should be present"
assert any("delegate" in tool.name.lower() for tool in tools), ( assert any(
"Delegation tool should be present" "delegate" in tool.name.lower() for tool in tools
) ), "Delegation tool should be present"
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -817,17 +817,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
used_tools = kwargs["tools"] used_tools = kwargs["tools"]
# Confirm AnotherTestTool is present but TestTool is not # Confirm AnotherTestTool is present but TestTool is not
assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), ( assert any(
"AnotherTestTool should be present" isinstance(tool, AnotherTestTool) for tool in used_tools
) ), "AnotherTestTool should be present"
assert not any(isinstance(tool, TestTool) for tool in used_tools), ( assert not any(
"TestTool should not be present among used tools" isinstance(tool, TestTool) for tool in used_tools
) ), "TestTool should not be present among used tools"
# Confirm delegation tool(s) are present # Confirm delegation tool(s) are present
assert any("delegate" in tool.name.lower() for tool in used_tools), ( assert any(
"Delegation tool should be present" "delegate" in tool.name.lower() for tool in used_tools
) ), "Delegation tool should be present"
# Finally, make sure the agent's original tools remain unchanged # Finally, make sure the agent's original tools remain unchanged
assert len(researcher_with_delegation.tools) == 1 assert len(researcher_with_delegation.tools) == 1
@@ -931,9 +931,9 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
tool="multiplier", input={"first_number": 2, "second_number": 6} tool="multiplier", input={"first_number": 2, "second_number": 6}
) )
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}" assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
assert cache_calls[1] == expected_call, ( assert (
f"Second call mismatch: {cache_calls[1]}" cache_calls[1] == expected_call
) ), f"Second call mismatch: {cache_calls[1]}"
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -1676,9 +1676,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
# Verify that exactly one tool was used and it was a CodeInterpreterTool # Verify that exactly one tool was used and it was a CodeInterpreterTool
assert len(used_tools) == 1, "Should have exactly one tool" assert len(used_tools) == 1, "Should have exactly one tool"
assert isinstance(used_tools[0], CodeInterpreterTool), ( assert isinstance(
"Tool should be CodeInterpreterTool" used_tools[0], CodeInterpreterTool
) ), "Tool should be CodeInterpreterTool"
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -2537,8 +2537,8 @@ def test_memory_events_are_emitted():
crew.kickoff() crew.kickoff()
assert len(events["MemorySaveStartedEvent"]) == 6 assert len(events["MemorySaveStartedEvent"]) == 3
assert len(events["MemorySaveCompletedEvent"]) == 6 assert len(events["MemorySaveCompletedEvent"]) == 3
assert len(events["MemorySaveFailedEvent"]) == 0 assert len(events["MemorySaveFailedEvent"]) == 0
assert len(events["MemoryQueryStartedEvent"]) == 3 assert len(events["MemoryQueryStartedEvent"]) == 3
assert len(events["MemoryQueryCompletedEvent"]) == 3 assert len(events["MemoryQueryCompletedEvent"]) == 3
@@ -3817,9 +3817,9 @@ def test_fetch_inputs():
expected_placeholders = {"role_detail", "topic", "field"} expected_placeholders = {"role_detail", "topic", "field"}
actual_placeholders = crew.fetch_inputs() actual_placeholders = crew.fetch_inputs()
assert actual_placeholders == expected_placeholders, ( assert (
f"Expected {expected_placeholders}, but got {actual_placeholders}" actual_placeholders == expected_placeholders
) ), f"Expected {expected_placeholders}, but got {actual_placeholders}"
def test_task_tools_preserve_code_execution_tools(): def test_task_tools_preserve_code_execution_tools():
@@ -3894,20 +3894,20 @@ def test_task_tools_preserve_code_execution_tools():
used_tools = kwargs["tools"] used_tools = kwargs["tools"]
# Verify all expected tools are present # Verify all expected tools are present
assert any(isinstance(tool, TestTool) for tool in used_tools), ( assert any(
"Task's TestTool should be present" isinstance(tool, TestTool) for tool in used_tools
) ), "Task's TestTool should be present"
assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), ( assert any(
"CodeInterpreterTool should be present" isinstance(tool, CodeInterpreterTool) for tool in used_tools
) ), "CodeInterpreterTool should be present"
assert any("delegate" in tool.name.lower() for tool in used_tools), ( assert any(
"Delegation tool should be present" "delegate" in tool.name.lower() for tool in used_tools
) ), "Delegation tool should be present"
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools) # Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
assert len(used_tools) == 4, ( assert (
"Should have TestTool, CodeInterpreter, and 2 delegation tools" len(used_tools) == 4
) ), "Should have TestTool, CodeInterpreter, and 2 delegation tools"
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -3951,9 +3951,9 @@ def test_multimodal_flag_adds_multimodal_tools():
used_tools = kwargs["tools"] used_tools = kwargs["tools"]
# Check that the multimodal tool was added # Check that the multimodal tool was added
assert any(isinstance(tool, AddImageTool) for tool in used_tools), ( assert any(
"AddImageTool should be present when agent is multimodal" isinstance(tool, AddImageTool) for tool in used_tools
) ), "AddImageTool should be present when agent is multimodal"
# Verify we have exactly one tool (just the AddImageTool) # Verify we have exactly one tool (just the AddImageTool)
assert len(used_tools) == 1, "Should only have the AddImageTool" assert len(used_tools) == 1, "Should only have the AddImageTool"
@@ -4217,9 +4217,9 @@ def test_crew_guardrail_feedback_in_context():
assert len(execution_contexts) > 1, "Task should have been executed multiple times" assert len(execution_contexts) > 1, "Task should have been executed multiple times"
# Verify that the second execution included the guardrail feedback # Verify that the second execution included the guardrail feedback
assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], ( assert (
"Guardrail feedback should be included in retry context" "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1]
) ), "Guardrail feedback should be included in retry context"
# Verify final output meets guardrail requirements # Verify final output meets guardrail requirements
assert "IMPORTANT" in result.raw, "Final output should contain required keyword" assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
@@ -4435,46 +4435,46 @@ def test_crew_copy_with_memory():
try: try:
crew_copy = crew.copy() crew_copy = crew.copy()
assert hasattr(crew_copy, "_short_term_memory"), ( assert hasattr(
"Copied crew should have _short_term_memory" crew_copy, "_short_term_memory"
) ), "Copied crew should have _short_term_memory"
assert crew_copy._short_term_memory is not None, ( assert (
"Copied _short_term_memory should not be None" crew_copy._short_term_memory is not None
) ), "Copied _short_term_memory should not be None"
assert id(crew_copy._short_term_memory) != original_short_term_id, ( assert (
"Copied _short_term_memory should be a new object" id(crew_copy._short_term_memory) != original_short_term_id
) ), "Copied _short_term_memory should be a new object"
assert hasattr(crew_copy, "_long_term_memory"), ( assert hasattr(
"Copied crew should have _long_term_memory" crew_copy, "_long_term_memory"
) ), "Copied crew should have _long_term_memory"
assert crew_copy._long_term_memory is not None, ( assert (
"Copied _long_term_memory should not be None" crew_copy._long_term_memory is not None
) ), "Copied _long_term_memory should not be None"
assert id(crew_copy._long_term_memory) != original_long_term_id, ( assert (
"Copied _long_term_memory should be a new object" id(crew_copy._long_term_memory) != original_long_term_id
) ), "Copied _long_term_memory should be a new object"
assert hasattr(crew_copy, "_entity_memory"), ( assert hasattr(
"Copied crew should have _entity_memory" crew_copy, "_entity_memory"
) ), "Copied crew should have _entity_memory"
assert crew_copy._entity_memory is not None, ( assert (
"Copied _entity_memory should not be None" crew_copy._entity_memory is not None
) ), "Copied _entity_memory should not be None"
assert id(crew_copy._entity_memory) != original_entity_id, ( assert (
"Copied _entity_memory should be a new object" id(crew_copy._entity_memory) != original_entity_id
) ), "Copied _entity_memory should be a new object"
if original_external_id: if original_external_id:
assert hasattr(crew_copy, "_external_memory"), ( assert hasattr(
"Copied crew should have _external_memory" crew_copy, "_external_memory"
) ), "Copied crew should have _external_memory"
assert crew_copy._external_memory is not None, ( assert (
"Copied _external_memory should not be None" crew_copy._external_memory is not None
) ), "Copied _external_memory should not be None"
assert id(crew_copy._external_memory) != original_external_id, ( assert (
"Copied _external_memory should be a new object" id(crew_copy._external_memory) != original_external_id
) ), "Copied _external_memory should be a new object"
else: else:
assert ( assert (
not hasattr(crew_copy, "_external_memory") not hasattr(crew_copy, "_external_memory")