mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-09 04:28:16 +00:00
Compare commits
1 Commits
fix/issue-
...
devin/1744
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4748597667 |
@@ -771,65 +771,6 @@ class Crew(BaseModel):
|
||||
|
||||
return self._create_crew_output(task_outputs)
|
||||
|
||||
def _get_context_based_output(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: List[TaskOutput],
|
||||
task_index: int,
|
||||
) -> Optional[TaskOutput]:
|
||||
"""Get the output from explicit context tasks."""
|
||||
context_task_outputs = []
|
||||
for context_task in task.context:
|
||||
context_task_index = self._find_task_index(context_task)
|
||||
if context_task_index != -1 and context_task_index < task_index:
|
||||
for output in task_outputs:
|
||||
if output.description == context_task.description:
|
||||
context_task_outputs.append(output)
|
||||
break
|
||||
return context_task_outputs[-1] if context_task_outputs else None
|
||||
|
||||
def _get_non_conditional_output(
|
||||
self,
|
||||
task_outputs: List[TaskOutput],
|
||||
task_index: int,
|
||||
) -> Optional[TaskOutput]:
|
||||
"""Get the output from the most recent non-conditional task."""
|
||||
non_conditional_outputs = []
|
||||
for i in range(task_index):
|
||||
if i < len(self.tasks) and not isinstance(self.tasks[i], ConditionalTask):
|
||||
for output in task_outputs:
|
||||
if output.description == self.tasks[i].description:
|
||||
non_conditional_outputs.append(output)
|
||||
break
|
||||
return non_conditional_outputs[-1] if non_conditional_outputs else None
|
||||
|
||||
def _get_previous_output(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: List[TaskOutput],
|
||||
task_index: int,
|
||||
) -> Optional[TaskOutput]:
|
||||
"""Get the previous output for a conditional task.
|
||||
|
||||
The order of precedence is:
|
||||
1. Output from explicit context tasks
|
||||
2. Output from the most recent non-conditional task
|
||||
3. Output from the immediately preceding task
|
||||
"""
|
||||
if task.context and len(task.context) > 0:
|
||||
previous_output = self._get_context_based_output(task, task_outputs, task_index)
|
||||
if previous_output:
|
||||
return previous_output
|
||||
|
||||
previous_output = self._get_non_conditional_output(task_outputs, task_index)
|
||||
if previous_output:
|
||||
return previous_output
|
||||
|
||||
if task_outputs and task_index > 0 and task_index <= len(task_outputs):
|
||||
return task_outputs[task_index - 1]
|
||||
|
||||
return None
|
||||
|
||||
def _handle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
@@ -838,17 +779,11 @@ class Crew(BaseModel):
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Optional[TaskOutput]:
|
||||
"""Handle a conditional task.
|
||||
|
||||
Determines whether a conditional task should be executed based on the output
|
||||
of previous tasks. If the task should not be executed, returns a skipped task output.
|
||||
"""
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
futures.clear()
|
||||
|
||||
previous_output = self._get_previous_output(task, task_outputs, task_index)
|
||||
|
||||
previous_output = task_outputs[task_index - 1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
self._logger.log(
|
||||
"debug",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
@@ -38,7 +40,7 @@ class EntityMemory(Memory):
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
def save(self, item: EntityMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
if self.memory_provider == "mem0":
|
||||
data = f"""
|
||||
@@ -49,7 +51,7 @@ class EntityMemory(Memory):
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
super().save(data, item.metadata)
|
||||
super().save(data, item.metadata, custom_key=custom_key)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -19,9 +19,12 @@ class LongTermMemory(Memory):
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
def save(self, item: LongTermMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
metadata = item.metadata
|
||||
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
|
||||
if custom_key:
|
||||
metadata.update({"custom_key": custom_key})
|
||||
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
@@ -29,8 +32,8 @@ class LongTermMemory(Memory):
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
def search(self, task: str, latest_n: int = 3, custom_key: Optional[str] = None) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
return self.storage.load(task, latest_n, custom_key) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
|
||||
def reset(self) -> None:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -5,7 +5,10 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
class Memory:
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
Base class for memory, now supporting agent tags, generic metadata, and custom keys.
|
||||
|
||||
Custom keys allow scoping memories to specific entities (users, accounts, sessions),
|
||||
retrieving memories contextually, and preventing data leakage across logical boundaries.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: RAGStorage):
|
||||
@@ -16,10 +19,13 @@ class Memory:
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
if agent:
|
||||
metadata["agent"] = agent
|
||||
if custom_key:
|
||||
metadata["custom_key"] = custom_key
|
||||
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
@@ -28,7 +34,12 @@ class Memory:
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> List[Any]:
|
||||
filter_dict = None
|
||||
if custom_key:
|
||||
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
query=query, limit=limit, score_threshold=score_threshold, filter=filter_dict
|
||||
)
|
||||
|
||||
@@ -46,22 +46,31 @@ class ShortTermMemory(Memory):
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self.memory_provider == "mem0":
|
||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent, custom_key=custom_key)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
custom_key: Optional[str] = None,
|
||||
):
|
||||
filter_dict = None
|
||||
if custom_key:
|
||||
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
filter=filter_dict
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
|
||||
@@ -70,22 +70,31 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
|
||||
def load(
|
||||
self, task_description: str, latest_n: int
|
||||
self, task_description: str, latest_n: int, custom_key: Optional[str] = None
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
|
||||
query = """
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
"""
|
||||
|
||||
params = [task_description]
|
||||
|
||||
if custom_key:
|
||||
query += " AND json_extract(metadata, '$.custom_key') = ?"
|
||||
params.append(custom_key)
|
||||
|
||||
query += f"""
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec
|
||||
(task_description,),
|
||||
)
|
||||
"""
|
||||
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
|
||||
@@ -120,7 +120,11 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
try:
|
||||
with suppress_logging():
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
response = self.collection.query(
|
||||
query_texts=query,
|
||||
n_results=limit,
|
||||
where=filter
|
||||
)
|
||||
|
||||
results = []
|
||||
for i in range(len(response["ids"][0])):
|
||||
|
||||
@@ -26,20 +26,27 @@ class UserMemory(Memory):
|
||||
value,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> None:
|
||||
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
||||
data = f"Remember the details about the user: {value}"
|
||||
super().save(data, metadata)
|
||||
super().save(data, metadata, custom_key=custom_key)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
custom_key: Optional[str] = None,
|
||||
):
|
||||
filter_dict = None
|
||||
if custom_key:
|
||||
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||
|
||||
results = self.storage.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
filter=filter_dict,
|
||||
)
|
||||
return results
|
||||
|
||||
57
tests/memory/custom_key_memory_test.py
Normal file
57
tests/memory/custom_key_memory_test.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory():
|
||||
"""Fixture to create a ShortTermMemory instance"""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=agent,
|
||||
)
|
||||
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
|
||||
|
||||
|
||||
def test_save_with_custom_key(short_term_memory):
|
||||
"""Test that save method correctly passes custom_key to storage"""
|
||||
with patch.object(short_term_memory.storage, 'save') as mock_save:
|
||||
short_term_memory.save(
|
||||
value="Test data",
|
||||
metadata={"task": "test_task"},
|
||||
agent="test_agent",
|
||||
custom_key="user123",
|
||||
)
|
||||
|
||||
called_args = mock_save.call_args[0]
|
||||
called_kwargs = mock_save.call_args[1]
|
||||
|
||||
assert "custom_key" in called_args[1]
|
||||
assert called_args[1]["custom_key"] == "user123"
|
||||
|
||||
|
||||
def test_search_with_custom_key(short_term_memory):
|
||||
"""Test that search method correctly passes custom_key to storage"""
|
||||
expected_results = [{"context": "Test data", "metadata": {"custom_key": "user123"}, "score": 0.95}]
|
||||
|
||||
with patch.object(short_term_memory.storage, 'search', return_value=expected_results) as mock_search:
|
||||
results = short_term_memory.search("test query", custom_key="user123")
|
||||
|
||||
mock_search.assert_called_once()
|
||||
filter_arg = mock_search.call_args[1].get('filter')
|
||||
assert filter_arg == {"custom_key": {"$eq": "user123"}}
|
||||
assert results == expected_results
|
||||
@@ -1,335 +0,0 @@
|
||||
"""Test for multiple conditional tasks."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class TestMultipleConditionalTasks:
|
||||
"""Test class for multiple conditional tasks scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_agents(self):
|
||||
"""Set up agents for the tests."""
|
||||
agent1 = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Find information",
|
||||
backstory="You're a researcher",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
agent2 = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Process information",
|
||||
backstory="You process data",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
agent3 = Agent(
|
||||
role="Report Writer",
|
||||
goal="Write reports",
|
||||
backstory="You write reports",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
return agent1, agent2, agent3
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tasks(self, setup_agents):
|
||||
"""Set up tasks for the tests."""
|
||||
agent1, agent2, agent3 = setup_agents
|
||||
|
||||
# Create tasks
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=agent1,
|
||||
)
|
||||
|
||||
# First conditional task should check task1's output
|
||||
condition1_mock = MagicMock()
|
||||
task2 = ConditionalTask(
|
||||
description="Conditional Task 2",
|
||||
expected_output="Output 2",
|
||||
agent=agent2,
|
||||
condition=condition1_mock,
|
||||
)
|
||||
|
||||
# Second conditional task should check task1's output, not task2's
|
||||
condition2_mock = MagicMock()
|
||||
task3 = ConditionalTask(
|
||||
description="Conditional Task 3",
|
||||
expected_output="Output 3",
|
||||
agent=agent3,
|
||||
condition=condition2_mock,
|
||||
)
|
||||
|
||||
return task1, task2, task3, condition1_mock, condition2_mock
|
||||
|
||||
@pytest.fixture
|
||||
def setup_crew(self, setup_agents, setup_tasks):
|
||||
"""Set up crew for the tests."""
|
||||
agent1, agent2, agent3 = setup_agents
|
||||
task1, task2, task3, _, _ = setup_tasks
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2, agent3],
|
||||
tasks=[task1, task2, task3],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
return crew
|
||||
|
||||
@pytest.fixture
|
||||
def setup_task_outputs(self, setup_agents):
|
||||
"""Set up task outputs for the tests."""
|
||||
agent1, agent2, _ = setup_agents
|
||||
|
||||
task1_output = TaskOutput(
|
||||
description="Task 1",
|
||||
raw="Task 1 output",
|
||||
agent=agent1.role,
|
||||
output_format=OutputFormat.RAW,
|
||||
)
|
||||
|
||||
task2_output = TaskOutput(
|
||||
description="Conditional Task 2",
|
||||
raw="Task 2 output",
|
||||
agent=agent2.role,
|
||||
output_format=OutputFormat.RAW,
|
||||
)
|
||||
|
||||
return task1_output, task2_output
|
||||
|
||||
def test_first_conditional_task_execution(self, setup_crew, setup_tasks, setup_task_outputs):
|
||||
"""Test that the first conditional task is evaluated correctly."""
|
||||
crew = setup_crew
|
||||
_, task2, _, condition1_mock, _ = setup_tasks
|
||||
task1_output, _ = setup_task_outputs
|
||||
|
||||
condition1_mock.return_value = True # Task should execute
|
||||
result = crew._handle_conditional_task(
|
||||
task=task2,
|
||||
task_outputs=[task1_output],
|
||||
futures=[],
|
||||
task_index=1,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output
|
||||
condition1_mock.assert_called_once()
|
||||
args = condition1_mock.call_args[0][0]
|
||||
assert args.raw == "Task 1 output"
|
||||
assert result is None # Task should execute, so no skipped output
|
||||
|
||||
def test_second_conditional_task_execution(self, setup_crew, setup_tasks, setup_task_outputs):
|
||||
"""Test that the second conditional task is evaluated correctly."""
|
||||
crew = setup_crew
|
||||
_, _, task3, _, condition2_mock = setup_tasks
|
||||
task1_output, task2_output = setup_task_outputs
|
||||
|
||||
condition2_mock.return_value = True # Task should execute
|
||||
result = crew._handle_conditional_task(
|
||||
task=task3,
|
||||
task_outputs=[task1_output, task2_output],
|
||||
futures=[],
|
||||
task_index=2,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output, not task2's
|
||||
condition2_mock.assert_called_once()
|
||||
args = condition2_mock.call_args[0][0]
|
||||
assert args.raw == "Task 1 output" # Should be task1's output
|
||||
assert args.raw != "Task 2 output" # Should not be task2's output
|
||||
assert result is None # Task should execute, so no skipped output
|
||||
|
||||
def test_conditional_task_skipping(self, setup_crew, setup_tasks, setup_task_outputs):
|
||||
"""Test that conditional tasks are skipped when the condition returns False."""
|
||||
crew = setup_crew
|
||||
_, task2, _, condition1_mock, _ = setup_tasks
|
||||
task1_output, _ = setup_task_outputs
|
||||
|
||||
condition1_mock.return_value = False # Task should be skipped
|
||||
result = crew._handle_conditional_task(
|
||||
task=task2,
|
||||
task_outputs=[task1_output],
|
||||
futures=[],
|
||||
task_index=1,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output
|
||||
condition1_mock.assert_called_once()
|
||||
args = condition1_mock.call_args[0][0]
|
||||
assert args.raw == "Task 1 output"
|
||||
assert result is not None # Task should be skipped, so there should be a skipped output
|
||||
assert result.description == task2.description
|
||||
|
||||
def test_conditional_task_with_explicit_context(self, setup_crew, setup_agents, setup_task_outputs):
|
||||
"""Test conditional task with explicit context tasks."""
|
||||
crew = setup_crew
|
||||
agent1, agent2, _ = setup_agents
|
||||
task1_output, _ = setup_task_outputs
|
||||
|
||||
with patch.object(crew, '_find_task_index', return_value=0):
|
||||
context_task = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=agent1,
|
||||
)
|
||||
|
||||
condition_mock = MagicMock(return_value=True)
|
||||
task_with_context = ConditionalTask(
|
||||
description="Task with Context",
|
||||
expected_output="Output with Context",
|
||||
agent=agent2,
|
||||
condition=condition_mock,
|
||||
context=[context_task],
|
||||
)
|
||||
|
||||
crew.tasks.append(task_with_context)
|
||||
|
||||
result = crew._handle_conditional_task(
|
||||
task=task_with_context,
|
||||
task_outputs=[task1_output],
|
||||
futures=[],
|
||||
task_index=3, # This would be the 4th task
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output
|
||||
condition_mock.assert_called_once()
|
||||
args = condition_mock.call_args[0][0]
|
||||
assert args.raw == "Task 1 output"
|
||||
assert result is None # Task should execute, so no skipped output
|
||||
|
||||
def test_conditional_task_with_empty_task_outputs(self, setup_crew, setup_tasks):
|
||||
"""Test conditional task with empty task outputs."""
|
||||
crew = setup_crew
|
||||
_, task2, _, condition1_mock, _ = setup_tasks
|
||||
|
||||
result = crew._handle_conditional_task(
|
||||
task=task2,
|
||||
task_outputs=[],
|
||||
futures=[],
|
||||
task_index=1,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
condition1_mock.assert_not_called()
|
||||
assert result is None # Task should execute, so no skipped output
|
||||
|
||||
|
||||
def test_multiple_conditional_tasks():
|
||||
"""Test that multiple conditional tasks are evaluated correctly.
|
||||
|
||||
This is a legacy test that's kept for backward compatibility.
|
||||
The actual tests are now in the TestMultipleConditionalTasks class.
|
||||
"""
|
||||
agent1 = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Find information",
|
||||
backstory="You're a researcher",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
agent2 = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Process information",
|
||||
backstory="You process data",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
agent3 = Agent(
|
||||
role="Report Writer",
|
||||
goal="Write reports",
|
||||
backstory="You write reports",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Create tasks
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=agent1,
|
||||
)
|
||||
|
||||
# First conditional task should check task1's output
|
||||
condition1_mock = MagicMock()
|
||||
task2 = ConditionalTask(
|
||||
description="Conditional Task 2",
|
||||
expected_output="Output 2",
|
||||
agent=agent2,
|
||||
condition=condition1_mock,
|
||||
)
|
||||
|
||||
# Second conditional task should check task1's output, not task2's
|
||||
condition2_mock = MagicMock()
|
||||
task3 = ConditionalTask(
|
||||
description="Conditional Task 3",
|
||||
expected_output="Output 3",
|
||||
agent=agent3,
|
||||
condition=condition2_mock,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2, agent3],
|
||||
tasks=[task1, task2, task3],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
with patch.object(crew, '_find_task_index', return_value=0):
|
||||
task1_output = TaskOutput(
|
||||
description="Task 1",
|
||||
raw="Task 1 output",
|
||||
agent=agent1.role,
|
||||
output_format=OutputFormat.RAW,
|
||||
)
|
||||
|
||||
condition1_mock.return_value = True # Task should execute
|
||||
result1 = crew._handle_conditional_task(
|
||||
task=task2,
|
||||
task_outputs=[task1_output],
|
||||
futures=[],
|
||||
task_index=1,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output
|
||||
condition1_mock.assert_called_once()
|
||||
args1 = condition1_mock.call_args[0][0]
|
||||
assert args1.raw == "Task 1 output"
|
||||
assert result1 is None # Task should execute, so no skipped output
|
||||
|
||||
condition1_mock.reset_mock()
|
||||
|
||||
task2_output = TaskOutput(
|
||||
description="Conditional Task 2",
|
||||
raw="Task 2 output",
|
||||
agent=agent2.role,
|
||||
output_format=OutputFormat.RAW,
|
||||
)
|
||||
|
||||
condition2_mock.return_value = True # Task should execute
|
||||
result2 = crew._handle_conditional_task(
|
||||
task=task3,
|
||||
task_outputs=[task1_output, task2_output],
|
||||
futures=[],
|
||||
task_index=2,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output, not task2's
|
||||
condition2_mock.assert_called_once()
|
||||
args2 = condition2_mock.call_args[0][0]
|
||||
assert args2.raw == "Task 1 output" # Should be task1's output
|
||||
assert args2.raw != "Task 2 output" # Should not be task2's output
|
||||
assert result2 is None # Task should execute, so no skipped output
|
||||
Reference in New Issue
Block a user