mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Adding long term, short term, entity and contextual memory
This commit is contained in:
@@ -48,36 +48,6 @@ def test_custom_llm():
|
||||
assert agent.llm.temperature == 0
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_without_memory():
|
||||
no_memory_agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
memory=False,
|
||||
llm=ChatOpenAI(temperature=0, model="gpt-4"),
|
||||
)
|
||||
|
||||
memory_agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
memory=True,
|
||||
llm=ChatOpenAI(temperature=0, model="gpt-4"),
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="How much is 1 + 1?",
|
||||
agent=no_memory_agent,
|
||||
expected_output="the result of the math operation.",
|
||||
)
|
||||
result = no_memory_agent.execute_task(task)
|
||||
|
||||
assert result == "The result of the math operation 1 + 1 is 2."
|
||||
assert no_memory_agent.agent_executor.memory is None
|
||||
assert memory_agent.agent_executor.memory is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execution():
|
||||
agent = Agent(
|
||||
@@ -403,7 +373,6 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
print(captured.out)
|
||||
assert (
|
||||
"I tried reusing the same input, I must stop using this action input. I'll try something else instead."
|
||||
in captured.out
|
||||
|
||||
@@ -648,9 +648,9 @@ def test_agent_usage_metrics_are_captured_for_sequential_process():
|
||||
assert result == "Howdy!"
|
||||
assert crew.usage_metrics == {
|
||||
"completion_tokens": 56,
|
||||
"prompt_tokens": 164,
|
||||
"prompt_tokens": 161,
|
||||
"successful_requests": 1,
|
||||
"total_tokens": 220,
|
||||
"total_tokens": 217,
|
||||
}
|
||||
|
||||
|
||||
@@ -677,8 +677,8 @@ def test_agent_usage_metrics_are_captured_for_hierarchical_process():
|
||||
result = crew.kickoff()
|
||||
assert result == "Howdy!"
|
||||
assert crew.usage_metrics == {
|
||||
"total_tokens": 1513,
|
||||
"prompt_tokens": 1299,
|
||||
"total_tokens": 1510,
|
||||
"prompt_tokens": 1296,
|
||||
"completion_tokens": 214,
|
||||
"successful_requests": 3,
|
||||
}
|
||||
@@ -735,6 +735,36 @@ def test_crew_inputs_interpolate_both_agents_and_tasks():
|
||||
interpolate_task_inputs.assert_called()
|
||||
|
||||
|
||||
def test_task_callback_on_crew():
|
||||
from unittest.mock import patch
|
||||
|
||||
researcher_agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Make the best research and analysis on content about AI and AI agents",
|
||||
backstory="You're an expert researcher, specialized in technology, software engineering, AI and startups. You work as a freelancer and is now working on doing research and analysis for a new customer.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
list_ideas = Task(
|
||||
description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.",
|
||||
expected_output="Bullet point list of 5 important events.",
|
||||
agent=researcher_agent,
|
||||
async_execution=True,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher_agent],
|
||||
process=Process.sequential,
|
||||
tasks=[list_ideas],
|
||||
task_callback=lambda: None,
|
||||
)
|
||||
|
||||
with patch.object(Agent, "execute_task") as execute:
|
||||
execute.return_value = "ok"
|
||||
crew.kickoff()
|
||||
assert list_ideas.callback is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_tools_with_custom_caching():
|
||||
from unittest.mock import patch
|
||||
@@ -748,7 +778,6 @@ def test_tools_with_custom_caching():
|
||||
|
||||
def cache_func(args, result):
|
||||
cache = result % 2 == 0
|
||||
print(f"cache?: {cache}")
|
||||
return cache
|
||||
|
||||
multiplcation_tool.cache_function = cache_func
|
||||
|
||||
29
tests/memory/long_term_memory_test.py
Normal file
29
tests/memory/long_term_memory_test.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import pytest
|
||||
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory():
|
||||
"""Fixture to create a LongTermMemory instance"""
|
||||
return LongTermMemory()
|
||||
|
||||
|
||||
def test_save_and_search(long_term_memory):
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
find = long_term_memory.search("test_task")[0]
|
||||
assert find["score"] == 0.5
|
||||
assert find["datetime"] == "test_datetime"
|
||||
assert find["metadata"]["agent"] == "test_agent"
|
||||
assert find["metadata"]["quality"] == 0.5
|
||||
assert find["metadata"]["task"] == "test_task"
|
||||
assert find["metadata"]["expected_output"] == "test_output"
|
||||
24
tests/memory/short_term_memory_test.py
Normal file
24
tests/memory/short_term_memory_test.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory():
|
||||
"""Fixture to create a ShortTermMemory instance"""
|
||||
return ShortTermMemory()
|
||||
|
||||
|
||||
def test_save_and_search(short_term_memory):
|
||||
memory = ShortTermMemoryItem(
|
||||
data="""test value test value test value test value test value test value
|
||||
test value test value test value test value test value test value
|
||||
test value test value test value test value test value test value""",
|
||||
agent="test_agent",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
short_term_memory.save(memory)
|
||||
find = short_term_memory.search("test value", score_threshold=0.01)[0]
|
||||
assert find["context"] == memory.data, "Data value mismatch."
|
||||
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."
|
||||
Reference in New Issue
Block a user