Add addtiional validation and tests. Also, if memory was set to True on a Crew, you were required to have a mem0 key. Fixed this direct dependency.

This commit is contained in:
Brandon Hancock
2024-10-09 09:52:56 -04:00
parent 5d6eb6e9c1
commit 0354ad378b
11 changed files with 490 additions and 43 deletions

View File

@@ -1,18 +1,19 @@
import os
from inspect import signature
from typing import Any, List, Optional, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.agents import CacheHandler
from crewai.utilities import Converter, Prompts
from crewai.tools.agent_tools import AgentTools
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.tools.agent_tools import AgentTools
from crewai.utilities import Converter, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
def mock_agent_ops_provider():
@@ -207,11 +208,11 @@ class Agent(BaseAgent):
if self.crew and self.crew.memory:
contextual_memory = ContextualMemory(
self.crew.memory_provider,
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._user_memory,
self.crew.memory_provider,
)
memory = contextual_memory.build_context_for_task(task, context)
if memory.strip() != "":

View File

@@ -210,6 +210,14 @@ class Crew(BaseModel):
# TODO: Improve typing
return json.loads(v) if isinstance(v, Json) else v # type: ignore
@field_validator("memory_provider", mode="before")
@classmethod
def validate_memory_provider(cls, v: Optional[str]) -> Optional[str]:
"""Ensure memory provider is either None or 'mem0'."""
if v not in (None, "mem0"):
raise ValueError("Memory provider must be either None or 'mem0'.")
return v
@model_validator(mode="after")
def set_private_attrs(self) -> "Crew":
"""Set private attributes."""
@@ -247,12 +255,18 @@ class Crew(BaseModel):
embedder_config=self.embedder,
)
)
self._entity_memory = EntityMemory(
memory_provider=self.memory_provider,
crew=self,
embedder_config=self.embedder,
self._entity_memory = (
self.entity_memory
if self.entity_memory
else EntityMemory(
memory_provider=self.memory_provider,
crew=self,
embedder_config=self.embedder,
)
)
self._user_memory = (
UserMemory(crew=self) if self.memory_provider == "mem0" else None
)
self._user_memory = UserMemory(crew=self)
return self
@model_validator(mode="after")
@@ -905,6 +919,7 @@ class Crew(BaseModel):
"_short_term_memory",
"_long_term_memory",
"_entity_memory",
"_user_memory",
"_telemetry",
"agents",
"tasks",

View File

@@ -6,17 +6,17 @@ from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMem
class ContextualMemory:
def __init__(
self,
memory_provider: str,
stm: ShortTermMemory,
ltm: LongTermMemory,
em: EntityMemory,
um: UserMemory,
memory_provider: Optional[str] = None, # Default value added
):
self.memory_provider = memory_provider
self.stm = stm
self.ltm = ltm
self.em = em
self.um = um
self.memory_provider = memory_provider
def build_context_for_task(self, task, context) -> str:
"""

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@@ -18,18 +18,25 @@ class LongTermMemory(Memory):
storage = storage if storage else LTMSQLiteStorage()
super().__init__(storage)
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
metadata = item.metadata
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
def save(self, item: LongTermMemoryItem) -> None:
metadata = item.metadata.copy() # Create a copy to avoid modifying the original
metadata.update(
{
"agent": item.agent,
"expected_output": item.expected_output,
"quality": item.quality, # Add quality to metadata
}
)
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
task_description=item.task,
score=metadata["quality"],
score=item.quality,
metadata=metadata,
datetime=item.datetime,
)
def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]:
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]:
results = self.storage.load(task, latest_n)
return results
def reset(self) -> None:
self.storage.reset()

View File

@@ -1,8 +1,8 @@
import os
from typing import Any, Dict, List, Optional
from mem0 import MemoryClient
from crewai.memory.storage.interface import Storage
from mem0 import MemoryClient
class Mem0Storage(Storage):
@@ -18,6 +18,9 @@ class Mem0Storage(Storage):
):
os.environ["OPENAI_API_KEY"] = "fake"
if not os.getenv("MEM0_API_KEY"):
raise EnvironmentError("MEM0_API_KEY is not set.")
agents = crew.agents if crew else []
agents = [agent.role for agent in agents]
agents = "_".join(agents)
@@ -39,4 +42,4 @@ class Mem0Storage(Storage):
if filters:
params["filters"] = filters
results = self.memory.search(**params)
return [r for r in results if r["score"] >= score_threshold]
return [r for r in results if float(r["score"]) >= score_threshold]

View File

@@ -5,14 +5,13 @@ import os
import shutil
from typing import Any, Dict, List, Optional
from crewai.memory.storage.interface import Storage
from crewai.utilities.paths import db_storage_path
from embedchain import App
from embedchain.llm.base import BaseLlm
from embedchain.models.data_type import DataType
from embedchain.vectordb.chroma import InvalidDimensionException
from crewai.memory.storage.interface import Storage
from crewai.utilities.paths import db_storage_path
@contextlib.contextmanager
def suppress_logging(

View File

@@ -32,6 +32,12 @@ class UserMemory(Memory):
filters: dict = {},
score_threshold: float = 0.35,
):
return super().search(
query=query, limit=limit, filters=filters, score_threshold=score_threshold
print("SEARCHING USER MEMORY", query, limit, filters, score_threshold)
result = super().search(
query=query,
limit=limit,
filters=filters,
score_threshold=score_threshold,
)
print("USER MEMORY SEARCH RESULT:", result)
return result

View File

@@ -23,6 +23,7 @@ from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import Logger
from crewai.utilities.rpm_controller import RPMController
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
from pydantic_core import ValidationError
ceo = Agent(
role="CEO",
@@ -173,6 +174,57 @@ def test_context_no_future_tasks():
Crew(tasks=[task1, task2, task3, task4], agents=[researcher, writer])
def test_memory_provider_validation():
# Create mock agents
agent1 = Agent(
role="Researcher",
goal="Conduct research on AI",
backstory="An experienced AI researcher",
allow_delegation=False,
)
agent2 = Agent(
role="Writer",
goal="Write articles on AI",
backstory="A seasoned writer with a focus on technology",
allow_delegation=False,
)
# Create mock tasks
task1 = Task(
description="Research the latest trends in AI",
expected_output="A report on AI trends",
agent=agent1,
)
task2 = Task(
description="Write an article based on the research",
expected_output="An article on AI trends",
agent=agent2,
)
# Test with valid memory provider values
try:
crew_with_none = Crew(
agents=[agent1, agent2], tasks=[task1, task2], memory_provider=None
)
crew_with_mem0 = Crew(
agents=[agent1, agent2], tasks=[task1, task2], memory_provider="mem0"
)
except ValidationError:
pytest.fail(
"Unexpected ValidationError raised for valid memory provider values"
)
# Test with an invalid memory provider value
with pytest.raises(ValidationError) as excinfo:
Crew(
agents=[agent1, agent2],
tasks=[task1, task2],
memory_provider="invalid_provider",
)
assert "Memory provider must be either None or 'mem0'." in str(excinfo.value)
def test_crew_config_with_wrong_keys():
no_tasks_config = json.dumps(
{
@@ -497,6 +549,7 @@ def test_cache_hitting_between_agents():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_api_calls_throttling(capsys):
from unittest.mock import patch
from crewai_tools import tool
@tool
@@ -1105,6 +1158,7 @@ def test_dont_set_agents_step_callback_if_already_set():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_function_calling_llm():
from unittest.mock import patch
from crewai_tools import tool
llm = "gpt-4o"

View File

@@ -0,0 +1,147 @@
from unittest.mock import MagicMock, patch
import pytest
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory
from crewai.memory.contextual.contextual_memory import ContextualMemory
@pytest.fixture
def mock_memories():
return {
"stm": MagicMock(spec=ShortTermMemory),
"ltm": MagicMock(spec=LongTermMemory),
"em": MagicMock(spec=EntityMemory),
"um": MagicMock(spec=UserMemory),
}
@pytest.fixture
def contextual_memory_mem0(mock_memories):
return ContextualMemory(
memory_provider="mem0",
stm=mock_memories["stm"],
ltm=mock_memories["ltm"],
em=mock_memories["em"],
um=mock_memories["um"],
)
@pytest.fixture
def contextual_memory_other(mock_memories):
return ContextualMemory(
memory_provider="other",
stm=mock_memories["stm"],
ltm=mock_memories["ltm"],
em=mock_memories["em"],
um=mock_memories["um"],
)
@pytest.fixture
def contextual_memory_none(mock_memories):
return ContextualMemory(
memory_provider=None,
stm=mock_memories["stm"],
ltm=mock_memories["ltm"],
em=mock_memories["em"],
um=mock_memories["um"],
)
def test_build_context_for_task_mem0(contextual_memory_mem0, mock_memories):
task = MagicMock(description="Test task")
context = "Additional context"
mock_memories["stm"].search.return_value = ["Recent insight"]
mock_memories["ltm"].search.return_value = [
{"metadata": {"suggestions": ["Historical data"]}}
]
mock_memories["em"].search.return_value = [{"memory": "Entity memory"}]
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
result = contextual_memory_mem0.build_context_for_task(task, context)
assert "Recent Insights:" in result
assert "Historical Data:" in result
assert "Entities:" in result
assert "User memories/preferences:" in result
def test_build_context_for_task_other_provider(contextual_memory_other, mock_memories):
task = MagicMock(description="Test task")
context = "Additional context"
mock_memories["stm"].search.return_value = ["Recent insight"]
mock_memories["ltm"].search.return_value = [
{"metadata": {"suggestions": ["Historical data"]}}
]
mock_memories["em"].search.return_value = [{"context": "Entity context"}]
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
result = contextual_memory_other.build_context_for_task(task, context)
assert "Recent Insights:" in result
assert "Historical Data:" in result
assert "Entities:" in result
assert "User memories/preferences:" not in result
def test_build_context_for_task_none_provider(contextual_memory_none, mock_memories):
task = MagicMock(description="Test task")
context = "Additional context"
mock_memories["stm"].search.return_value = ["Recent insight"]
mock_memories["ltm"].search.return_value = [
{"metadata": {"suggestions": ["Historical data"]}}
]
mock_memories["em"].search.return_value = [{"context": "Entity context"}]
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
result = contextual_memory_none.build_context_for_task(task, context)
assert "Recent Insights:" in result
assert "Historical Data:" in result
assert "Entities:" in result
assert "User memories/preferences:" not in result
def test_fetch_entity_context_mem0(contextual_memory_mem0, mock_memories):
mock_memories["em"].search.return_value = [
{"memory": "Entity 1"},
{"memory": "Entity 2"},
]
result = contextual_memory_mem0._fetch_entity_context("query")
expected_result = "Entities:\n- Entity 1\n- Entity 2"
assert result == expected_result
def test_fetch_entity_context_other_provider(contextual_memory_other, mock_memories):
mock_memories["em"].search.return_value = [
{"context": "Entity 1"},
{"context": "Entity 2"},
]
result = contextual_memory_other._fetch_entity_context("query")
expected_result = "Entities:\n- Entity 1\n- Entity 2"
assert result == expected_result
def test_user_memories_only_for_mem0(contextual_memory_mem0, mock_memories):
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
# Test for mem0 provider
result_mem0 = contextual_memory_mem0._fetch_user_memories("query")
assert "User memories/preferences:" in result_mem0
assert "User memory" in result_mem0
# Additional test to ensure user memories are included/excluded in the full context
task = MagicMock(description="Test task")
context = "Additional context"
mock_memories["stm"].search.return_value = ["Recent insight"]
mock_memories["ltm"].search.return_value = [
{"metadata": {"suggestions": ["Historical data"]}}
]
mock_memories["em"].search.return_value = [{"memory": "Entity memory"}]
full_context_mem0 = contextual_memory_mem0.build_context_for_task(task, context)
assert "User memories/preferences:" in full_context_mem0
assert "User memory" in full_context_mem0

View File

@@ -0,0 +1,119 @@
# tests/memory/test_entity_memory.py
from unittest.mock import MagicMock, patch
import pytest
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.storage.mem0_storage import Mem0Storage
from crewai.memory.storage.rag_storage import RAGStorage
@pytest.fixture
def mock_rag_storage():
"""Fixture to create a mock RAGStorage instance"""
return MagicMock(spec=RAGStorage)
@pytest.fixture
def mock_mem0_storage():
"""Fixture to create a mock Mem0Storage instance"""
return MagicMock(spec=Mem0Storage)
@pytest.fixture
def entity_memory_rag(mock_rag_storage):
"""Fixture to create an EntityMemory instance with RAGStorage"""
with patch(
"crewai.memory.entity.entity_memory.RAGStorage", return_value=mock_rag_storage
):
return EntityMemory()
@pytest.fixture
def entity_memory_mem0(mock_mem0_storage):
"""Fixture to create an EntityMemory instance with Mem0Storage"""
with patch(
"crewai.memory.entity.entity_memory.Mem0Storage", return_value=mock_mem0_storage
):
return EntityMemory(memory_provider="mem0")
def test_save_rag_storage(entity_memory_rag, mock_rag_storage):
item = EntityMemoryItem(
name="John Doe",
type="Person",
description="A software engineer",
relationships="Works at TechCorp",
)
entity_memory_rag.save(item)
expected_data = "John Doe(Person): A software engineer"
mock_rag_storage.save.assert_called_once_with(expected_data, item.metadata)
def test_save_mem0_storage(entity_memory_mem0, mock_mem0_storage):
item = EntityMemoryItem(
name="John Doe",
type="Person",
description="A software engineer",
relationships="Works at TechCorp",
)
entity_memory_mem0.save(item)
expected_data = """
Remember details about the following entity:
Name: John Doe
Type: Person
Entity Description: A software engineer
"""
mock_mem0_storage.save.assert_called_once_with(expected_data, item.metadata)
def test_search(entity_memory_rag, mock_rag_storage):
query = "software engineer"
limit = 5
filters = {"type": "Person"}
score_threshold = 0.7
entity_memory_rag.search(query, limit, filters, score_threshold)
mock_rag_storage.search.assert_called_once_with(
query=query, limit=limit, filters=filters, score_threshold=score_threshold
)
def test_reset(entity_memory_rag, mock_rag_storage):
entity_memory_rag.reset()
mock_rag_storage.reset.assert_called_once()
def test_reset_error(entity_memory_rag, mock_rag_storage):
mock_rag_storage.reset.side_effect = Exception("Reset error")
with pytest.raises(Exception) as exc_info:
entity_memory_rag.reset()
assert (
str(exc_info.value)
== "An error occurred while resetting the entity memory: Reset error"
)
@pytest.mark.parametrize("memory_provider", [None, "other"])
def test_init_with_rag_storage(memory_provider):
with patch("crewai.memory.entity.entity_memory.RAGStorage") as mock_rag_storage:
EntityMemory(memory_provider=memory_provider)
mock_rag_storage.assert_called_once()
def test_init_with_mem0_storage():
with patch("crewai.memory.entity.entity_memory.Mem0Storage") as mock_mem0_storage:
EntityMemory(memory_provider="mem0")
mock_mem0_storage.assert_called_once()
def test_init_with_custom_storage():
custom_storage = MagicMock()
entity_memory = EntityMemory(storage=custom_storage)
assert entity_memory.storage == custom_storage

View File

@@ -1,29 +1,125 @@
import pytest
# tests/memory/long_term_memory_test.py
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
@pytest.fixture
def long_term_memory():
"""Fixture to create a LongTermMemory instance"""
return LongTermMemory()
def mock_storage():
"""Fixture to create a mock LTMSQLiteStorage instance"""
return MagicMock(spec=LTMSQLiteStorage)
def test_save_and_search(long_term_memory):
@pytest.fixture
def long_term_memory(mock_storage):
"""Fixture to create a LongTermMemory instance with mock storage"""
return LongTermMemory(storage=mock_storage)
def test_save(long_term_memory, mock_storage):
memory = LongTermMemoryItem(
agent="test_agent",
task="test_task",
expected_output="test_output",
datetime="test_datetime",
datetime="2023-01-01 12:00:00",
quality=0.5,
metadata={"task": "test_task", "quality": 0.5},
metadata={"additional_info": "test_info"},
)
long_term_memory.save(memory)
find = long_term_memory.search("test_task", latest_n=5)[0]
assert find["score"] == 0.5
assert find["datetime"] == "test_datetime"
assert find["metadata"]["agent"] == "test_agent"
assert find["metadata"]["quality"] == 0.5
assert find["metadata"]["task"] == "test_task"
assert find["metadata"]["expected_output"] == "test_output"
expected_metadata = {
"additional_info": "test_info",
"agent": "test_agent",
"expected_output": "test_output",
"quality": 0.5, # Include quality in expected metadata
}
mock_storage.save.assert_called_once_with(
task_description="test_task",
score=0.5,
metadata=expected_metadata,
datetime="2023-01-01 12:00:00",
)
def test_search(long_term_memory, mock_storage):
mock_storage.load.return_value = [
{
"metadata": {
"agent": "test_agent",
"expected_output": "test_output",
"task": "test_task",
},
"datetime": "2023-01-01 12:00:00",
"score": 0.5,
}
]
result = long_term_memory.search("test_task", latest_n=5)
mock_storage.load.assert_called_once_with("test_task", 5)
assert len(result) == 1
assert result[0]["metadata"]["agent"] == "test_agent"
assert result[0]["metadata"]["expected_output"] == "test_output"
assert result[0]["metadata"]["task"] == "test_task"
assert result[0]["datetime"] == "2023-01-01 12:00:00"
assert result[0]["score"] == 0.5
def test_save_with_minimal_metadata(long_term_memory, mock_storage):
memory = LongTermMemoryItem(
agent="minimal_agent",
task="minimal_task",
expected_output="minimal_output",
datetime="2023-01-01 12:00:00",
quality=0.3,
metadata={},
)
long_term_memory.save(memory)
expected_metadata = {
"agent": "minimal_agent",
"expected_output": "minimal_output",
"quality": 0.3, # Include quality in expected metadata
}
mock_storage.save.assert_called_once_with(
task_description="minimal_task",
score=0.3,
metadata=expected_metadata,
datetime="2023-01-01 12:00:00",
)
def test_reset(long_term_memory, mock_storage):
long_term_memory.reset()
mock_storage.reset.assert_called_once()
def test_search_with_no_results(long_term_memory, mock_storage):
mock_storage.load.return_value = []
result = long_term_memory.search("nonexistent_task")
assert result == []
def test_init_with_default_storage():
with patch(
"crewai.memory.long_term.long_term_memory.LTMSQLiteStorage"
) as mock_storage_class:
LongTermMemory()
mock_storage_class.assert_called_once()
def test_init_with_custom_storage():
custom_storage = MagicMock()
memory = LongTermMemory(storage=custom_storage)
assert memory.storage == custom_storage
@pytest.mark.parametrize("latest_n", [1, 3, 5, 10])
def test_search_with_different_latest_n(long_term_memory, mock_storage, latest_n):
long_term_memory.search("test_task", latest_n=latest_n)
mock_storage.load.assert_called_once_with("test_task", latest_n)