cleaner code

This commit is contained in:
Lorenze Jay
2024-07-11 17:22:42 -07:00
parent a9873ff90d
commit e1589befb4
2 changed files with 148 additions and 65 deletions

View File

@@ -87,6 +87,9 @@ class Crew(BaseModel):
_train: Optional[bool] = PrivateAttr(default=False)
_train_iteration: Optional[int] = PrivateAttr()
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_logging_color: str = PrivateAttr(
default="bold_purple",
)
cache: bool = Field(default=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -539,11 +542,11 @@ class Crew(BaseModel):
"inputs": inputs,
"was_replayed": was_replayed,
}
if task_index < len(self.execution_logs):
self.execution_logs[task_index] = log
else:
self.execution_logs.append(log)
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).update(task_index, log)
def _run_sequential_process(self) -> CrewOutput:
@@ -578,6 +581,8 @@ class Crew(BaseModel):
self,
tasks: List[Task],
manager: Optional[BaseAgent] = None,
start_index: Optional[int] = 0,
was_replayed: bool = False,
) -> CrewOutput:
"""Executes tasks sequentially and returns the final output.
@@ -590,8 +595,18 @@ class Crew(BaseModel):
"""
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
self.execution_logs = []
last_sync_output: Optional[TaskOutput] = None
for task_index, task in enumerate(tasks):
if start_index is not None and task_index < start_index:
if task.output:
if task.async_execution:
task_outputs.append(task.output)
else:
task_outputs = [task.output]
last_sync_output = task.output
continue
self._prepare_task(task, manager)
agent_to_use = task.agent if task.agent else manager
if agent_to_use is None:
@@ -599,9 +614,10 @@ class Crew(BaseModel):
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
)
self._log_task_start(task, agent_to_use)
if task.async_execution:
context = self._set_context(task, task_outputs)
context = self._set_context(
task, [last_sync_output] if last_sync_output else []
)
future = task.execute_async(
agent=agent_to_use,
context=context,
@@ -610,7 +626,9 @@ class Crew(BaseModel):
futures.append((task, future, task_index))
else:
if futures:
task_outputs = self._process_async_tasks(futures)
task_outputs.extend(
self._process_async_tasks(futures, was_replayed)
)
futures.clear()
context = self._set_context(task, task_outputs)
@@ -621,10 +639,10 @@ class Crew(BaseModel):
)
task_outputs = [task_output]
self._process_task_result(task, task_output)
self._store_execution_log(task, task_output, task_index)
self._store_execution_log(task, task_output, task_index, was_replayed)
if futures:
task_outputs = self._process_async_tasks(futures)
task_outputs = self._process_async_tasks(futures, was_replayed)
return self._create_crew_output(task_outputs)
@@ -637,13 +655,10 @@ class Crew(BaseModel):
def _add_delegation_tools(self, task: Task):
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and agents_for_delegation:
task.tools += task.agent.get_delegation_tools(
agents_for_delegation
) # TODO: FIX TYPE ERROR HERE
task.tools += task.agent.get_delegation_tools(agents_for_delegation) # type: ignore
def _log_task_start(
self, task: Task, agent: Optional[BaseAgent], color: str = "bold_purple"
):
def _log_task_start(self, task: Task, agent: Optional[BaseAgent]):
color = self._logging_color
role = agent.role if agent else "None"
self._logger.log("debug", f"== Working Agent: {role}", color=color)
self._logger.log("info", f"== Starting Task: {task.description}", color=color)
@@ -724,70 +739,35 @@ class Crew(BaseModel):
if start_index is None:
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
# inputs can be overrided with new passed inputs
replay_inputs = (
inputs
if inputs is not None
else stored_outputs[start_index].get("inputs", {})
)
self._inputs = replay_inputs
if replay_inputs:
self._interpolate_inputs(replay_inputs)
if self.process == Process.hierarchical:
self._create_manager_agent()
for task_index, task in enumerate(self.tasks):
if task_index < start_index: # we are skipping this task
stored_output = stored_outputs[task_index]["output"]
task_output = TaskOutput(
description=stored_output["description"],
agent=stored_output["agent"],
raw=stored_output["raw"],
pydantic=stored_output["pydantic"],
json_dict=stored_output["json_dict"],
output_format=stored_output["output_format"],
)
self.tasks[task_index].output = task_output
task_outputs = [task_output]
else:
self._prepare_task(task, self.manager_agent)
agent_to_use = task.agent if task.agent else self.manager_agent
if agent_to_use is None:
raise ValueError(
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
)
self._log_task_start(task, agent_to_use, "bold_blue")
if task.async_execution:
context = self._set_context(task, task_outputs)
future = task.execute_async(
agent=agent_to_use, context=context, tools=task.tools
)
futures.append((task, future, task_index))
else:
if futures:
task_outputs = self._process_async_tasks(futures, True)
futures.clear()
for i in range(start_index):
stored_output = stored_outputs[i]["output"]
task_output = TaskOutput(
description=stored_output["description"],
agent=stored_output["agent"],
raw=stored_output["raw"],
pydantic=stored_output["pydantic"],
json_dict=stored_output["json_dict"],
output_format=stored_output["output_format"],
)
self.tasks[i].output = task_output
context = self._set_context(task, task_outputs)
task_output = task.execute_sync(
agent=agent_to_use, context=context, tools=task.tools
)
task_outputs = [task_output]
self._process_task_result(task, task_output)
self._store_execution_log(
task, task_output, task_index, was_replayed=True
)
# Process any remaining async tasks
if futures:
task_outputs = self._process_async_tasks(futures, True)
return self._create_crew_output(task_outputs)
self._logging_color = "bold_blue"
result = self._execute_tasks(self.tasks, self.manager_agent, start_index)
self._logging_color = "bold_purple"
return result
def copy(self):
"""Create a deep copy of the Crew."""