feat: Add dynamic task ordering capability to Sequential and Hierarchical processes

- Add task_ordering_callback field to Crew class with proper validation
- Implement dynamic task selection in _execute_tasks method
- Add comprehensive error handling and validation for callback
- Include tests for various ordering scenarios and edge cases
- Maintain backward compatibility with existing code
- Support both task index and Task object returns from callback
- Add example demonstrating priority-based and conditional ordering

Fixes #3620

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-09-29 10:50:28 +00:00
parent 7d5cd4d3e2
commit e1c2c08bba
3 changed files with 518 additions and 7 deletions

View File

@@ -0,0 +1,112 @@
"""
Example demonstrating dynamic task ordering in CrewAI.
This example shows how to use the task_ordering_callback to dynamically
determine the execution order of tasks based on runtime conditions.
"""
from crewai import Agent, Crew, Task
from crewai.process import Process
def priority_based_ordering(all_tasks, completed_outputs, current_index):
"""
Order tasks by priority (lower number = higher priority).
Args:
all_tasks: List of all tasks in the crew
completed_outputs: List of TaskOutput objects for completed tasks
current_index: Current task index (for default ordering)
Returns:
int: Index of next task to execute
Task: Task object to execute next
None: Use default ordering
"""
completed_task_ids = {output.task_id for output in completed_outputs}
remaining_tasks = [
(i, task) for i, task in enumerate(all_tasks)
if task.id not in completed_task_ids
]
if not remaining_tasks:
return None
remaining_tasks.sort(key=lambda x: getattr(x[1], 'priority', 999))
return remaining_tasks[0][0]
def conditional_ordering(all_tasks, completed_outputs, current_index):
"""
Order tasks based on previous task outputs.
This example shows how to make task ordering decisions based on
the results of previously completed tasks.
"""
if len(completed_outputs) == 0:
return 0
last_output = completed_outputs[-1]
if "urgent" in last_output.raw.lower():
for i, task in enumerate(all_tasks):
if (hasattr(task, 'priority') and task.priority == 1 and
task.id not in {out.task_id for out in completed_outputs}):
return i
return None
researcher = Agent(
role="Research Analyst",
goal="Gather and analyze information",
backstory="Expert at finding and synthesizing information"
)
writer = Agent(
role="Content Writer",
goal="Create compelling content",
backstory="Skilled at crafting engaging narratives"
)
reviewer = Agent(
role="Quality Reviewer",
goal="Ensure content quality",
backstory="Meticulous attention to detail"
)
research_task = Task(
description="Research the latest trends in AI",
expected_output="Comprehensive research report",
agent=researcher
)
research_task.priority = 2
urgent_task = Task(
description="Write urgent press release",
expected_output="Press release draft",
agent=writer
)
urgent_task.priority = 1
review_task = Task(
description="Review and edit content",
expected_output="Polished final content",
agent=reviewer
)
review_task.priority = 3
crew = Crew(
agents=[researcher, writer, reviewer],
tasks=[research_task, urgent_task, review_task],
process=Process.sequential,
task_ordering_callback=priority_based_ordering,
verbose=True
)
if __name__ == "__main__":
print("Starting crew with dynamic task ordering...")
result = crew.kickoff()
print(f"Completed {len(result.tasks_output)} tasks")

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import inspect
import json import json
import re import re
import uuid import uuid
@@ -113,6 +114,9 @@ class Crew(FlowTrackable, BaseModel):
execution. execution.
step_callback: Callback to be executed after each step for every agents step_callback: Callback to be executed after each step for every agents
execution. execution.
task_ordering_callback: Callback to determine the next task to execute
dynamically. Receives (all_tasks, completed_outputs, current_index)
and returns next task index, Task object, or None for default ordering.
share_crew: Whether you want to share the complete crew information and share_crew: Whether you want to share the complete crew information and
execution with crewAI to make the library better, and allow us to execution with crewAI to make the library better, and allow us to
train models. train models.
@@ -213,6 +217,12 @@ class Crew(FlowTrackable, BaseModel):
"It may be used to adjust the output of the crew." "It may be used to adjust the output of the crew."
), ),
) )
task_ordering_callback: Callable[
[list[Task], list[TaskOutput], int], int | Task | None
] | None = Field(
default=None,
description="Callback to determine the next task to execute. Receives (all_tasks, completed_outputs, current_index) and returns next task index, Task object, or None for default ordering.",
)
max_rpm: int | None = Field( max_rpm: int | None = Field(
default=None, default=None,
description=( description=(
@@ -535,6 +545,30 @@ class Crew(FlowTrackable, BaseModel):
) )
return self return self
@model_validator(mode="after")
def validate_task_ordering_callback(self):
"""Validates that the task ordering callback has the correct signature."""
if self.task_ordering_callback is not None:
if not callable(self.task_ordering_callback):
raise PydanticCustomError(
"invalid_task_ordering_callback",
"task_ordering_callback must be callable",
{},
)
try:
sig = inspect.signature(self.task_ordering_callback)
if len(sig.parameters) != 3:
raise PydanticCustomError(
"invalid_task_ordering_callback_signature",
"task_ordering_callback must accept exactly 3 parameters: (tasks, outputs, current_index)",
{},
)
except (ValueError, TypeError):
pass
return self
@property @property
def key(self) -> str: def key(self) -> str:
source: list[str] = [agent.key for agent in self.agents] + [ source: list[str] = [agent.key for agent in self.agents] + [
@@ -847,12 +881,12 @@ class Crew(FlowTrackable, BaseModel):
start_index: int | None = 0, start_index: int | None = 0,
was_replayed: bool = False, was_replayed: bool = False,
) -> CrewOutput: ) -> CrewOutput:
"""Executes tasks sequentially and returns the final output. """Executes tasks with optional dynamic ordering and returns the final output.
Args: Args:
tasks (List[Task]): List of tasks to execute tasks (List[Task]): List of tasks to execute
manager (Optional[BaseAgent], optional): Manager agent to use for start_index (int | None): Starting index for task execution
delegation. Defaults to None. was_replayed (bool): Whether this is a replay execution
Returns: Returns:
CrewOutput: Final output of the crew CrewOutput: Final output of the crew
@@ -861,17 +895,78 @@ class Crew(FlowTrackable, BaseModel):
task_outputs: list[TaskOutput] = [] task_outputs: list[TaskOutput] = []
futures: list[tuple[Task, Future[TaskOutput], int]] = [] futures: list[tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: TaskOutput | None = None last_sync_output: TaskOutput | None = None
executed_task_indices: set[int] = set()
current_index = start_index or 0
for task_index, task in enumerate(tasks): while current_index < len(tasks):
if start_index is not None and task_index < start_index: if current_index in executed_task_indices:
current_index += 1
continue
if start_index is not None and current_index < start_index:
task = tasks[current_index]
if task.output: if task.output:
if task.async_execution: if task.async_execution:
task_outputs.append(task.output) task_outputs.append(task.output)
else: else:
task_outputs = [task.output] task_outputs = [task.output]
last_sync_output = task.output last_sync_output = task.output
executed_task_indices.add(current_index)
current_index += 1
continue continue
if self.task_ordering_callback:
try:
next_task_result = self.task_ordering_callback(
tasks, task_outputs, current_index
)
if next_task_result is None:
task_index = current_index
elif isinstance(next_task_result, int):
if 0 <= next_task_result < len(tasks):
task_index = next_task_result
else:
self._logger.log(
"warning",
f"Invalid task index {next_task_result} from ordering callback, using default",
color="yellow"
)
task_index = current_index
elif isinstance(next_task_result, Task):
try:
task_index = tasks.index(next_task_result)
except ValueError:
self._logger.log(
"warning",
"Task from ordering callback not found in tasks list, using default",
color="yellow"
)
task_index = current_index
else:
self._logger.log(
"warning",
f"Invalid return type from ordering callback: {type(next_task_result)}, using default",
color="yellow"
)
task_index = current_index
except Exception as e:
self._logger.log(
"warning",
f"Error in task ordering callback: {e}, using default ordering",
color="yellow"
)
task_index = current_index
else:
task_index = current_index
if task_index in executed_task_indices:
current_index += 1
continue
task = tasks[task_index]
executed_task_indices.add(task_index)
agent_to_use = self._get_agent_to_use(task) agent_to_use = self._get_agent_to_use(task)
if agent_to_use is None: if agent_to_use is None:
raise ValueError( raise ValueError(
@@ -880,9 +975,7 @@ class Crew(FlowTrackable, BaseModel):
f"or a manager agent is provided." f"or a manager agent is provided."
) )
# Determine which tools to use - task tools take precedence over agent tools
tools_for_task = task.tools or agent_to_use.tools or [] tools_for_task = task.tools or agent_to_use.tools or []
# Prepare tools and ensure they're compatible with task execution
tools_for_task = self._prepare_tools( tools_for_task = self._prepare_tools(
agent_to_use, agent_to_use,
task, task,
@@ -897,6 +990,7 @@ class Crew(FlowTrackable, BaseModel):
) )
if skipped_task_output: if skipped_task_output:
task_outputs.append(skipped_task_output) task_outputs.append(skipped_task_output)
current_index += 1
continue continue
if task.async_execution: if task.async_execution:
@@ -923,6 +1017,9 @@ class Crew(FlowTrackable, BaseModel):
task_outputs.append(task_output) task_outputs.append(task_output)
self._process_task_result(task, task_output) self._process_task_result(task, task_output)
self._store_execution_log(task, task_output, task_index, was_replayed) self._store_execution_log(task, task_output, task_index, was_replayed)
last_sync_output = task_output
current_index += 1
if futures: if futures:
task_outputs = self._process_async_tasks(futures, was_replayed) task_outputs = self._process_async_tasks(futures, was_replayed)

View File

@@ -0,0 +1,302 @@
import pytest
from unittest.mock import Mock
from crewai import Agent, Crew, Task
from crewai.process import Process
from crewai.task import TaskOutput
@pytest.fixture
def agents():
return [
Agent(role="Agent 1", goal="Goal 1", backstory="Backstory 1"),
Agent(role="Agent 2", goal="Goal 2", backstory="Backstory 2"),
Agent(role="Agent 3", goal="Goal 3", backstory="Backstory 3"),
]
@pytest.fixture
def tasks(agents):
return [
Task(description="Task 1", expected_output="Output 1", agent=agents[0]),
Task(description="Task 2", expected_output="Output 2", agent=agents[1]),
Task(description="Task 3", expected_output="Output 3", agent=agents[2]),
]
def test_sequential_process_with_reverse_ordering(agents, tasks):
"""Test sequential process with reverse task ordering."""
execution_order = []
def reverse_ordering_callback(all_tasks, completed_outputs, current_index):
completed_task_ids = {output.task_id for output in completed_outputs}
remaining_indices = [i for i in range(len(all_tasks))
if all_tasks[i].id not in completed_task_ids]
if remaining_indices:
next_index = max(remaining_indices)
execution_order.append(next_index)
return next_index
return None
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
task_ordering_callback=reverse_ordering_callback,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 3
assert execution_order == [2, 1, 0]
def test_hierarchical_process_with_priority_ordering(agents, tasks):
"""Test hierarchical process with priority-based task ordering."""
tasks[0].priority = 3
tasks[1].priority = 1
tasks[2].priority = 2
execution_order = []
def priority_ordering_callback(all_tasks, completed_outputs, current_index):
completed_task_ids = {output.task_id for output in completed_outputs}
remaining_tasks = [
(i, task) for i, task in enumerate(all_tasks)
if task.id not in completed_task_ids
]
if remaining_tasks:
remaining_tasks.sort(key=lambda x: getattr(x[1], 'priority', 999))
next_index = remaining_tasks[0][0]
execution_order.append(next_index)
return next_index
return None
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.hierarchical,
manager_llm="gpt-4o",
task_ordering_callback=priority_ordering_callback,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 3
assert execution_order == [1, 2, 0]
def test_task_ordering_callback_with_task_object_return():
"""Test callback returning Task object instead of index."""
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
tasks = [
Task(description="Task A", expected_output="Output A", agent=agents[0]),
Task(description="Task B", expected_output="Output B", agent=agents[0]),
]
execution_order = []
def task_object_callback(all_tasks, completed_outputs, current_index):
if len(completed_outputs) == 0:
execution_order.append(1)
return all_tasks[1]
elif len(completed_outputs) == 1:
execution_order.append(0)
return all_tasks[0]
return None
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
task_ordering_callback=task_object_callback,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 2
assert execution_order == [1, 0]
def test_invalid_task_ordering_callback_index():
"""Test handling of invalid task index from callback."""
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
def invalid_callback(all_tasks, completed_outputs, current_index):
return 999
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
task_ordering_callback=invalid_callback,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 1
def test_task_ordering_callback_exception_handling():
"""Test handling of exceptions in task ordering callback."""
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
def failing_callback(all_tasks, completed_outputs, current_index):
raise ValueError("Callback error")
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
task_ordering_callback=failing_callback,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 1
def test_task_ordering_callback_validation():
"""Test validation of task ordering callback signature."""
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
def invalid_signature_callback(only_one_param):
return 0
with pytest.raises(ValueError, match="task_ordering_callback must accept exactly 3 parameters"):
Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
task_ordering_callback=invalid_signature_callback
)
def test_no_task_ordering_callback_default_behavior():
"""Test that default behavior is unchanged when no callback is provided."""
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
tasks = [
Task(description="Task 1", expected_output="Output 1", agent=agents[0]),
Task(description="Task 2", expected_output="Output 2", agent=agents[0]),
]
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 2
def test_task_ordering_callback_with_none_return():
"""Test callback returning None for default ordering."""
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
tasks = [
Task(description="Task 1", expected_output="Output 1", agent=agents[0]),
Task(description="Task 2", expected_output="Output 2", agent=agents[0]),
]
def none_callback(all_tasks, completed_outputs, current_index):
return None
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
task_ordering_callback=none_callback,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 2
def test_task_ordering_callback_invalid_task_object():
"""Test handling of invalid Task object from callback."""
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
invalid_task = Task(description="Invalid", expected_output="Invalid", agent=agents[0])
def invalid_task_callback(all_tasks, completed_outputs, current_index):
return invalid_task
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
task_ordering_callback=invalid_task_callback,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 1
def test_task_ordering_callback_invalid_return_type():
"""Test handling of invalid return type from callback."""
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
def invalid_type_callback(all_tasks, completed_outputs, current_index):
return "invalid"
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
task_ordering_callback=invalid_type_callback,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 1
def test_task_ordering_prevents_infinite_loops():
"""Test that task ordering prevents infinite loops by tracking executed tasks."""
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
tasks = [
Task(description="Task 1", expected_output="Output 1", agent=agents[0]),
Task(description="Task 2", expected_output="Output 2", agent=agents[0]),
]
call_count = 0
def loop_callback(all_tasks, completed_outputs, current_index):
nonlocal call_count
call_count += 1
if call_count > 10:
pytest.fail("Callback called too many times, possible infinite loop")
return 0
crew = Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
task_ordering_callback=loop_callback,
verbose=False
)
result = crew.kickoff()
assert len(result.tasks_output) == 2
assert call_count <= 4