mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-05 06:08:29 +00:00
Compare commits
6 Commits
feature/re
...
lj/conditi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0cc37e0d72 | ||
|
|
bb33e1813d | ||
|
|
96dc96d13c | ||
|
|
6efbe8c5a5 | ||
|
|
4e8f69a7b0 | ||
|
|
60d0f56e2d |
37
src/crewai/conditional_task.py
Normal file
37
src/crewai/conditional_task.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Callable, Optional, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ConditionalTask(Task):
|
||||
"""
|
||||
A task that can be conditionally executed based on the output of another task.
|
||||
Note: This cannot be the only task you have in your crew and cannot be the first since its needs context from the previous task.
|
||||
"""
|
||||
|
||||
condition: Optional[Callable[[BaseModel], bool]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
condition: Optional[Callable[[BaseModel], bool]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.condition = condition
|
||||
|
||||
def should_execute(self, context: Any) -> bool:
|
||||
"""
|
||||
Determines whether the conditional task should be executed based on the provided context.
|
||||
|
||||
Args:
|
||||
context (Any): The context or output from the previous task that will be evaluated by the condition.
|
||||
|
||||
Returns:
|
||||
bool: True if the task should be executed, False otherwise.
|
||||
"""
|
||||
if self.condition:
|
||||
return self.condition(context)
|
||||
return True
|
||||
@@ -21,6 +21,7 @@ from pydantic_core import PydanticCustomError
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.conditional_task import ConditionalTask
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
@@ -223,6 +224,17 @@ class Crew(BaseModel):
|
||||
agent.set_rpm_controller(self._rpm_controller)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_first_task(self) -> "Crew":
|
||||
"""Ensure the first task is not a ConditionalTask."""
|
||||
if self.tasks and isinstance(self.tasks[0], ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
"invalid_first_task",
|
||||
"The first task cannot be a ConditionalTask.",
|
||||
{},
|
||||
)
|
||||
return self
|
||||
|
||||
def _setup_from_config(self):
|
||||
assert self.config is not None, "Config should not be None."
|
||||
|
||||
@@ -397,7 +409,27 @@ class Crew(BaseModel):
|
||||
futures: List[Tuple[Task, Future[TaskOutput]]] = []
|
||||
|
||||
for task in self.tasks:
|
||||
if task.agent.allow_delegation: # type: ignore # Item "None" of "Agent | None" has no attribute "allow_delegation"
|
||||
if isinstance(task, ConditionalTask):
|
||||
if futures:
|
||||
task_outputs = []
|
||||
for future_task, future in futures:
|
||||
task_output = future.result()
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(future_task, task_output)
|
||||
futures.clear()
|
||||
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(
|
||||
previous_output.result()
|
||||
):
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
continue
|
||||
|
||||
if task.agent and task.agent.allow_delegation:
|
||||
agents_for_delegation = [
|
||||
agent for agent in self.agents if agent != task.agent
|
||||
]
|
||||
@@ -438,9 +470,8 @@ class Crew(BaseModel):
|
||||
task_output = task.execute_sync(
|
||||
agent=task.agent, context=context, tools=task.tools
|
||||
)
|
||||
task_outputs = [task_output]
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
|
||||
if futures:
|
||||
# Clear task_outputs before processing async tasks
|
||||
task_outputs = []
|
||||
@@ -451,8 +482,14 @@ class Crew(BaseModel):
|
||||
|
||||
final_string_output = aggregate_raw_outputs_from_task_outputs(task_outputs)
|
||||
self._finish_execution(final_string_output)
|
||||
|
||||
token_usage = self.calculate_usage_metrics()
|
||||
# TODO: need to revert
|
||||
# token_usage = self.calculate_usage_metrics()
|
||||
token_usage = {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
}
|
||||
|
||||
return self._format_output(task_outputs, token_usage)
|
||||
|
||||
@@ -595,9 +632,17 @@ class Crew(BaseModel):
|
||||
"""
|
||||
Formats the output of the crew execution.
|
||||
"""
|
||||
|
||||
# breakpoint()
|
||||
task_output = []
|
||||
for task in self.tasks:
|
||||
if task.output:
|
||||
# print("task.output", task.output)
|
||||
task_output.append(task.output.result())
|
||||
return CrewOutput(
|
||||
output=output,
|
||||
tasks_output=[task.output for task in self.tasks if task],
|
||||
# tasks_output=[task.output for task in self.tasks if task],
|
||||
tasks_output=task_output,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,11 @@ from crewai.utilities.formatter import aggregate_raw_outputs_from_task_outputs
|
||||
|
||||
class CrewOutput(BaseModel):
|
||||
output: List[TaskOutput] = Field(description="Result of the final task")
|
||||
tasks_output: list[TaskOutput] = Field(
|
||||
# NOTE HERE
|
||||
# tasks_output: list[TaskOutput] = Field(
|
||||
# description="Output of each task", default=[]
|
||||
# )
|
||||
tasks_output: list[Union[str, BaseModel, Dict[str, Any]]] = Field(
|
||||
description="Output of each task", default=[]
|
||||
)
|
||||
token_usage: Dict[str, Any] = Field(
|
||||
@@ -18,7 +22,7 @@ class CrewOutput(BaseModel):
|
||||
# TODO: Ask @joao what is the desired behavior here
|
||||
def result(
|
||||
self,
|
||||
) -> List[str | BaseModel | Dict[str, Any]]]:
|
||||
) -> List[str | BaseModel | Dict[str, Any]]:
|
||||
"""Return the result of the task based on the available output."""
|
||||
results = [output.result() for output in self.output]
|
||||
return results
|
||||
|
||||
@@ -8,6 +8,8 @@ class Printer:
|
||||
self._print_bold_green(content)
|
||||
elif color == "bold_purple":
|
||||
self._print_bold_purple(content)
|
||||
elif color == "yellow":
|
||||
self._print_yellow(content)
|
||||
else:
|
||||
print(content)
|
||||
|
||||
@@ -22,3 +24,6 @@ class Printer:
|
||||
|
||||
def _print_red(self, content):
|
||||
print("\033[91m {}\033[00m".format(content))
|
||||
|
||||
def _print_yellow(self, content):
|
||||
print("\033[93m {}\033[00m".format(content))
|
||||
|
||||
2683
tests/cassettes/test_conditional_skips_task_if_condition_is_met.yaml
Normal file
2683
tests/cassettes/test_conditional_skips_task_if_condition_is_met.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,13 +3,14 @@
|
||||
import json
|
||||
from concurrent.futures import Future
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pydantic_core
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.conditional_task import ConditionalTask
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
@@ -559,7 +560,6 @@ def test_hierarchical_async_task_execution_completion():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_single_task_with_async_execution():
|
||||
|
||||
researcher_agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Make the best research and analysis on content about AI and AI agents",
|
||||
@@ -713,7 +713,6 @@ def test_async_task_execution_call_count():
|
||||
) as mock_execute_sync, patch.object(
|
||||
Task, "execute_async", return_value=mock_future
|
||||
) as mock_execute_async:
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert mock_execute_async.call_count == 2
|
||||
@@ -1135,8 +1134,6 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_delegation_is_not_enabled_if_there_are_only_one_agent():
|
||||
from unittest.mock import patch
|
||||
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="Make the best research and analysis on content about AI and AI agents",
|
||||
@@ -1204,6 +1201,82 @@ def test_agent_usage_metrics_are_captured_for_sequential_process():
|
||||
assert crew.usage_metrics[key] > 0, f"Value for key '{key}' is zero"
|
||||
|
||||
|
||||
def test_conditional_task_requirement_breaks_when_singular_conditional_task():
|
||||
task = ConditionalTask(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
)
|
||||
|
||||
with pytest.raises(pydantic_core._pydantic_core.ValidationError):
|
||||
Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_conditional_should_not_execute():
|
||||
task1 = Task(description="Return hello", expected_output="say hi", agent=researcher)
|
||||
|
||||
condition_mock = MagicMock(return_value=False)
|
||||
task2 = ConditionalTask(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
condition=condition_mock,
|
||||
agent=writer,
|
||||
)
|
||||
crew_met = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2],
|
||||
)
|
||||
with patch.object(Task, "execute_sync") as mock_execute_sync:
|
||||
mock_execute_sync.return_value = TaskOutput(
|
||||
description="Task 1 description",
|
||||
raw_output="Task 1 output",
|
||||
agent="Researcher",
|
||||
)
|
||||
|
||||
result = crew_met.kickoff()
|
||||
assert mock_execute_sync.call_count == 1
|
||||
|
||||
assert condition_mock.call_count == 1
|
||||
assert condition_mock() is False
|
||||
|
||||
assert task2.output is None
|
||||
assert result.raw_output().startswith("Task 1 output")
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_conditional_should_execute():
|
||||
task1 = Task(description="Return hello", expected_output="say hi", agent=researcher)
|
||||
|
||||
condition_mock = MagicMock(
|
||||
return_value=True
|
||||
) # should execute this conditional task
|
||||
task2 = ConditionalTask(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
condition=condition_mock,
|
||||
agent=writer,
|
||||
)
|
||||
crew_met = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2],
|
||||
)
|
||||
with patch.object(Task, "execute_sync") as mock_execute_sync:
|
||||
mock_execute_sync.return_value = TaskOutput(
|
||||
description="Task 1 description",
|
||||
raw_output="Task 1 output",
|
||||
agent="Researcher",
|
||||
)
|
||||
|
||||
crew_met.kickoff()
|
||||
|
||||
assert condition_mock.call_count == 1
|
||||
assert condition_mock() is True
|
||||
assert mock_execute_sync.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_usage_metrics_are_captured_for_hierarchical_process():
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
Reference in New Issue
Block a user