Compare commits

..

3 Commits

Author SHA1 Message Date
Devin AI
214357c482 Refactor: Improve test organization and add edge case tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-07 11:11:45 +00:00
Devin AI
cbac6a5534 Refactor: Improve code organization and maintainability in conditional task handling
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-07 11:09:06 +00:00
Devin AI
fac958dd0b Fix issue #2530: Multiple conditional tasks using correct previous output
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-07 11:03:53 +00:00
4 changed files with 402 additions and 55 deletions

View File

@@ -771,6 +771,65 @@ class Crew(BaseModel):
return self._create_crew_output(task_outputs)
def _get_context_based_output(
self,
task: ConditionalTask,
task_outputs: List[TaskOutput],
task_index: int,
) -> Optional[TaskOutput]:
"""Get the output from explicit context tasks."""
context_task_outputs = []
for context_task in task.context:
context_task_index = self._find_task_index(context_task)
if context_task_index != -1 and context_task_index < task_index:
for output in task_outputs:
if output.description == context_task.description:
context_task_outputs.append(output)
break
return context_task_outputs[-1] if context_task_outputs else None
def _get_non_conditional_output(
self,
task_outputs: List[TaskOutput],
task_index: int,
) -> Optional[TaskOutput]:
"""Get the output from the most recent non-conditional task."""
non_conditional_outputs = []
for i in range(task_index):
if i < len(self.tasks) and not isinstance(self.tasks[i], ConditionalTask):
for output in task_outputs:
if output.description == self.tasks[i].description:
non_conditional_outputs.append(output)
break
return non_conditional_outputs[-1] if non_conditional_outputs else None
def _get_previous_output(
self,
task: ConditionalTask,
task_outputs: List[TaskOutput],
task_index: int,
) -> Optional[TaskOutput]:
"""Get the previous output for a conditional task.
The order of precedence is:
1. Output from explicit context tasks
2. Output from the most recent non-conditional task
3. Output from the immediately preceding task
"""
if task.context and len(task.context) > 0:
previous_output = self._get_context_based_output(task, task_outputs, task_index)
if previous_output:
return previous_output
previous_output = self._get_non_conditional_output(task_outputs, task_index)
if previous_output:
return previous_output
if task_outputs and task_index > 0 and task_index <= len(task_outputs):
return task_outputs[task_index - 1]
return None
def _handle_conditional_task(
self,
task: ConditionalTask,
@@ -779,11 +838,17 @@ class Crew(BaseModel):
task_index: int,
was_replayed: bool,
) -> Optional[TaskOutput]:
"""Handle a conditional task.
Determines whether a conditional task should be executed based on the output
of previous tasks. If the task should not be executed, returns a skipped task output.
"""
if futures:
task_outputs = self._process_async_tasks(futures, was_replayed)
futures.clear()
previous_output = task_outputs[task_index - 1] if task_outputs else None
previous_output = self._get_previous_output(task, task_outputs, task_index)
if previous_output is not None and not task.should_execute(previous_output):
self._logger.log(
"debug",

View File

@@ -92,8 +92,6 @@ def suppress_warnings():
class LLM:
MODELS_WITHOUT_STOP_SUPPORT = ["o3", "o3-mini", "o4-mini"]
def __init__(
self,
model: str,
@@ -157,7 +155,7 @@ class LLM:
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop if self.supports_stop_words() else None,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
@@ -195,19 +193,6 @@ class LLM:
return False
def supports_stop_words(self) -> bool:
"""
Determines whether the current model supports the 'stop' parameter.
This method checks if the model is in the list of models known not to support
stop words, and if not, it queries the litellm library to determine if the
model supports the 'stop' parameter.
Returns:
bool: True if the model supports stop words, False otherwise.
"""
if any(self.model.startswith(model) for model in self.MODELS_WITHOUT_STOP_SUPPORT):
return False
try:
params = get_supported_openai_params(model=self.model)
return "stop" in params

View File

@@ -28,41 +28,3 @@ def test_llm_callback_replacement():
assert usage_metrics_1.successful_requests == 1
assert usage_metrics_2.successful_requests == 1
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
class TestLLMStopWords:
"""Tests for LLM stop words functionality."""
def test_supports_stop_words_for_o3_model(self):
"""Test that supports_stop_words returns False for o3 model."""
llm = LLM(model="o3")
assert not llm.supports_stop_words()
def test_supports_stop_words_for_o4_mini_model(self):
"""Test that supports_stop_words returns False for o4-mini model."""
llm = LLM(model="o4-mini")
assert not llm.supports_stop_words()
def test_supports_stop_words_for_supported_model(self):
"""Test that supports_stop_words returns True for models that support stop words."""
llm = LLM(model="gpt-4")
assert llm.supports_stop_words()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_excludes_stop_parameter_for_unsupported_models(self, monkeypatch):
"""Test that the LLM.call method excludes the stop parameter for models that don't support it."""
def mock_completion(**kwargs):
assert 'stop' not in kwargs, "Stop parameter should be excluded for o3 model"
assert 'model' in kwargs, "Model parameter should be included"
assert 'messages' in kwargs, "Messages parameter should be included"
return {"choices": [{"message": {"content": "Hello, World!"}}]}
monkeypatch.setattr("litellm.completion", mock_completion)
llm = LLM(model="o3")
llm.stop = ["STOP"]
messages = [{"role": "user", "content": "Say 'Hello, World!'"}]
response = llm.call(messages)
assert response == "Hello, World!"

View File

@@ -0,0 +1,335 @@
"""Test for multiple conditional tasks."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
class TestMultipleConditionalTasks:
"""Test class for multiple conditional tasks scenarios."""
@pytest.fixture
def setup_agents(self):
"""Set up agents for the tests."""
agent1 = Agent(
role="Research Analyst",
goal="Find information",
backstory="You're a researcher",
verbose=True,
)
agent2 = Agent(
role="Data Analyst",
goal="Process information",
backstory="You process data",
verbose=True,
)
agent3 = Agent(
role="Report Writer",
goal="Write reports",
backstory="You write reports",
verbose=True,
)
return agent1, agent2, agent3
@pytest.fixture
def setup_tasks(self, setup_agents):
"""Set up tasks for the tests."""
agent1, agent2, agent3 = setup_agents
# Create tasks
task1 = Task(
description="Task 1",
expected_output="Output 1",
agent=agent1,
)
# First conditional task should check task1's output
condition1_mock = MagicMock()
task2 = ConditionalTask(
description="Conditional Task 2",
expected_output="Output 2",
agent=agent2,
condition=condition1_mock,
)
# Second conditional task should check task1's output, not task2's
condition2_mock = MagicMock()
task3 = ConditionalTask(
description="Conditional Task 3",
expected_output="Output 3",
agent=agent3,
condition=condition2_mock,
)
return task1, task2, task3, condition1_mock, condition2_mock
@pytest.fixture
def setup_crew(self, setup_agents, setup_tasks):
"""Set up crew for the tests."""
agent1, agent2, agent3 = setup_agents
task1, task2, task3, _, _ = setup_tasks
crew = Crew(
agents=[agent1, agent2, agent3],
tasks=[task1, task2, task3],
verbose=True,
)
return crew
@pytest.fixture
def setup_task_outputs(self, setup_agents):
"""Set up task outputs for the tests."""
agent1, agent2, _ = setup_agents
task1_output = TaskOutput(
description="Task 1",
raw="Task 1 output",
agent=agent1.role,
output_format=OutputFormat.RAW,
)
task2_output = TaskOutput(
description="Conditional Task 2",
raw="Task 2 output",
agent=agent2.role,
output_format=OutputFormat.RAW,
)
return task1_output, task2_output
def test_first_conditional_task_execution(self, setup_crew, setup_tasks, setup_task_outputs):
"""Test that the first conditional task is evaluated correctly."""
crew = setup_crew
_, task2, _, condition1_mock, _ = setup_tasks
task1_output, _ = setup_task_outputs
condition1_mock.return_value = True # Task should execute
result = crew._handle_conditional_task(
task=task2,
task_outputs=[task1_output],
futures=[],
task_index=1,
was_replayed=False,
)
# Verify the condition was called with task1's output
condition1_mock.assert_called_once()
args = condition1_mock.call_args[0][0]
assert args.raw == "Task 1 output"
assert result is None # Task should execute, so no skipped output
def test_second_conditional_task_execution(self, setup_crew, setup_tasks, setup_task_outputs):
"""Test that the second conditional task is evaluated correctly."""
crew = setup_crew
_, _, task3, _, condition2_mock = setup_tasks
task1_output, task2_output = setup_task_outputs
condition2_mock.return_value = True # Task should execute
result = crew._handle_conditional_task(
task=task3,
task_outputs=[task1_output, task2_output],
futures=[],
task_index=2,
was_replayed=False,
)
# Verify the condition was called with task1's output, not task2's
condition2_mock.assert_called_once()
args = condition2_mock.call_args[0][0]
assert args.raw == "Task 1 output" # Should be task1's output
assert args.raw != "Task 2 output" # Should not be task2's output
assert result is None # Task should execute, so no skipped output
def test_conditional_task_skipping(self, setup_crew, setup_tasks, setup_task_outputs):
"""Test that conditional tasks are skipped when the condition returns False."""
crew = setup_crew
_, task2, _, condition1_mock, _ = setup_tasks
task1_output, _ = setup_task_outputs
condition1_mock.return_value = False # Task should be skipped
result = crew._handle_conditional_task(
task=task2,
task_outputs=[task1_output],
futures=[],
task_index=1,
was_replayed=False,
)
# Verify the condition was called with task1's output
condition1_mock.assert_called_once()
args = condition1_mock.call_args[0][0]
assert args.raw == "Task 1 output"
assert result is not None # Task should be skipped, so there should be a skipped output
assert result.description == task2.description
def test_conditional_task_with_explicit_context(self, setup_crew, setup_agents, setup_task_outputs):
"""Test conditional task with explicit context tasks."""
crew = setup_crew
agent1, agent2, _ = setup_agents
task1_output, _ = setup_task_outputs
with patch.object(crew, '_find_task_index', return_value=0):
context_task = Task(
description="Task 1",
expected_output="Output 1",
agent=agent1,
)
condition_mock = MagicMock(return_value=True)
task_with_context = ConditionalTask(
description="Task with Context",
expected_output="Output with Context",
agent=agent2,
condition=condition_mock,
context=[context_task],
)
crew.tasks.append(task_with_context)
result = crew._handle_conditional_task(
task=task_with_context,
task_outputs=[task1_output],
futures=[],
task_index=3, # This would be the 4th task
was_replayed=False,
)
# Verify the condition was called with task1's output
condition_mock.assert_called_once()
args = condition_mock.call_args[0][0]
assert args.raw == "Task 1 output"
assert result is None # Task should execute, so no skipped output
def test_conditional_task_with_empty_task_outputs(self, setup_crew, setup_tasks):
"""Test conditional task with empty task outputs."""
crew = setup_crew
_, task2, _, condition1_mock, _ = setup_tasks
result = crew._handle_conditional_task(
task=task2,
task_outputs=[],
futures=[],
task_index=1,
was_replayed=False,
)
condition1_mock.assert_not_called()
assert result is None # Task should execute, so no skipped output
def test_multiple_conditional_tasks():
"""Test that multiple conditional tasks are evaluated correctly.
This is a legacy test that's kept for backward compatibility.
The actual tests are now in the TestMultipleConditionalTasks class.
"""
agent1 = Agent(
role="Research Analyst",
goal="Find information",
backstory="You're a researcher",
verbose=True,
)
agent2 = Agent(
role="Data Analyst",
goal="Process information",
backstory="You process data",
verbose=True,
)
agent3 = Agent(
role="Report Writer",
goal="Write reports",
backstory="You write reports",
verbose=True,
)
# Create tasks
task1 = Task(
description="Task 1",
expected_output="Output 1",
agent=agent1,
)
# First conditional task should check task1's output
condition1_mock = MagicMock()
task2 = ConditionalTask(
description="Conditional Task 2",
expected_output="Output 2",
agent=agent2,
condition=condition1_mock,
)
# Second conditional task should check task1's output, not task2's
condition2_mock = MagicMock()
task3 = ConditionalTask(
description="Conditional Task 3",
expected_output="Output 3",
agent=agent3,
condition=condition2_mock,
)
crew = Crew(
agents=[agent1, agent2, agent3],
tasks=[task1, task2, task3],
verbose=True,
)
with patch.object(crew, '_find_task_index', return_value=0):
task1_output = TaskOutput(
description="Task 1",
raw="Task 1 output",
agent=agent1.role,
output_format=OutputFormat.RAW,
)
condition1_mock.return_value = True # Task should execute
result1 = crew._handle_conditional_task(
task=task2,
task_outputs=[task1_output],
futures=[],
task_index=1,
was_replayed=False,
)
# Verify the condition was called with task1's output
condition1_mock.assert_called_once()
args1 = condition1_mock.call_args[0][0]
assert args1.raw == "Task 1 output"
assert result1 is None # Task should execute, so no skipped output
condition1_mock.reset_mock()
task2_output = TaskOutput(
description="Conditional Task 2",
raw="Task 2 output",
agent=agent2.role,
output_format=OutputFormat.RAW,
)
condition2_mock.return_value = True # Task should execute
result2 = crew._handle_conditional_task(
task=task3,
task_outputs=[task1_output, task2_output],
futures=[],
task_index=2,
was_replayed=False,
)
# Verify the condition was called with task1's output, not task2's
condition2_mock.assert_called_once()
args2 = condition2_mock.call_args[0][0]
assert args2.raw == "Task 1 output" # Should be task1's output
assert args2.raw != "Task 2 output" # Should not be task2's output
assert result2 is None # Task should execute, so no skipped output