WIP for conditional tasks

This commit is contained in:
Lorenze Jay
2024-07-02 09:06:15 -07:00
parent 053d8a0449
commit 60d0f56e2d
5 changed files with 62 additions and 8 deletions

View File

@@ -0,0 +1,25 @@
from typing import Callable, Optional
from pydantic import BaseModel
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
class ConditionalTask(Task):
condition: Optional[Callable[[BaseModel], bool]] = None
def __init__(
self,
*args,
condition: Optional[Callable[[BaseModel], bool]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.condition = condition
def should_execute(self, context: TaskOutput) -> bool:
print("TaskOutput", TaskOutput)
if self.condition:
return self.condition(context)
return True

View File

@@ -21,6 +21,7 @@ from pydantic_core import PydanticCustomError
from crewai.agent import Agent from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.conditional_task import ConditionalTask
from crewai.crews.crew_output import CrewOutput from crewai.crews.crew_output import CrewOutput
from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.long_term.long_term_memory import LongTermMemory
@@ -396,7 +397,19 @@ class Crew(BaseModel):
task_outputs: List[TaskOutput] = [] task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput]]] = [] futures: List[Tuple[Task, Future[TaskOutput]]] = []
for task in self.tasks: for i, task in enumerate(self.tasks):
if isinstance(task, ConditionalTask):
# print("task_outputs", task_outputs)
previous_output = task_outputs[-1] if task_outputs else None
# print("previous_output type", type(previous_output))
if previous_output is not None:
if not task.should_execute(previous_output):
self._logger.log(
"info",
f"Skipping conditional task: {task.description}",
color="yellow",
)
continue
if task.agent.allow_delegation: # type: ignore # Item "None" of "Agent | None" has no attribute "allow_delegation" if task.agent.allow_delegation: # type: ignore # Item "None" of "Agent | None" has no attribute "allow_delegation"
agents_for_delegation = [ agents_for_delegation = [
agent for agent in self.agents if agent != task.agent agent for agent in self.agents if agent != task.agent
@@ -438,9 +451,9 @@ class Crew(BaseModel):
task_output = task.execute_sync( task_output = task.execute_sync(
agent=task.agent, context=context, tools=task.tools agent=task.agent, context=context, tools=task.tools
) )
task_outputs = [task_output] print("task executed res:", task_output)
task_outputs.append(task_output)
self._process_task_result(task, task_output) self._process_task_result(task, task_output)
if futures: if futures:
# Clear task_outputs before processing async tasks # Clear task_outputs before processing async tasks
task_outputs = [] task_outputs = []
@@ -451,8 +464,14 @@ class Crew(BaseModel):
final_string_output = aggregate_raw_outputs_from_task_outputs(task_outputs) final_string_output = aggregate_raw_outputs_from_task_outputs(task_outputs)
self._finish_execution(final_string_output) self._finish_execution(final_string_output)
# TODO: need to revert
token_usage = self.calculate_usage_metrics() # token_usage = self.calculate_usage_metrics()
token_usage = {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"successful_requests": 0,
}
return self._format_output(task_outputs, token_usage) return self._format_output(task_outputs, token_usage)

View File

@@ -3,7 +3,7 @@ import re
import threading import threading
import uuid import uuid
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy, deepcopy from copy import copy
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
@@ -224,6 +224,7 @@ class Task(BaseModel):
tools=tools, tools=tools,
) )
exported_output = self._export_output(result) exported_output = self._export_output(result)
print("exported_output", exported_output["pydantic"])
task_output = TaskOutput( task_output = TaskOutput(
description=self.description, description=self.description,
@@ -232,6 +233,7 @@ class Task(BaseModel):
json_output=exported_output["json"], json_output=exported_output["json"],
agent=agent.role, agent=agent.role,
) )
print("task_output", task_output)
self.output = task_output self.output = task_output
if self.callback: if self.callback:
@@ -311,8 +313,8 @@ class Task(BaseModel):
self, result: str self, result: str
) -> Dict[str, Union[BaseModel, Dict[str, Any]]]: ) -> Dict[str, Union[BaseModel, Dict[str, Any]]]:
output = { output = {
"pydantic": None, "pydantic": self.output_pydantic() if self.output_pydantic else None,
"json": None, "json": {},
} }
if self.output_pydantic or self.output_json: if self.output_pydantic or self.output_json:

View File

@@ -29,10 +29,13 @@ class TaskOutput(BaseModel):
def result(self) -> Union[str, BaseModel, Dict[str, Any]]: def result(self) -> Union[str, BaseModel, Dict[str, Any]]:
"""Return the result of the task based on the available output.""" """Return the result of the task based on the available output."""
if self.pydantic_output: if self.pydantic_output:
print("returns pydantic_output", self.pydantic_output)
return self.pydantic_output return self.pydantic_output
elif self.json_output: elif self.json_output:
print("returns json_output", self.json_output)
return self.json_output return self.json_output
else: else:
print("return string out")
return self.raw_output return self.raw_output
def __getitem__(self, key: str) -> Any: def __getitem__(self, key: str) -> Any:

View File

@@ -8,6 +8,8 @@ class Printer:
self._print_bold_green(content) self._print_bold_green(content)
elif color == "bold_purple": elif color == "bold_purple":
self._print_bold_purple(content) self._print_bold_purple(content)
elif color == "yellow":
self._print_yellow(content)
else: else:
print(content) print(content)
@@ -22,3 +24,6 @@ class Printer:
def _print_red(self, content): def _print_red(self, content):
print("\033[91m {}\033[00m".format(content)) print("\033[91m {}\033[00m".format(content))
def _print_yellow(self, content):
print("\033[93m {}\033[00m".format(content))