mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
Add selective task execution feature for issue #2941
- Add tags field to Task class for categorization - Add task_selector parameter to Crew class - Implement task filtering in _execute_tasks method - Add Process.selective type with validation - Add helper method for tag-based selection - Add comprehensive tests covering all scenarios - Maintain backward compatibility with existing crews Fixes #2941: Users can now run only specific agents/tasks based on input parameters like 'action', rather than executing the entire crew process. Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -200,6 +200,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=list,
|
||||
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
|
||||
)
|
||||
task_selector: Optional[Callable[[Dict[str, Any], Task], bool]] = Field(
|
||||
default=None,
|
||||
description="Function to determine which tasks should execute based on inputs and task properties.",
|
||||
)
|
||||
max_rpm: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum number of requests per minute for the crew execution to be respected.",
|
||||
@@ -504,6 +508,17 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_selective_process_requirements(self) -> "Crew":
|
||||
"""Ensure selective process has required task_selector."""
|
||||
if self.process == Process.selective and not self.task_selector:
|
||||
raise PydanticCustomError(
|
||||
"missing_task_selector",
|
||||
"Selective process requires a task_selector to be defined.",
|
||||
{},
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
source: List[str] = [agent.key for agent in self.agents] + [
|
||||
@@ -661,6 +676,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
result = self._run_sequential_process()
|
||||
elif self.process == Process.hierarchical:
|
||||
result = self._run_hierarchical_process()
|
||||
elif self.process == Process.selective:
|
||||
result = self._run_selective_process()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The process '{self.process}' is not implemented yet."
|
||||
@@ -777,6 +794,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._create_manager_agent()
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
def _run_selective_process(self) -> CrewOutput:
|
||||
"""Executes tasks selectively based on task_selector and returns the final output."""
|
||||
if not self.task_selector:
|
||||
raise ValueError("Selective process requires a task_selector to be defined.")
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
def _create_manager_agent(self):
|
||||
i18n = I18N(prompt_file=self.prompt_file)
|
||||
if self.manager_agent is not None:
|
||||
@@ -812,12 +835,22 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
Args:
|
||||
tasks (List[Task]): List of tasks to execute
|
||||
manager (Optional[BaseAgent], optional): Manager agent to use for delegation. Defaults to None.
|
||||
start_index (Optional[int], optional): Starting index for task execution. Defaults to 0.
|
||||
was_replayed (bool, optional): Whether this is a replayed execution. Defaults to False.
|
||||
|
||||
Returns:
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
|
||||
if self.task_selector and self._inputs:
|
||||
filtered_tasks = [
|
||||
task for task in tasks
|
||||
if self.task_selector(self._inputs, task)
|
||||
]
|
||||
if not filtered_tasks:
|
||||
raise ValueError("No tasks match the selection criteria. At least one task must be selected for execution.")
|
||||
tasks = filtered_tasks
|
||||
|
||||
task_outputs: List[TaskOutput] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
|
||||
last_sync_output: Optional[TaskOutput] = None
|
||||
@@ -1506,3 +1539,27 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""Reset crew and agent knowledge storage."""
|
||||
for ks in knowledges:
|
||||
ks.reset()
|
||||
|
||||
@staticmethod
|
||||
def create_tag_selector(action_key: str = "action", tag_mapping: Optional[Dict[str, List[str]]] = None) -> Callable[[Dict[str, Any], Task], bool]:
|
||||
"""Create a task selector function based on tags and input action.
|
||||
|
||||
Args:
|
||||
action_key: Key in inputs dict that specifies the action (default: "action")
|
||||
tag_mapping: Optional mapping of action values to required tags
|
||||
|
||||
Returns:
|
||||
Function that selects tasks based on tags matching the action
|
||||
"""
|
||||
def selector(inputs: Dict[str, Any], task: Task) -> bool:
|
||||
action = inputs.get(action_key)
|
||||
if not action or not task.tags:
|
||||
return True
|
||||
|
||||
if tag_mapping and action in tag_mapping:
|
||||
required_tags = tag_mapping[action]
|
||||
return any(tag in task.tags for tag in required_tags)
|
||||
else:
|
||||
return action in task.tags
|
||||
|
||||
return selector
|
||||
|
||||
@@ -8,4 +8,5 @@ class Process(str, Enum):
|
||||
|
||||
sequential = "sequential"
|
||||
hierarchical = "hierarchical"
|
||||
selective = "selective"
|
||||
# TODO: consensual = 'consensual'
|
||||
|
||||
@@ -139,6 +139,10 @@ class Task(BaseModel):
|
||||
description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
|
||||
default=False,
|
||||
)
|
||||
tags: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Tags to categorize this task for selective execution.",
|
||||
)
|
||||
converter_cls: Optional[Type[Converter]] = Field(
|
||||
description="A converter class used to export structured output",
|
||||
default=None,
|
||||
|
||||
@@ -1538,6 +1538,172 @@ def test_set_agents_step_callback():
|
||||
assert researcher_agent.step_callback is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_selective_execution_with_tags(researcher, writer):
|
||||
"""Test selective task execution based on tags and input action."""
|
||||
|
||||
forecast_task = Task(
|
||||
description="Analyze forecast data",
|
||||
expected_output="Forecast analysis",
|
||||
agent=researcher,
|
||||
tags=["forecast", "analysis"]
|
||||
)
|
||||
|
||||
news_task = Task(
|
||||
description="Summarize news",
|
||||
expected_output="News summary",
|
||||
agent=writer,
|
||||
tags=["news", "summary"]
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[forecast_task, news_task],
|
||||
task_selector=Crew.create_tag_selector()
|
||||
)
|
||||
|
||||
result = crew.kickoff(inputs={"action": "forecast"})
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_selective_process_type(researcher):
|
||||
"""Test selective process type."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
tags=["test"]
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
process=Process.selective,
|
||||
task_selector=Crew.create_tag_selector()
|
||||
)
|
||||
|
||||
result = crew.kickoff(inputs={"action": "test"})
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_selective_execution_no_matching_tasks_error(researcher):
|
||||
"""Test error when no tasks match selection criteria."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
tags=["other"]
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
task_selector=Crew.create_tag_selector()
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No tasks match the selection criteria"):
|
||||
crew.kickoff(inputs={"action": "nonexistent"})
|
||||
|
||||
|
||||
def test_selective_process_missing_selector_error(researcher):
|
||||
"""Test error when selective process lacks task_selector."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError, match="Selective process requires a task_selector"):
|
||||
Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
process=Process.selective
|
||||
)
|
||||
|
||||
|
||||
def test_tag_selector_with_mapping(researcher, writer):
|
||||
"""Test tag selector with custom tag mapping."""
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=researcher,
|
||||
tags=["data_analysis"]
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Task 2",
|
||||
expected_output="Output 2",
|
||||
agent=writer,
|
||||
tags=["reporting"]
|
||||
)
|
||||
|
||||
tag_mapping = {
|
||||
"analyze": ["data_analysis", "research"],
|
||||
"report": ["reporting", "writing"]
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2],
|
||||
task_selector=Crew.create_tag_selector(tag_mapping=tag_mapping)
|
||||
)
|
||||
|
||||
result = crew.kickoff(inputs={"action": "analyze"})
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_selective_execution_no_action_executes_all(researcher, writer):
|
||||
"""Test that when no action is specified, all tasks execute."""
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=researcher,
|
||||
tags=["tag1"]
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Task 2",
|
||||
expected_output="Output 2",
|
||||
agent=writer,
|
||||
tags=["tag2"]
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2],
|
||||
task_selector=Crew.create_tag_selector()
|
||||
)
|
||||
|
||||
result = crew.kickoff(inputs={})
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_selective_execution_no_tags_executes_all(researcher, writer):
|
||||
"""Test that tasks without tags execute when using selective execution."""
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=researcher
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Task 2",
|
||||
expected_output="Output 2",
|
||||
agent=writer
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2],
|
||||
task_selector=Crew.create_tag_selector()
|
||||
)
|
||||
|
||||
result = crew.kickoff(inputs={"action": "anything"})
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_dont_set_agents_step_callback_if_already_set():
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
66
tests/test_selective_execution_example.py
Normal file
66
tests/test_selective_execution_example.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Example demonstrating selective execution for issue #2941."""
|
||||
|
||||
import pytest
|
||||
from crewai import Agent, Crew, Task, Process
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_issue_2941_example():
|
||||
"""Reproduce and test the exact scenario from issue #2941."""
|
||||
|
||||
holiday_agent = Agent(role="Holiday Researcher", goal="Research holidays", backstory="Expert in holidays")
|
||||
macro_agent = Agent(role="Macro Analyst", goal="Analyze macro data", backstory="Expert in macroeconomics")
|
||||
news_agent = Agent(role="News Summarizer", goal="Summarize news", backstory="Expert in news analysis")
|
||||
forecast_agent = Agent(role="Forecaster", goal="Create forecasts", backstory="Expert in forecasting")
|
||||
query_agent = Agent(role="Query Handler", goal="Handle user queries", backstory="Expert in query processing")
|
||||
|
||||
holiday_task = Task(description="Research holiday information", expected_output="Holiday data", agent=holiday_agent, tags=["holiday"])
|
||||
macro_task = Task(description="Extract macroeconomic data", expected_output="Macro data", agent=macro_agent, tags=["macro"])
|
||||
news_task = Task(description="Summarize relevant news", expected_output="News summary", agent=news_agent, tags=["news"])
|
||||
forecast_task = Task(description="Generate forecast", expected_output="Forecast result", agent=forecast_agent, tags=["forecast"])
|
||||
query_task = Task(description="Handle user query", expected_output="Query response", agent=query_agent, tags=["query"])
|
||||
|
||||
crew = Crew(
|
||||
agents=[holiday_agent, macro_agent, news_agent, forecast_agent, query_agent],
|
||||
tasks=[holiday_task, macro_task, news_task, forecast_task, query_task],
|
||||
process=Process.selective,
|
||||
task_selector=Crew.create_tag_selector()
|
||||
)
|
||||
|
||||
inputs = {
|
||||
'data_file': 'sample.csv',
|
||||
'action': 'forecast',
|
||||
'country_code': 'US',
|
||||
'topic': 'Egg_prices',
|
||||
'query': "Provide forecasted result on the input data"
|
||||
}
|
||||
|
||||
result = crew.kickoff(inputs=inputs)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_multiple_actions_example():
|
||||
"""Test crew that can handle multiple different actions."""
|
||||
|
||||
researcher = Agent(role="Researcher", goal="Research topics", backstory="Expert researcher")
|
||||
analyst = Agent(role="Analyst", goal="Analyze data", backstory="Expert analyst")
|
||||
writer = Agent(role="Writer", goal="Write reports", backstory="Expert writer")
|
||||
|
||||
research_task = Task(description="Research the topic", expected_output="Research findings", agent=researcher, tags=["research", "data_gathering"])
|
||||
analysis_task = Task(description="Analyze the data", expected_output="Analysis results", agent=analyst, tags=["analysis", "forecast"])
|
||||
writing_task = Task(description="Write the report", expected_output="Final report", agent=writer, tags=["writing", "summary"])
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, analyst, writer],
|
||||
tasks=[research_task, analysis_task, writing_task],
|
||||
task_selector=Crew.create_tag_selector()
|
||||
)
|
||||
|
||||
research_result = crew.kickoff(inputs={"action": "research"})
|
||||
assert research_result is not None
|
||||
|
||||
analysis_result = crew.kickoff(inputs={"action": "analysis"})
|
||||
assert analysis_result is not None
|
||||
|
||||
writing_result = crew.kickoff(inputs={"action": "writing"})
|
||||
assert writing_result is not None
|
||||
Reference in New Issue
Block a user