From 109de91d0891a6916d6ae9d85caa9348843678de Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Wed, 27 Aug 2025 10:47:20 -0400 Subject: [PATCH] fix: batch entity memory items to reduce redundant operations (#3409) * fix: batch save entity memory items to reduce redundant operations * test: update memory event count after entity batch save implementation --- .../base_agent_executor_mixin.py | 9 +- src/crewai/memory/entity/entity_memory.py | 84 +++++++-- tests/test_crew.py | 172 +++++++++--------- 3 files changed, 157 insertions(+), 108 deletions(-) diff --git a/src/crewai/agents/agent_builder/base_agent_executor_mixin.py b/src/crewai/agents/agent_builder/base_agent_executor_mixin.py index 35e683846..e15144c8f 100644 --- a/src/crewai/agents/agent_builder/base_agent_executor_mixin.py +++ b/src/crewai/agents/agent_builder/base_agent_executor_mixin.py @@ -98,8 +98,8 @@ class CrewAgentExecutorMixin: ) self.crew._long_term_memory.save(long_term_memory) - for entity in evaluation.entities: - entity_memory = EntityMemoryItem( + entity_memories = [ + EntityMemoryItem( name=entity.name, type=entity.type, description=entity.description, @@ -107,7 +107,10 @@ class CrewAgentExecutorMixin: [f"- {r}" for r in entity.relationships] ), ) - self.crew._entity_memory.save(entity_memory) + for entity in evaluation.entities + ] + if entity_memories: + self.crew._entity_memory.save(entity_memories) except AttributeError as e: print(f"Missing attributes for long term memory: {e}") pass diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 3b37a84de..63583665d 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any import time from pydantic import PrivateAttr @@ -24,7 +24,7 @@ class EntityMemory(Memory): Inherits from the Memory class. """ - _memory_provider: Optional[str] = PrivateAttr() + _memory_provider: str | None = PrivateAttr() def __init__(self, crew=None, embedder_config=None, storage=None, path=None): memory_provider = embedder_config.get("provider") if embedder_config else None @@ -53,12 +53,33 @@ class EntityMemory(Memory): super().__init__(storage=storage) self._memory_provider = memory_provider - def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" - """Saves an entity item into the SQLite storage.""" + def save( + self, + value: EntityMemoryItem | list[EntityMemoryItem], + metadata: dict[str, Any] | None = None, + ) -> None: + """Saves one or more entity items into the SQLite storage. + + Args: + value: Single EntityMemoryItem or list of EntityMemoryItems to save. + metadata: Optional metadata dict (included for supertype compatibility but not used). + + Notes: + The metadata parameter is included to satisfy the supertype signature but is not + used - entity metadata is extracted from the EntityMemoryItem objects themselves. + """ + + if not value: + return + + items = value if isinstance(value, list) else [value] + is_batch = len(items) > 1 + + metadata = {"entity_count": len(items)} if is_batch else items[0].metadata crewai_event_bus.emit( self, event=MemorySaveStartedEvent( - metadata=item.metadata, + metadata=metadata, source_type="entity_memory", from_agent=self.agent, from_task=self.task, @@ -66,36 +87,61 @@ class EntityMemory(Memory): ) start_time = time.time() + saved_count = 0 + errors = [] + try: - if self._memory_provider == "mem0": - data = f""" - Remember details about the following entity: - Name: {item.name} - Type: {item.type} - Entity Description: {item.description} - """ + for item in items: + try: + if self._memory_provider == "mem0": + data = f""" + Remember details about the following entity: + Name: {item.name} + Type: {item.type} + Entity Description: {item.description} + """ + else: + data = f"{item.name}({item.type}): {item.description}" + + super().save(data, item.metadata) + saved_count += 1 + except Exception as e: + errors.append(f"{item.name}: {str(e)}") + + if is_batch: + emit_value = f"Saved {saved_count} entities" + metadata = {"entity_count": saved_count, "errors": errors} else: - data = f"{item.name}({item.type}): {item.description}" + emit_value = f"{items[0].name}({items[0].type}): {items[0].description}" + metadata = items[0].metadata - super().save(data, item.metadata) - - # Emit memory save completed event crewai_event_bus.emit( self, event=MemorySaveCompletedEvent( - value=data, - metadata=item.metadata, + value=emit_value, + metadata=metadata, save_time_ms=(time.time() - start_time) * 1000, source_type="entity_memory", from_agent=self.agent, from_task=self.task, ), ) + + if errors: + raise Exception( + f"Partial save: {len(errors)} failed out of {len(items)}" + ) + except Exception as e: + fail_metadata = ( + {"entity_count": len(items), "saved": saved_count} + if is_batch + else items[0].metadata + ) crewai_event_bus.emit( self, event=MemorySaveFailedEvent( - metadata=item.metadata, + metadata=fail_metadata, error=str(e), source_type="entity_memory", from_agent=self.agent, diff --git a/tests/test_crew.py b/tests/test_crew.py index 4c55ee4d5..26f0cf40d 100644 --- a/tests/test_crew.py +++ b/tests/test_crew.py @@ -624,12 +624,12 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer) _, kwargs = mock_execute_sync.call_args tools = kwargs["tools"] - assert any(isinstance(tool, TestTool) for tool in tools), ( - "TestTool should be present" - ) - assert any("delegate" in tool.name.lower() for tool in tools), ( - "Delegation tool should be present" - ) + assert any( + isinstance(tool, TestTool) for tool in tools + ), "TestTool should be present" + assert any( + "delegate" in tool.name.lower() for tool in tools + ), "Delegation tool should be present" @pytest.mark.vcr(filter_headers=["authorization"]) @@ -688,12 +688,12 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer _, kwargs = mock_execute_sync.call_args tools = kwargs["tools"] - assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), ( - "TestTool should be present" - ) - assert any("delegate" in tool.name.lower() for tool in tools), ( - "Delegation tool should be present" - ) + assert any( + isinstance(tool, TestTool) for tool in new_ceo.tools + ), "TestTool should be present" + assert any( + "delegate" in tool.name.lower() for tool in tools + ), "Delegation tool should be present" @pytest.mark.vcr(filter_headers=["authorization"]) @@ -817,17 +817,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write used_tools = kwargs["tools"] # Confirm AnotherTestTool is present but TestTool is not - assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), ( - "AnotherTestTool should be present" - ) - assert not any(isinstance(tool, TestTool) for tool in used_tools), ( - "TestTool should not be present among used tools" - ) + assert any( + isinstance(tool, AnotherTestTool) for tool in used_tools + ), "AnotherTestTool should be present" + assert not any( + isinstance(tool, TestTool) for tool in used_tools + ), "TestTool should not be present among used tools" # Confirm delegation tool(s) are present - assert any("delegate" in tool.name.lower() for tool in used_tools), ( - "Delegation tool should be present" - ) + assert any( + "delegate" in tool.name.lower() for tool in used_tools + ), "Delegation tool should be present" # Finally, make sure the agent's original tools remain unchanged assert len(researcher_with_delegation.tools) == 1 @@ -931,9 +931,9 @@ def test_cache_hitting_between_agents(researcher, writer, ceo): tool="multiplier", input={"first_number": 2, "second_number": 6} ) assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}" - assert cache_calls[1] == expected_call, ( - f"Second call mismatch: {cache_calls[1]}" - ) + assert ( + cache_calls[1] == expected_call + ), f"Second call mismatch: {cache_calls[1]}" @pytest.mark.vcr(filter_headers=["authorization"]) @@ -1676,9 +1676,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff(): # Verify that exactly one tool was used and it was a CodeInterpreterTool assert len(used_tools) == 1, "Should have exactly one tool" - assert isinstance(used_tools[0], CodeInterpreterTool), ( - "Tool should be CodeInterpreterTool" - ) + assert isinstance( + used_tools[0], CodeInterpreterTool + ), "Tool should be CodeInterpreterTool" @pytest.mark.vcr(filter_headers=["authorization"]) @@ -2537,8 +2537,8 @@ def test_memory_events_are_emitted(): crew.kickoff() - assert len(events["MemorySaveStartedEvent"]) == 6 - assert len(events["MemorySaveCompletedEvent"]) == 6 + assert len(events["MemorySaveStartedEvent"]) == 3 + assert len(events["MemorySaveCompletedEvent"]) == 3 assert len(events["MemorySaveFailedEvent"]) == 0 assert len(events["MemoryQueryStartedEvent"]) == 3 assert len(events["MemoryQueryCompletedEvent"]) == 3 @@ -3817,9 +3817,9 @@ def test_fetch_inputs(): expected_placeholders = {"role_detail", "topic", "field"} actual_placeholders = crew.fetch_inputs() - assert actual_placeholders == expected_placeholders, ( - f"Expected {expected_placeholders}, but got {actual_placeholders}" - ) + assert ( + actual_placeholders == expected_placeholders + ), f"Expected {expected_placeholders}, but got {actual_placeholders}" def test_task_tools_preserve_code_execution_tools(): @@ -3894,20 +3894,20 @@ def test_task_tools_preserve_code_execution_tools(): used_tools = kwargs["tools"] # Verify all expected tools are present - assert any(isinstance(tool, TestTool) for tool in used_tools), ( - "Task's TestTool should be present" - ) - assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), ( - "CodeInterpreterTool should be present" - ) - assert any("delegate" in tool.name.lower() for tool in used_tools), ( - "Delegation tool should be present" - ) + assert any( + isinstance(tool, TestTool) for tool in used_tools + ), "Task's TestTool should be present" + assert any( + isinstance(tool, CodeInterpreterTool) for tool in used_tools + ), "CodeInterpreterTool should be present" + assert any( + "delegate" in tool.name.lower() for tool in used_tools + ), "Delegation tool should be present" # Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools) - assert len(used_tools) == 4, ( - "Should have TestTool, CodeInterpreter, and 2 delegation tools" - ) + assert ( + len(used_tools) == 4 + ), "Should have TestTool, CodeInterpreter, and 2 delegation tools" @pytest.mark.vcr(filter_headers=["authorization"]) @@ -3951,9 +3951,9 @@ def test_multimodal_flag_adds_multimodal_tools(): used_tools = kwargs["tools"] # Check that the multimodal tool was added - assert any(isinstance(tool, AddImageTool) for tool in used_tools), ( - "AddImageTool should be present when agent is multimodal" - ) + assert any( + isinstance(tool, AddImageTool) for tool in used_tools + ), "AddImageTool should be present when agent is multimodal" # Verify we have exactly one tool (just the AddImageTool) assert len(used_tools) == 1, "Should only have the AddImageTool" @@ -4217,9 +4217,9 @@ def test_crew_guardrail_feedback_in_context(): assert len(execution_contexts) > 1, "Task should have been executed multiple times" # Verify that the second execution included the guardrail feedback - assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], ( - "Guardrail feedback should be included in retry context" - ) + assert ( + "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1] + ), "Guardrail feedback should be included in retry context" # Verify final output meets guardrail requirements assert "IMPORTANT" in result.raw, "Final output should contain required keyword" @@ -4435,46 +4435,46 @@ def test_crew_copy_with_memory(): try: crew_copy = crew.copy() - assert hasattr(crew_copy, "_short_term_memory"), ( - "Copied crew should have _short_term_memory" - ) - assert crew_copy._short_term_memory is not None, ( - "Copied _short_term_memory should not be None" - ) - assert id(crew_copy._short_term_memory) != original_short_term_id, ( - "Copied _short_term_memory should be a new object" - ) + assert hasattr( + crew_copy, "_short_term_memory" + ), "Copied crew should have _short_term_memory" + assert ( + crew_copy._short_term_memory is not None + ), "Copied _short_term_memory should not be None" + assert ( + id(crew_copy._short_term_memory) != original_short_term_id + ), "Copied _short_term_memory should be a new object" - assert hasattr(crew_copy, "_long_term_memory"), ( - "Copied crew should have _long_term_memory" - ) - assert crew_copy._long_term_memory is not None, ( - "Copied _long_term_memory should not be None" - ) - assert id(crew_copy._long_term_memory) != original_long_term_id, ( - "Copied _long_term_memory should be a new object" - ) + assert hasattr( + crew_copy, "_long_term_memory" + ), "Copied crew should have _long_term_memory" + assert ( + crew_copy._long_term_memory is not None + ), "Copied _long_term_memory should not be None" + assert ( + id(crew_copy._long_term_memory) != original_long_term_id + ), "Copied _long_term_memory should be a new object" - assert hasattr(crew_copy, "_entity_memory"), ( - "Copied crew should have _entity_memory" - ) - assert crew_copy._entity_memory is not None, ( - "Copied _entity_memory should not be None" - ) - assert id(crew_copy._entity_memory) != original_entity_id, ( - "Copied _entity_memory should be a new object" - ) + assert hasattr( + crew_copy, "_entity_memory" + ), "Copied crew should have _entity_memory" + assert ( + crew_copy._entity_memory is not None + ), "Copied _entity_memory should not be None" + assert ( + id(crew_copy._entity_memory) != original_entity_id + ), "Copied _entity_memory should be a new object" if original_external_id: - assert hasattr(crew_copy, "_external_memory"), ( - "Copied crew should have _external_memory" - ) - assert crew_copy._external_memory is not None, ( - "Copied _external_memory should not be None" - ) - assert id(crew_copy._external_memory) != original_external_id, ( - "Copied _external_memory should be a new object" - ) + assert hasattr( + crew_copy, "_external_memory" + ), "Copied crew should have _external_memory" + assert ( + crew_copy._external_memory is not None + ), "Copied _external_memory should not be None" + assert ( + id(crew_copy._external_memory) != original_external_id + ), "Copied _external_memory should be a new object" else: assert ( not hasattr(crew_copy, "_external_memory")