mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
WIP for conditional tasks
This commit is contained in:
25
src/crewai/conditional_task.py
Normal file
25
src/crewai/conditional_task.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class ConditionalTask(Task):
|
||||
condition: Optional[Callable[[BaseModel], bool]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
condition: Optional[Callable[[BaseModel], bool]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.condition = condition
|
||||
|
||||
def should_execute(self, context: TaskOutput) -> bool:
|
||||
print("TaskOutput", TaskOutput)
|
||||
if self.condition:
|
||||
return self.condition(context)
|
||||
return True
|
||||
@@ -21,6 +21,7 @@ from pydantic_core import PydanticCustomError
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.conditional_task import ConditionalTask
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
@@ -396,7 +397,19 @@ class Crew(BaseModel):
|
||||
task_outputs: List[TaskOutput] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput]]] = []
|
||||
|
||||
for task in self.tasks:
|
||||
for i, task in enumerate(self.tasks):
|
||||
if isinstance(task, ConditionalTask):
|
||||
# print("task_outputs", task_outputs)
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
# print("previous_output type", type(previous_output))
|
||||
if previous_output is not None:
|
||||
if not task.should_execute(previous_output):
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
continue
|
||||
if task.agent.allow_delegation: # type: ignore # Item "None" of "Agent | None" has no attribute "allow_delegation"
|
||||
agents_for_delegation = [
|
||||
agent for agent in self.agents if agent != task.agent
|
||||
@@ -438,9 +451,9 @@ class Crew(BaseModel):
|
||||
task_output = task.execute_sync(
|
||||
agent=task.agent, context=context, tools=task.tools
|
||||
)
|
||||
task_outputs = [task_output]
|
||||
print("task executed res:", task_output)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
|
||||
if futures:
|
||||
# Clear task_outputs before processing async tasks
|
||||
task_outputs = []
|
||||
@@ -451,8 +464,14 @@ class Crew(BaseModel):
|
||||
|
||||
final_string_output = aggregate_raw_outputs_from_task_outputs(task_outputs)
|
||||
self._finish_execution(final_string_output)
|
||||
|
||||
token_usage = self.calculate_usage_metrics()
|
||||
# TODO: need to revert
|
||||
# token_usage = self.calculate_usage_metrics()
|
||||
token_usage = {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
}
|
||||
|
||||
return self._format_output(task_outputs, token_usage)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from copy import copy, deepcopy
|
||||
from copy import copy
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
@@ -224,6 +224,7 @@ class Task(BaseModel):
|
||||
tools=tools,
|
||||
)
|
||||
exported_output = self._export_output(result)
|
||||
print("exported_output", exported_output["pydantic"])
|
||||
|
||||
task_output = TaskOutput(
|
||||
description=self.description,
|
||||
@@ -232,6 +233,7 @@ class Task(BaseModel):
|
||||
json_output=exported_output["json"],
|
||||
agent=agent.role,
|
||||
)
|
||||
print("task_output", task_output)
|
||||
self.output = task_output
|
||||
|
||||
if self.callback:
|
||||
@@ -311,8 +313,8 @@ class Task(BaseModel):
|
||||
self, result: str
|
||||
) -> Dict[str, Union[BaseModel, Dict[str, Any]]]:
|
||||
output = {
|
||||
"pydantic": None,
|
||||
"json": None,
|
||||
"pydantic": self.output_pydantic() if self.output_pydantic else None,
|
||||
"json": {},
|
||||
}
|
||||
|
||||
if self.output_pydantic or self.output_json:
|
||||
|
||||
@@ -29,10 +29,13 @@ class TaskOutput(BaseModel):
|
||||
def result(self) -> Union[str, BaseModel, Dict[str, Any]]:
|
||||
"""Return the result of the task based on the available output."""
|
||||
if self.pydantic_output:
|
||||
print("returns pydantic_output", self.pydantic_output)
|
||||
return self.pydantic_output
|
||||
elif self.json_output:
|
||||
print("returns json_output", self.json_output)
|
||||
return self.json_output
|
||||
else:
|
||||
print("return string out")
|
||||
return self.raw_output
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
|
||||
@@ -8,6 +8,8 @@ class Printer:
|
||||
self._print_bold_green(content)
|
||||
elif color == "bold_purple":
|
||||
self._print_bold_purple(content)
|
||||
elif color == "yellow":
|
||||
self._print_yellow(content)
|
||||
else:
|
||||
print(content)
|
||||
|
||||
@@ -22,3 +24,6 @@ class Printer:
|
||||
|
||||
def _print_red(self, content):
|
||||
print("\033[91m {}\033[00m".format(content))
|
||||
|
||||
def _print_yellow(self, content):
|
||||
print("\033[93m {}\033[00m".format(content))
|
||||
|
||||
Reference in New Issue
Block a user