mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
conditional task feat
This commit is contained in:
@@ -2,5 +2,6 @@ from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.conditional_task import ConditionalTask
|
||||
|
||||
__all__ = ["Agent", "Crew", "Process", "Task"]
|
||||
__all__ = ["Agent", "Crew", "Process", "Task", "ConditionalTask"]
|
||||
|
||||
35
src/crewai/conditional_task.py
Normal file
35
src/crewai/conditional_task.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Callable, Optional, Any
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class ConditionalTask(Task):
|
||||
"""
|
||||
A task that can be conditionally executed based on the output of another task.
|
||||
Note: This cannot be the only task you have in your crew and cannot be the first since its needs context from the previous task.
|
||||
"""
|
||||
|
||||
condition: Optional[Callable[[TaskOutput], bool]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
condition: Optional[Callable[[TaskOutput], bool]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.condition = condition
|
||||
|
||||
def should_execute(self, context: Any) -> bool:
|
||||
"""
|
||||
Determines whether the conditional task should be executed based on the provided context.
|
||||
|
||||
Args:
|
||||
context (Any): The context or output from the previous task that will be evaluated by the condition.
|
||||
|
||||
Returns:
|
||||
bool: True if the task should be executed, False otherwise.
|
||||
"""
|
||||
if self.condition:
|
||||
return self.condition(context)
|
||||
return True
|
||||
@@ -28,6 +28,8 @@ from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.conditional_task import ConditionalTask
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry import Telemetry
|
||||
from crewai.tools.agent_tools import AgentTools
|
||||
@@ -295,6 +297,29 @@ class Crew(BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_first_task(self) -> "Crew":
|
||||
"""Ensure the first task is not a ConditionalTask."""
|
||||
if self.tasks and isinstance(self.tasks[0], ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
"invalid_first_task",
|
||||
"The first task cannot be a ConditionalTask.",
|
||||
{},
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_async_tasks_not_async(self) -> "Crew":
|
||||
"""Ensure the first task is not a ConditionalTask."""
|
||||
for task in self.tasks:
|
||||
if task.async_execution and isinstance(task, ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
"invalid_async_conditional_task",
|
||||
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
{},
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_async_task_cannot_include_sequential_async_tasks_in_context(self):
|
||||
"""
|
||||
@@ -622,7 +647,35 @@ class Crew(BaseModel):
|
||||
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
|
||||
)
|
||||
self._log_task_start(task, agent_to_use)
|
||||
if isinstance(task, ConditionalTask):
|
||||
if futures:
|
||||
task_outputs.extend(
|
||||
self._process_async_tasks(futures, was_replayed)
|
||||
)
|
||||
futures.clear()
|
||||
|
||||
previous_output = task_outputs[task_index - 1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(
|
||||
previous_output
|
||||
):
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
skipped_task_output = TaskOutput(
|
||||
description=task.description,
|
||||
raw="",
|
||||
agent=task.agent.role if task.agent else "",
|
||||
output_format=OutputFormat.RAW,
|
||||
)
|
||||
if not was_replayed:
|
||||
self._store_execution_log(
|
||||
task,
|
||||
skipped_task_output,
|
||||
task_index,
|
||||
)
|
||||
continue
|
||||
if task.async_execution:
|
||||
context = self._get_context(
|
||||
task, [last_sync_output] if last_sync_output else []
|
||||
@@ -687,6 +740,34 @@ class Crew(BaseModel):
|
||||
# Add the new tool
|
||||
task.tools.append(new_tool)
|
||||
|
||||
def _handle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
task_outputs: List[TaskOutput],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> bool:
|
||||
"""
|
||||
Handle conditional task execution.
|
||||
|
||||
Returns:
|
||||
bool: True if the task should be executed, False if it should be skipped.
|
||||
"""
|
||||
if futures:
|
||||
task_outputs.extend(self._process_async_tasks(futures, was_replayed))
|
||||
futures.clear()
|
||||
|
||||
previous_output = task_outputs[task_index - 1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _log_task_start(self, task: Task, agent: Optional[BaseAgent]):
|
||||
color = self._logging_color
|
||||
role = agent.role if agent else "None"
|
||||
|
||||
@@ -10,6 +10,8 @@ class Printer:
|
||||
self._print_bold_purple(content)
|
||||
elif color == "bold_blue":
|
||||
self._print_bold_blue(content)
|
||||
elif color == "yellow":
|
||||
self._print_yellow(content)
|
||||
else:
|
||||
print(content)
|
||||
|
||||
@@ -27,3 +29,6 @@ class Printer:
|
||||
|
||||
def _print_bold_blue(self, content):
|
||||
print("\033[1m\033[94m {}\033[00m".format(content))
|
||||
|
||||
def _print_yellow(self, content):
|
||||
print("\033[93m {}\033[00m".format(content))
|
||||
|
||||
Reference in New Issue
Block a user