mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 07:42:40 +00:00
fix: batch entity memory items to reduce redundant operations (#3409)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
* fix: batch save entity memory items to reduce redundant operations * test: update memory event count after entity batch save implementation
This commit is contained in:
@@ -98,8 +98,8 @@ class CrewAgentExecutorMixin:
|
|||||||
)
|
)
|
||||||
self.crew._long_term_memory.save(long_term_memory)
|
self.crew._long_term_memory.save(long_term_memory)
|
||||||
|
|
||||||
for entity in evaluation.entities:
|
entity_memories = [
|
||||||
entity_memory = EntityMemoryItem(
|
EntityMemoryItem(
|
||||||
name=entity.name,
|
name=entity.name,
|
||||||
type=entity.type,
|
type=entity.type,
|
||||||
description=entity.description,
|
description=entity.description,
|
||||||
@@ -107,7 +107,10 @@ class CrewAgentExecutorMixin:
|
|||||||
[f"- {r}" for r in entity.relationships]
|
[f"- {r}" for r in entity.relationships]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.crew._entity_memory.save(entity_memory)
|
for entity in evaluation.entities
|
||||||
|
]
|
||||||
|
if entity_memories:
|
||||||
|
self.crew._entity_memory.save(entity_memories)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
print(f"Missing attributes for long term memory: {e}")
|
print(f"Missing attributes for long term memory: {e}")
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Any
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from pydantic import PrivateAttr
|
from pydantic import PrivateAttr
|
||||||
@@ -24,7 +24,7 @@ class EntityMemory(Memory):
|
|||||||
Inherits from the Memory class.
|
Inherits from the Memory class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_memory_provider: Optional[str] = PrivateAttr()
|
_memory_provider: str | None = PrivateAttr()
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||||
@@ -53,12 +53,33 @@ class EntityMemory(Memory):
|
|||||||
super().__init__(storage=storage)
|
super().__init__(storage=storage)
|
||||||
self._memory_provider = memory_provider
|
self._memory_provider = memory_provider
|
||||||
|
|
||||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
def save(
|
||||||
"""Saves an entity item into the SQLite storage."""
|
self,
|
||||||
|
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Saves one or more entity items into the SQLite storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||||
|
metadata: Optional metadata dict (included for supertype compatibility but not used).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The metadata parameter is included to satisfy the supertype signature but is not
|
||||||
|
used - entity metadata is extracted from the EntityMemoryItem objects themselves.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
|
|
||||||
|
items = value if isinstance(value, list) else [value]
|
||||||
|
is_batch = len(items) > 1
|
||||||
|
|
||||||
|
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemorySaveStartedEvent(
|
event=MemorySaveStartedEvent(
|
||||||
metadata=item.metadata,
|
metadata=metadata,
|
||||||
source_type="entity_memory",
|
source_type="entity_memory",
|
||||||
from_agent=self.agent,
|
from_agent=self.agent,
|
||||||
from_task=self.task,
|
from_task=self.task,
|
||||||
@@ -66,36 +87,61 @@ class EntityMemory(Memory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
saved_count = 0
|
||||||
|
errors = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self._memory_provider == "mem0":
|
for item in items:
|
||||||
data = f"""
|
try:
|
||||||
Remember details about the following entity:
|
if self._memory_provider == "mem0":
|
||||||
Name: {item.name}
|
data = f"""
|
||||||
Type: {item.type}
|
Remember details about the following entity:
|
||||||
Entity Description: {item.description}
|
Name: {item.name}
|
||||||
"""
|
Type: {item.type}
|
||||||
|
Entity Description: {item.description}
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
data = f"{item.name}({item.type}): {item.description}"
|
||||||
|
|
||||||
|
super().save(data, item.metadata)
|
||||||
|
saved_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"{item.name}: {str(e)}")
|
||||||
|
|
||||||
|
if is_batch:
|
||||||
|
emit_value = f"Saved {saved_count} entities"
|
||||||
|
metadata = {"entity_count": saved_count, "errors": errors}
|
||||||
else:
|
else:
|
||||||
data = f"{item.name}({item.type}): {item.description}"
|
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||||
|
metadata = items[0].metadata
|
||||||
|
|
||||||
super().save(data, item.metadata)
|
|
||||||
|
|
||||||
# Emit memory save completed event
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemorySaveCompletedEvent(
|
event=MemorySaveCompletedEvent(
|
||||||
value=data,
|
value=emit_value,
|
||||||
metadata=item.metadata,
|
metadata=metadata,
|
||||||
save_time_ms=(time.time() - start_time) * 1000,
|
save_time_ms=(time.time() - start_time) * 1000,
|
||||||
source_type="entity_memory",
|
source_type="entity_memory",
|
||||||
from_agent=self.agent,
|
from_agent=self.agent,
|
||||||
from_task=self.task,
|
from_task=self.task,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
raise Exception(
|
||||||
|
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
fail_metadata = (
|
||||||
|
{"entity_count": len(items), "saved": saved_count}
|
||||||
|
if is_batch
|
||||||
|
else items[0].metadata
|
||||||
|
)
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemorySaveFailedEvent(
|
event=MemorySaveFailedEvent(
|
||||||
metadata=item.metadata,
|
metadata=fail_metadata,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
source_type="entity_memory",
|
source_type="entity_memory",
|
||||||
from_agent=self.agent,
|
from_agent=self.agent,
|
||||||
|
|||||||
@@ -624,12 +624,12 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
|
|||||||
_, kwargs = mock_execute_sync.call_args
|
_, kwargs = mock_execute_sync.call_args
|
||||||
tools = kwargs["tools"]
|
tools = kwargs["tools"]
|
||||||
|
|
||||||
assert any(isinstance(tool, TestTool) for tool in tools), (
|
assert any(
|
||||||
"TestTool should be present"
|
isinstance(tool, TestTool) for tool in tools
|
||||||
)
|
), "TestTool should be present"
|
||||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
assert any(
|
||||||
"Delegation tool should be present"
|
"delegate" in tool.name.lower() for tool in tools
|
||||||
)
|
), "Delegation tool should be present"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@@ -688,12 +688,12 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
|
|||||||
_, kwargs = mock_execute_sync.call_args
|
_, kwargs = mock_execute_sync.call_args
|
||||||
tools = kwargs["tools"]
|
tools = kwargs["tools"]
|
||||||
|
|
||||||
assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), (
|
assert any(
|
||||||
"TestTool should be present"
|
isinstance(tool, TestTool) for tool in new_ceo.tools
|
||||||
)
|
), "TestTool should be present"
|
||||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
assert any(
|
||||||
"Delegation tool should be present"
|
"delegate" in tool.name.lower() for tool in tools
|
||||||
)
|
), "Delegation tool should be present"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@@ -817,17 +817,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
|||||||
used_tools = kwargs["tools"]
|
used_tools = kwargs["tools"]
|
||||||
|
|
||||||
# Confirm AnotherTestTool is present but TestTool is not
|
# Confirm AnotherTestTool is present but TestTool is not
|
||||||
assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), (
|
assert any(
|
||||||
"AnotherTestTool should be present"
|
isinstance(tool, AnotherTestTool) for tool in used_tools
|
||||||
)
|
), "AnotherTestTool should be present"
|
||||||
assert not any(isinstance(tool, TestTool) for tool in used_tools), (
|
assert not any(
|
||||||
"TestTool should not be present among used tools"
|
isinstance(tool, TestTool) for tool in used_tools
|
||||||
)
|
), "TestTool should not be present among used tools"
|
||||||
|
|
||||||
# Confirm delegation tool(s) are present
|
# Confirm delegation tool(s) are present
|
||||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
assert any(
|
||||||
"Delegation tool should be present"
|
"delegate" in tool.name.lower() for tool in used_tools
|
||||||
)
|
), "Delegation tool should be present"
|
||||||
|
|
||||||
# Finally, make sure the agent's original tools remain unchanged
|
# Finally, make sure the agent's original tools remain unchanged
|
||||||
assert len(researcher_with_delegation.tools) == 1
|
assert len(researcher_with_delegation.tools) == 1
|
||||||
@@ -931,9 +931,9 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
|
|||||||
tool="multiplier", input={"first_number": 2, "second_number": 6}
|
tool="multiplier", input={"first_number": 2, "second_number": 6}
|
||||||
)
|
)
|
||||||
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
|
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
|
||||||
assert cache_calls[1] == expected_call, (
|
assert (
|
||||||
f"Second call mismatch: {cache_calls[1]}"
|
cache_calls[1] == expected_call
|
||||||
)
|
), f"Second call mismatch: {cache_calls[1]}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@@ -1676,9 +1676,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
|
|||||||
|
|
||||||
# Verify that exactly one tool was used and it was a CodeInterpreterTool
|
# Verify that exactly one tool was used and it was a CodeInterpreterTool
|
||||||
assert len(used_tools) == 1, "Should have exactly one tool"
|
assert len(used_tools) == 1, "Should have exactly one tool"
|
||||||
assert isinstance(used_tools[0], CodeInterpreterTool), (
|
assert isinstance(
|
||||||
"Tool should be CodeInterpreterTool"
|
used_tools[0], CodeInterpreterTool
|
||||||
)
|
), "Tool should be CodeInterpreterTool"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@@ -2537,8 +2537,8 @@ def test_memory_events_are_emitted():
|
|||||||
|
|
||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
|
|
||||||
assert len(events["MemorySaveStartedEvent"]) == 6
|
assert len(events["MemorySaveStartedEvent"]) == 3
|
||||||
assert len(events["MemorySaveCompletedEvent"]) == 6
|
assert len(events["MemorySaveCompletedEvent"]) == 3
|
||||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||||
assert len(events["MemoryQueryStartedEvent"]) == 3
|
assert len(events["MemoryQueryStartedEvent"]) == 3
|
||||||
assert len(events["MemoryQueryCompletedEvent"]) == 3
|
assert len(events["MemoryQueryCompletedEvent"]) == 3
|
||||||
@@ -3817,9 +3817,9 @@ def test_fetch_inputs():
|
|||||||
expected_placeholders = {"role_detail", "topic", "field"}
|
expected_placeholders = {"role_detail", "topic", "field"}
|
||||||
actual_placeholders = crew.fetch_inputs()
|
actual_placeholders = crew.fetch_inputs()
|
||||||
|
|
||||||
assert actual_placeholders == expected_placeholders, (
|
assert (
|
||||||
f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
actual_placeholders == expected_placeholders
|
||||||
)
|
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||||
|
|
||||||
|
|
||||||
def test_task_tools_preserve_code_execution_tools():
|
def test_task_tools_preserve_code_execution_tools():
|
||||||
@@ -3894,20 +3894,20 @@ def test_task_tools_preserve_code_execution_tools():
|
|||||||
used_tools = kwargs["tools"]
|
used_tools = kwargs["tools"]
|
||||||
|
|
||||||
# Verify all expected tools are present
|
# Verify all expected tools are present
|
||||||
assert any(isinstance(tool, TestTool) for tool in used_tools), (
|
assert any(
|
||||||
"Task's TestTool should be present"
|
isinstance(tool, TestTool) for tool in used_tools
|
||||||
)
|
), "Task's TestTool should be present"
|
||||||
assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), (
|
assert any(
|
||||||
"CodeInterpreterTool should be present"
|
isinstance(tool, CodeInterpreterTool) for tool in used_tools
|
||||||
)
|
), "CodeInterpreterTool should be present"
|
||||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
assert any(
|
||||||
"Delegation tool should be present"
|
"delegate" in tool.name.lower() for tool in used_tools
|
||||||
)
|
), "Delegation tool should be present"
|
||||||
|
|
||||||
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
|
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
|
||||||
assert len(used_tools) == 4, (
|
assert (
|
||||||
"Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
len(used_tools) == 4
|
||||||
)
|
), "Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@@ -3951,9 +3951,9 @@ def test_multimodal_flag_adds_multimodal_tools():
|
|||||||
used_tools = kwargs["tools"]
|
used_tools = kwargs["tools"]
|
||||||
|
|
||||||
# Check that the multimodal tool was added
|
# Check that the multimodal tool was added
|
||||||
assert any(isinstance(tool, AddImageTool) for tool in used_tools), (
|
assert any(
|
||||||
"AddImageTool should be present when agent is multimodal"
|
isinstance(tool, AddImageTool) for tool in used_tools
|
||||||
)
|
), "AddImageTool should be present when agent is multimodal"
|
||||||
|
|
||||||
# Verify we have exactly one tool (just the AddImageTool)
|
# Verify we have exactly one tool (just the AddImageTool)
|
||||||
assert len(used_tools) == 1, "Should only have the AddImageTool"
|
assert len(used_tools) == 1, "Should only have the AddImageTool"
|
||||||
@@ -4217,9 +4217,9 @@ def test_crew_guardrail_feedback_in_context():
|
|||||||
assert len(execution_contexts) > 1, "Task should have been executed multiple times"
|
assert len(execution_contexts) > 1, "Task should have been executed multiple times"
|
||||||
|
|
||||||
# Verify that the second execution included the guardrail feedback
|
# Verify that the second execution included the guardrail feedback
|
||||||
assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], (
|
assert (
|
||||||
"Guardrail feedback should be included in retry context"
|
"Output must contain the keyword 'IMPORTANT'" in execution_contexts[1]
|
||||||
)
|
), "Guardrail feedback should be included in retry context"
|
||||||
|
|
||||||
# Verify final output meets guardrail requirements
|
# Verify final output meets guardrail requirements
|
||||||
assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
|
assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
|
||||||
@@ -4435,46 +4435,46 @@ def test_crew_copy_with_memory():
|
|||||||
try:
|
try:
|
||||||
crew_copy = crew.copy()
|
crew_copy = crew.copy()
|
||||||
|
|
||||||
assert hasattr(crew_copy, "_short_term_memory"), (
|
assert hasattr(
|
||||||
"Copied crew should have _short_term_memory"
|
crew_copy, "_short_term_memory"
|
||||||
)
|
), "Copied crew should have _short_term_memory"
|
||||||
assert crew_copy._short_term_memory is not None, (
|
assert (
|
||||||
"Copied _short_term_memory should not be None"
|
crew_copy._short_term_memory is not None
|
||||||
)
|
), "Copied _short_term_memory should not be None"
|
||||||
assert id(crew_copy._short_term_memory) != original_short_term_id, (
|
assert (
|
||||||
"Copied _short_term_memory should be a new object"
|
id(crew_copy._short_term_memory) != original_short_term_id
|
||||||
)
|
), "Copied _short_term_memory should be a new object"
|
||||||
|
|
||||||
assert hasattr(crew_copy, "_long_term_memory"), (
|
assert hasattr(
|
||||||
"Copied crew should have _long_term_memory"
|
crew_copy, "_long_term_memory"
|
||||||
)
|
), "Copied crew should have _long_term_memory"
|
||||||
assert crew_copy._long_term_memory is not None, (
|
assert (
|
||||||
"Copied _long_term_memory should not be None"
|
crew_copy._long_term_memory is not None
|
||||||
)
|
), "Copied _long_term_memory should not be None"
|
||||||
assert id(crew_copy._long_term_memory) != original_long_term_id, (
|
assert (
|
||||||
"Copied _long_term_memory should be a new object"
|
id(crew_copy._long_term_memory) != original_long_term_id
|
||||||
)
|
), "Copied _long_term_memory should be a new object"
|
||||||
|
|
||||||
assert hasattr(crew_copy, "_entity_memory"), (
|
assert hasattr(
|
||||||
"Copied crew should have _entity_memory"
|
crew_copy, "_entity_memory"
|
||||||
)
|
), "Copied crew should have _entity_memory"
|
||||||
assert crew_copy._entity_memory is not None, (
|
assert (
|
||||||
"Copied _entity_memory should not be None"
|
crew_copy._entity_memory is not None
|
||||||
)
|
), "Copied _entity_memory should not be None"
|
||||||
assert id(crew_copy._entity_memory) != original_entity_id, (
|
assert (
|
||||||
"Copied _entity_memory should be a new object"
|
id(crew_copy._entity_memory) != original_entity_id
|
||||||
)
|
), "Copied _entity_memory should be a new object"
|
||||||
|
|
||||||
if original_external_id:
|
if original_external_id:
|
||||||
assert hasattr(crew_copy, "_external_memory"), (
|
assert hasattr(
|
||||||
"Copied crew should have _external_memory"
|
crew_copy, "_external_memory"
|
||||||
)
|
), "Copied crew should have _external_memory"
|
||||||
assert crew_copy._external_memory is not None, (
|
assert (
|
||||||
"Copied _external_memory should not be None"
|
crew_copy._external_memory is not None
|
||||||
)
|
), "Copied _external_memory should not be None"
|
||||||
assert id(crew_copy._external_memory) != original_external_id, (
|
assert (
|
||||||
"Copied _external_memory should be a new object"
|
id(crew_copy._external_memory) != original_external_id
|
||||||
)
|
), "Copied _external_memory should be a new object"
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
not hasattr(crew_copy, "_external_memory")
|
not hasattr(crew_copy, "_external_memory")
|
||||||
|
|||||||
Reference in New Issue
Block a user