WIP for conditional tasks

This commit is contained in:
Lorenze Jay
2024-07-02 09:06:15 -07:00
parent 053d8a0449
commit 60d0f56e2d
5 changed files with 62 additions and 8 deletions

View File

@@ -0,0 +1,25 @@
from typing import Callable, Optional
from pydantic import BaseModel
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
class ConditionalTask(Task):
condition: Optional[Callable[[BaseModel], bool]] = None
def __init__(
self,
*args,
condition: Optional[Callable[[BaseModel], bool]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.condition = condition
def should_execute(self, context: TaskOutput) -> bool:
print("TaskOutput", TaskOutput)
if self.condition:
return self.condition(context)
return True

View File

@@ -21,6 +21,7 @@ from pydantic_core import PydanticCustomError
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache import CacheHandler
from crewai.conditional_task import ConditionalTask
from crewai.crews.crew_output import CrewOutput
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
@@ -396,7 +397,19 @@ class Crew(BaseModel):
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput]]] = []
for task in self.tasks:
for i, task in enumerate(self.tasks):
if isinstance(task, ConditionalTask):
# print("task_outputs", task_outputs)
previous_output = task_outputs[-1] if task_outputs else None
# print("previous_output type", type(previous_output))
if previous_output is not None:
if not task.should_execute(previous_output):
self._logger.log(
"info",
f"Skipping conditional task: {task.description}",
color="yellow",
)
continue
if task.agent.allow_delegation: # type: ignore # Item "None" of "Agent | None" has no attribute "allow_delegation"
agents_for_delegation = [
agent for agent in self.agents if agent != task.agent
@@ -438,9 +451,9 @@ class Crew(BaseModel):
task_output = task.execute_sync(
agent=task.agent, context=context, tools=task.tools
)
task_outputs = [task_output]
print("task executed res:", task_output)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
if futures:
# Clear task_outputs before processing async tasks
task_outputs = []
@@ -451,8 +464,14 @@ class Crew(BaseModel):
final_string_output = aggregate_raw_outputs_from_task_outputs(task_outputs)
self._finish_execution(final_string_output)
token_usage = self.calculate_usage_metrics()
# TODO: need to revert
# token_usage = self.calculate_usage_metrics()
token_usage = {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"successful_requests": 0,
}
return self._format_output(task_outputs, token_usage)

View File

@@ -3,7 +3,7 @@ import re
import threading
import uuid
from concurrent.futures import Future
from copy import copy, deepcopy
from copy import copy
from typing import Any, Dict, List, Optional, Type, Union
from langchain_openai import ChatOpenAI
@@ -224,6 +224,7 @@ class Task(BaseModel):
tools=tools,
)
exported_output = self._export_output(result)
print("exported_output", exported_output["pydantic"])
task_output = TaskOutput(
description=self.description,
@@ -232,6 +233,7 @@ class Task(BaseModel):
json_output=exported_output["json"],
agent=agent.role,
)
print("task_output", task_output)
self.output = task_output
if self.callback:
@@ -311,8 +313,8 @@ class Task(BaseModel):
self, result: str
) -> Dict[str, Union[BaseModel, Dict[str, Any]]]:
output = {
"pydantic": None,
"json": None,
"pydantic": self.output_pydantic() if self.output_pydantic else None,
"json": {},
}
if self.output_pydantic or self.output_json:

View File

@@ -29,10 +29,13 @@ class TaskOutput(BaseModel):
def result(self) -> Union[str, BaseModel, Dict[str, Any]]:
"""Return the result of the task based on the available output."""
if self.pydantic_output:
print("returns pydantic_output", self.pydantic_output)
return self.pydantic_output
elif self.json_output:
print("returns json_output", self.json_output)
return self.json_output
else:
print("return string out")
return self.raw_output
def __getitem__(self, key: str) -> Any:

View File

@@ -8,6 +8,8 @@ class Printer:
self._print_bold_green(content)
elif color == "bold_purple":
self._print_bold_purple(content)
elif color == "yellow":
self._print_yellow(content)
else:
print(content)
@@ -22,3 +24,6 @@ class Printer:
def _print_red(self, content):
print("\033[91m {}\033[00m".format(content))
def _print_yellow(self, content):
print("\033[93m {}\033[00m".format(content))