mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Compare commits
4 Commits
1.2.1
...
devin/1759
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1de7dcd3c2 | ||
|
|
ed95f47b80 | ||
|
|
c467c96e9f | ||
|
|
e1c2c08bba |
113
examples/dynamic_task_ordering_example.py
Normal file
113
examples/dynamic_task_ordering_example.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Example demonstrating dynamic task ordering in CrewAI.
|
||||
|
||||
This example shows how to use the task_ordering_callback to dynamically
|
||||
determine the execution order of tasks based on runtime conditions.
|
||||
"""
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.process import Process
|
||||
|
||||
|
||||
def priority_based_ordering(all_tasks, completed_outputs, current_index):
|
||||
"""
|
||||
Order tasks by priority (lower number = higher priority).
|
||||
|
||||
Args:
|
||||
all_tasks: List of all tasks in the crew
|
||||
completed_outputs: List of TaskOutput objects for completed tasks
|
||||
current_index: Current task index (for default ordering)
|
||||
|
||||
Returns:
|
||||
int: Index of next task to execute
|
||||
Task: Task object to execute next
|
||||
None: Use default ordering
|
||||
"""
|
||||
completed_tasks = {id(task) for task in all_tasks if task.output is not None}
|
||||
|
||||
remaining_tasks = [
|
||||
(i, task) for i, task in enumerate(all_tasks)
|
||||
if id(task) not in completed_tasks
|
||||
]
|
||||
|
||||
if not remaining_tasks:
|
||||
return None
|
||||
|
||||
remaining_tasks.sort(key=lambda x: getattr(x[1], 'priority', 999))
|
||||
|
||||
return remaining_tasks[0][0]
|
||||
|
||||
|
||||
def conditional_ordering(all_tasks, completed_outputs, current_index):
|
||||
"""
|
||||
Order tasks based on previous task outputs.
|
||||
|
||||
This example shows how to make task ordering decisions based on
|
||||
the results of previously completed tasks.
|
||||
"""
|
||||
if len(completed_outputs) == 0:
|
||||
return 0
|
||||
|
||||
last_output = completed_outputs[-1]
|
||||
|
||||
if "urgent" in last_output.raw.lower():
|
||||
completed_tasks = {id(task) for task in all_tasks if task.output is not None}
|
||||
for i, task in enumerate(all_tasks):
|
||||
if (hasattr(task, 'priority') and task.priority == 1 and
|
||||
id(task) not in completed_tasks):
|
||||
return i
|
||||
|
||||
return None
|
||||
|
||||
|
||||
researcher = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Gather and analyze information",
|
||||
backstory="Expert at finding and synthesizing information"
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
role="Content Writer",
|
||||
goal="Create compelling content",
|
||||
backstory="Skilled at crafting engaging narratives"
|
||||
)
|
||||
|
||||
reviewer = Agent(
|
||||
role="Quality Reviewer",
|
||||
goal="Ensure content quality",
|
||||
backstory="Meticulous attention to detail"
|
||||
)
|
||||
|
||||
research_task = Task(
|
||||
description="Research the latest trends in AI",
|
||||
expected_output="Comprehensive research report",
|
||||
agent=researcher
|
||||
)
|
||||
research_task.priority = 2
|
||||
|
||||
urgent_task = Task(
|
||||
description="Write urgent press release",
|
||||
expected_output="Press release draft",
|
||||
agent=writer
|
||||
)
|
||||
urgent_task.priority = 1
|
||||
|
||||
review_task = Task(
|
||||
description="Review and edit content",
|
||||
expected_output="Polished final content",
|
||||
agent=reviewer
|
||||
)
|
||||
review_task.priority = 3
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer, reviewer],
|
||||
tasks=[research_task, urgent_task, review_task],
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=priority_based_ordering,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting crew with dynamic task ordering...")
|
||||
result = crew.kickoff()
|
||||
print(f"Completed {len(result.tasks_output)} tasks")
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
@@ -113,6 +114,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
execution.
|
||||
step_callback: Callback to be executed after each step for every agents
|
||||
execution.
|
||||
task_ordering_callback: Callback to determine the next task to execute
|
||||
dynamically. Receives (all_tasks, completed_outputs, current_index)
|
||||
and returns next task index, Task object, or None for default ordering.
|
||||
share_crew: Whether you want to share the complete crew information and
|
||||
execution with crewAI to make the library better, and allow us to
|
||||
train models.
|
||||
@@ -213,6 +217,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"It may be used to adjust the output of the crew."
|
||||
),
|
||||
)
|
||||
task_ordering_callback: Callable[
|
||||
[list[Task], list[TaskOutput], int], int | Task | None
|
||||
] | None = Field(
|
||||
default=None,
|
||||
description="Callback to determine the next task to execute. Receives (all_tasks, completed_outputs, current_index) and returns next task index, Task object, or None for default ordering.",
|
||||
)
|
||||
max_rpm: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -535,6 +545,25 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_task_ordering_callback(self):
|
||||
"""Validates that the task ordering callback has the correct signature."""
|
||||
if self.task_ordering_callback is not None:
|
||||
if not callable(self.task_ordering_callback):
|
||||
raise ValueError("task_ordering_callback must be callable")
|
||||
|
||||
try:
|
||||
sig = inspect.signature(self.task_ordering_callback)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
if len(sig.parameters) != 3:
|
||||
raise ValueError(
|
||||
"task_ordering_callback must accept exactly 3 parameters: (tasks, outputs, current_index)"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
source: list[str] = [agent.key for agent in self.agents] + [
|
||||
@@ -847,12 +876,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
start_index: int | None = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
"""Executes tasks sequentially and returns the final output.
|
||||
"""Executes tasks with optional dynamic ordering and returns the final output.
|
||||
|
||||
Args:
|
||||
tasks (List[Task]): List of tasks to execute
|
||||
manager (Optional[BaseAgent], optional): Manager agent to use for
|
||||
delegation. Defaults to None.
|
||||
start_index (int | None): Starting index for task execution
|
||||
was_replayed (bool): Whether this is a replay execution
|
||||
|
||||
Returns:
|
||||
CrewOutput: Final output of the crew
|
||||
@@ -861,7 +890,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_outputs: list[TaskOutput] = []
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]] = []
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
executed_task_indices: set[int] = set()
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
if start_index is not None and task_index < start_index:
|
||||
if task.output:
|
||||
@@ -870,7 +900,66 @@ class Crew(FlowTrackable, BaseModel):
|
||||
else:
|
||||
task_outputs = [task.output]
|
||||
last_sync_output = task.output
|
||||
continue
|
||||
executed_task_indices.add(task_index)
|
||||
|
||||
while len(executed_task_indices) < len(tasks):
|
||||
# Find next task to execute
|
||||
if self.task_ordering_callback:
|
||||
try:
|
||||
next_task_result = self.task_ordering_callback(
|
||||
tasks, task_outputs, len(executed_task_indices)
|
||||
)
|
||||
|
||||
if next_task_result is None:
|
||||
task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices)
|
||||
elif isinstance(next_task_result, int):
|
||||
if 0 <= next_task_result < len(tasks) and next_task_result not in executed_task_indices:
|
||||
task_index = next_task_result
|
||||
else:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Invalid or already executed task index {next_task_result} from ordering callback, using default",
|
||||
color="yellow"
|
||||
)
|
||||
task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices)
|
||||
elif isinstance(next_task_result, Task):
|
||||
try:
|
||||
candidate_index = tasks.index(next_task_result)
|
||||
if candidate_index not in executed_task_indices:
|
||||
task_index = candidate_index
|
||||
else:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
"Task from ordering callback already executed, using default",
|
||||
color="yellow"
|
||||
)
|
||||
task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices)
|
||||
except ValueError:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
"Task from ordering callback not found in tasks list, using default",
|
||||
color="yellow"
|
||||
)
|
||||
task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices)
|
||||
else:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Invalid return type from ordering callback: {type(next_task_result)}, using default",
|
||||
color="yellow"
|
||||
)
|
||||
task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices)
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Error in task ordering callback: {e}, using default ordering",
|
||||
color="yellow"
|
||||
)
|
||||
task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices)
|
||||
else:
|
||||
task_index = next(i for i in range(len(tasks)) if i not in executed_task_indices)
|
||||
|
||||
task = tasks[task_index]
|
||||
executed_task_indices.add(task_index)
|
||||
|
||||
agent_to_use = self._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
@@ -880,9 +969,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
f"or a manager agent is provided."
|
||||
)
|
||||
|
||||
# Determine which tools to use - task tools take precedence over agent tools
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
# Prepare tools and ensure they're compatible with task execution
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
@@ -923,6 +1010,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||
last_sync_output = task_output
|
||||
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
|
||||
302
tests/test_dynamic_task_ordering.py
Normal file
302
tests/test_dynamic_task_ordering.py
Normal file
@@ -0,0 +1,302 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.process import Process
|
||||
from crewai.task import TaskOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agents():
|
||||
return [
|
||||
Agent(role="Agent 1", goal="Goal 1", backstory="Backstory 1"),
|
||||
Agent(role="Agent 2", goal="Goal 2", backstory="Backstory 2"),
|
||||
Agent(role="Agent 3", goal="Goal 3", backstory="Backstory 3"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tasks(agents):
|
||||
return [
|
||||
Task(description="Task 1", expected_output="Output 1", agent=agents[0]),
|
||||
Task(description="Task 2", expected_output="Output 2", agent=agents[1]),
|
||||
Task(description="Task 3", expected_output="Output 3", agent=agents[2]),
|
||||
]
|
||||
|
||||
|
||||
def test_sequential_process_with_reverse_ordering(agents, tasks):
|
||||
"""Test sequential process with reverse task ordering."""
|
||||
execution_order = []
|
||||
|
||||
def reverse_ordering_callback(all_tasks, completed_outputs, current_index):
|
||||
completed_tasks = {id(task) for task in all_tasks if task.output is not None}
|
||||
remaining_indices = [i for i in range(len(all_tasks))
|
||||
if id(all_tasks[i]) not in completed_tasks]
|
||||
if remaining_indices:
|
||||
next_index = max(remaining_indices)
|
||||
execution_order.append(next_index)
|
||||
return next_index
|
||||
return None
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=reverse_ordering_callback,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
assert len(result.tasks_output) == 3
|
||||
assert execution_order == [2, 1, 0]
|
||||
|
||||
|
||||
def test_hierarchical_process_with_priority_ordering(agents, tasks):
|
||||
"""Test hierarchical process with priority-based task ordering."""
|
||||
|
||||
tasks[0].priority = 3
|
||||
tasks[1].priority = 1
|
||||
tasks[2].priority = 2
|
||||
|
||||
execution_order = []
|
||||
|
||||
def priority_ordering_callback(all_tasks, completed_outputs, current_index):
|
||||
completed_tasks = {id(task) for task in all_tasks if task.output is not None}
|
||||
remaining_tasks = [
|
||||
(i, task) for i, task in enumerate(all_tasks)
|
||||
if id(task) not in completed_tasks
|
||||
]
|
||||
|
||||
if remaining_tasks:
|
||||
remaining_tasks.sort(key=lambda x: getattr(x[1], 'priority', 999))
|
||||
next_index = remaining_tasks[0][0]
|
||||
execution_order.append(next_index)
|
||||
return next_index
|
||||
|
||||
return None
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.hierarchical,
|
||||
manager_llm="gpt-4o",
|
||||
task_ordering_callback=priority_ordering_callback,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
assert len(result.tasks_output) == 3
|
||||
assert execution_order == [1, 2, 0]
|
||||
|
||||
|
||||
def test_task_ordering_callback_with_task_object_return():
|
||||
"""Test callback returning Task object instead of index."""
|
||||
|
||||
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
|
||||
tasks = [
|
||||
Task(description="Task A", expected_output="Output A", agent=agents[0]),
|
||||
Task(description="Task B", expected_output="Output B", agent=agents[0]),
|
||||
]
|
||||
|
||||
execution_order = []
|
||||
|
||||
def task_object_callback(all_tasks, completed_outputs, current_index):
|
||||
if len(completed_outputs) == 0:
|
||||
execution_order.append(1)
|
||||
return all_tasks[1]
|
||||
elif len(completed_outputs) == 1:
|
||||
execution_order.append(0)
|
||||
return all_tasks[0]
|
||||
return None
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=task_object_callback,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
assert len(result.tasks_output) == 2
|
||||
assert execution_order == [1, 0]
|
||||
|
||||
|
||||
def test_invalid_task_ordering_callback_index():
|
||||
"""Test handling of invalid task index from callback."""
|
||||
|
||||
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
|
||||
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
|
||||
|
||||
def invalid_callback(all_tasks, completed_outputs, current_index):
|
||||
return 999
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=invalid_callback,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
assert len(result.tasks_output) == 1
|
||||
|
||||
|
||||
def test_task_ordering_callback_exception_handling():
|
||||
"""Test handling of exceptions in task ordering callback."""
|
||||
|
||||
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
|
||||
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
|
||||
|
||||
def failing_callback(all_tasks, completed_outputs, current_index):
|
||||
raise ValueError("Callback error")
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=failing_callback,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
assert len(result.tasks_output) == 1
|
||||
|
||||
|
||||
def test_task_ordering_callback_validation():
|
||||
"""Test validation of task ordering callback signature."""
|
||||
|
||||
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
|
||||
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
|
||||
|
||||
def invalid_signature_callback(only_one_param):
|
||||
return 0
|
||||
|
||||
with pytest.raises(ValueError, match="task_ordering_callback must accept exactly 3 parameters"):
|
||||
Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=invalid_signature_callback
|
||||
)
|
||||
|
||||
|
||||
def test_no_task_ordering_callback_default_behavior():
|
||||
"""Test that default behavior is unchanged when no callback is provided."""
|
||||
|
||||
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
|
||||
tasks = [
|
||||
Task(description="Task 1", expected_output="Output 1", agent=agents[0]),
|
||||
Task(description="Task 2", expected_output="Output 2", agent=agents[0]),
|
||||
]
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
assert len(result.tasks_output) == 2
|
||||
|
||||
|
||||
def test_task_ordering_callback_with_none_return():
|
||||
"""Test callback returning None for default ordering."""
|
||||
|
||||
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
|
||||
tasks = [
|
||||
Task(description="Task 1", expected_output="Output 1", agent=agents[0]),
|
||||
Task(description="Task 2", expected_output="Output 2", agent=agents[0]),
|
||||
]
|
||||
|
||||
def none_callback(all_tasks, completed_outputs, current_index):
|
||||
return None
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=none_callback,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
assert len(result.tasks_output) == 2
|
||||
|
||||
|
||||
def test_task_ordering_callback_invalid_task_object():
|
||||
"""Test handling of invalid Task object from callback."""
|
||||
|
||||
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
|
||||
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
|
||||
|
||||
invalid_task = Task(description="Invalid", expected_output="Invalid", agent=agents[0])
|
||||
|
||||
def invalid_task_callback(all_tasks, completed_outputs, current_index):
|
||||
return invalid_task
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=invalid_task_callback,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
assert len(result.tasks_output) == 1
|
||||
|
||||
|
||||
def test_task_ordering_callback_invalid_return_type():
|
||||
"""Test handling of invalid return type from callback."""
|
||||
|
||||
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
|
||||
tasks = [Task(description="Task", expected_output="Output", agent=agents[0])]
|
||||
|
||||
def invalid_type_callback(all_tasks, completed_outputs, current_index):
|
||||
return "invalid"
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=invalid_type_callback,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
assert len(result.tasks_output) == 1
|
||||
|
||||
|
||||
def test_task_ordering_prevents_infinite_loops():
|
||||
"""Test that task ordering prevents infinite loops by tracking executed tasks."""
|
||||
|
||||
agents = [Agent(role="Agent", goal="Goal", backstory="Backstory")]
|
||||
tasks = [
|
||||
Task(description="Task 1", expected_output="Output 1", agent=agents[0]),
|
||||
Task(description="Task 2", expected_output="Output 2", agent=agents[0]),
|
||||
]
|
||||
|
||||
call_count = 0
|
||||
|
||||
def loop_callback(all_tasks, completed_outputs, current_index):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 10:
|
||||
pytest.fail("Callback called too many times, possible infinite loop")
|
||||
return 0
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
task_ordering_callback=loop_callback,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
assert len(result.tasks_output) == 2
|
||||
assert call_count <= 4
|
||||
Reference in New Issue
Block a user