cleaner code

This commit is contained in:
Lorenze Jay
2024-07-11 17:22:42 -07:00
parent a9873ff90d
commit e1589befb4
2 changed files with 148 additions and 65 deletions

View File

@@ -87,6 +87,9 @@ class Crew(BaseModel):
_train: Optional[bool] = PrivateAttr(default=False)
_train_iteration: Optional[int] = PrivateAttr()
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_logging_color: str = PrivateAttr(
default="bold_purple",
)
cache: bool = Field(default=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -539,11 +542,11 @@ class Crew(BaseModel):
"inputs": inputs,
"was_replayed": was_replayed,
}
if task_index < len(self.execution_logs):
self.execution_logs[task_index] = log
else:
self.execution_logs.append(log)
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).update(task_index, log)
def _run_sequential_process(self) -> CrewOutput:
@@ -578,6 +581,8 @@ class Crew(BaseModel):
self,
tasks: List[Task],
manager: Optional[BaseAgent] = None,
start_index: Optional[int] = 0,
was_replayed: bool = False,
) -> CrewOutput:
"""Executes tasks sequentially and returns the final output.
@@ -590,8 +595,18 @@ class Crew(BaseModel):
"""
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
self.execution_logs = []
last_sync_output: Optional[TaskOutput] = None
for task_index, task in enumerate(tasks):
if start_index is not None and task_index < start_index:
if task.output:
if task.async_execution:
task_outputs.append(task.output)
else:
task_outputs = [task.output]
last_sync_output = task.output
continue
self._prepare_task(task, manager)
agent_to_use = task.agent if task.agent else manager
if agent_to_use is None:
@@ -599,9 +614,10 @@ class Crew(BaseModel):
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
)
self._log_task_start(task, agent_to_use)
if task.async_execution:
context = self._set_context(task, task_outputs)
context = self._set_context(
task, [last_sync_output] if last_sync_output else []
)
future = task.execute_async(
agent=agent_to_use,
context=context,
@@ -610,7 +626,9 @@ class Crew(BaseModel):
futures.append((task, future, task_index))
else:
if futures:
task_outputs = self._process_async_tasks(futures)
task_outputs.extend(
self._process_async_tasks(futures, was_replayed)
)
futures.clear()
context = self._set_context(task, task_outputs)
@@ -621,10 +639,10 @@ class Crew(BaseModel):
)
task_outputs = [task_output]
self._process_task_result(task, task_output)
self._store_execution_log(task, task_output, task_index)
self._store_execution_log(task, task_output, task_index, was_replayed)
if futures:
task_outputs = self._process_async_tasks(futures)
task_outputs = self._process_async_tasks(futures, was_replayed)
return self._create_crew_output(task_outputs)
@@ -637,13 +655,10 @@ class Crew(BaseModel):
def _add_delegation_tools(self, task: Task):
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and agents_for_delegation:
task.tools += task.agent.get_delegation_tools(
agents_for_delegation
) # TODO: FIX TYPE ERROR HERE
task.tools += task.agent.get_delegation_tools(agents_for_delegation) # type: ignore
def _log_task_start(
self, task: Task, agent: Optional[BaseAgent], color: str = "bold_purple"
):
def _log_task_start(self, task: Task, agent: Optional[BaseAgent]):
color = self._logging_color
role = agent.role if agent else "None"
self._logger.log("debug", f"== Working Agent: {role}", color=color)
self._logger.log("info", f"== Starting Task: {task.description}", color=color)
@@ -724,70 +739,35 @@ class Crew(BaseModel):
if start_index is None:
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
# inputs can be overrided with new passed inputs
replay_inputs = (
inputs
if inputs is not None
else stored_outputs[start_index].get("inputs", {})
)
self._inputs = replay_inputs
if replay_inputs:
self._interpolate_inputs(replay_inputs)
if self.process == Process.hierarchical:
self._create_manager_agent()
for task_index, task in enumerate(self.tasks):
if task_index < start_index: # we are skipping this task
stored_output = stored_outputs[task_index]["output"]
task_output = TaskOutput(
description=stored_output["description"],
agent=stored_output["agent"],
raw=stored_output["raw"],
pydantic=stored_output["pydantic"],
json_dict=stored_output["json_dict"],
output_format=stored_output["output_format"],
)
self.tasks[task_index].output = task_output
task_outputs = [task_output]
else:
self._prepare_task(task, self.manager_agent)
agent_to_use = task.agent if task.agent else self.manager_agent
if agent_to_use is None:
raise ValueError(
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
)
self._log_task_start(task, agent_to_use, "bold_blue")
if task.async_execution:
context = self._set_context(task, task_outputs)
future = task.execute_async(
agent=agent_to_use, context=context, tools=task.tools
)
futures.append((task, future, task_index))
else:
if futures:
task_outputs = self._process_async_tasks(futures, True)
futures.clear()
for i in range(start_index):
stored_output = stored_outputs[i]["output"]
task_output = TaskOutput(
description=stored_output["description"],
agent=stored_output["agent"],
raw=stored_output["raw"],
pydantic=stored_output["pydantic"],
json_dict=stored_output["json_dict"],
output_format=stored_output["output_format"],
)
self.tasks[i].output = task_output
context = self._set_context(task, task_outputs)
task_output = task.execute_sync(
agent=agent_to_use, context=context, tools=task.tools
)
task_outputs = [task_output]
self._process_task_result(task, task_output)
self._store_execution_log(
task, task_output, task_index, was_replayed=True
)
# Process any remaining async tasks
if futures:
task_outputs = self._process_async_tasks(futures, True)
return self._create_crew_output(task_outputs)
self._logging_color = "bold_blue"
result = self._execute_tasks(self.tasks, self.manager_agent, start_index)
self._logging_color = "bold_purple"
return result
def copy(self):
"""Create a deep copy of the Crew."""

View File

@@ -1929,3 +1929,106 @@ def test_replay_without_output_tasks_json():
with pytest.raises(ValueError):
crew.replay_from_task(str(task.id))
@pytest.mark.vcr(filter_headers=["authorization"])
def test_replay_task_with_context():
agent1 = Agent(
role="Researcher",
goal="Research AI advancements.",
backstory="You are an expert in AI research.",
)
agent2 = Agent(
role="Writer",
goal="Write detailed articles on AI.",
backstory="You have a background in journalism and AI.",
)
task1 = Task(
description="Research the latest advancements in AI.",
expected_output="A detailed report on AI advancements.",
agent=agent1,
)
task2 = Task(
description="Summarize the AI advancements report.",
expected_output="A summary of the AI advancements report.",
agent=agent2,
)
task3 = Task(
description="Write an article based on the AI advancements summary.",
expected_output="An article on AI advancements.",
agent=agent2,
)
task4 = Task(
description="Create a presentation based on the AI advancements article.",
expected_output="A presentation on AI advancements.",
agent=agent2,
context=[task1],
)
crew = Crew(
agents=[agent1, agent2],
tasks=[task1, task2, task3, task4],
process=Process.sequential,
)
mock_task_output1 = TaskOutput(
description="Research the latest advancements in AI.",
raw="Detailed report on AI advancements...",
agent="Researcher",
json_dict=None,
output_format=OutputFormat.RAW,
pydantic=None,
summary="Detailed report on AI advancements...",
)
mock_task_output2 = TaskOutput(
description="Summarize the AI advancements report.",
raw="Summary of the AI advancements report...",
agent="Writer",
json_dict=None,
output_format=OutputFormat.RAW,
pydantic=None,
summary="Summary of the AI advancements report...",
)
mock_task_output3 = TaskOutput(
description="Write an article based on the AI advancements summary.",
raw="Article on AI advancements...",
agent="Writer",
json_dict=None,
output_format=OutputFormat.RAW,
pydantic=None,
summary="Article on AI advancements...",
)
mock_task_output4 = TaskOutput(
description="Create a presentation based on the AI advancements article.",
raw="Presentation on AI advancements...",
agent="Writer",
json_dict=None,
output_format=OutputFormat.RAW,
pydantic=None,
summary="Presentation on AI advancements...",
)
with patch.object(Task, "execute_sync") as mock_execute_task:
mock_execute_task.side_effect = [
mock_task_output1,
mock_task_output2,
mock_task_output3,
mock_task_output4,
]
crew.kickoff()
# Check if the crew_tasks_output.json file is created
assert os.path.exists(CREW_TASKS_OUTPUT_FILE)
# Replay task4 and ensure it uses task1's context properly
with patch.object(Task, "execute_sync") as mock_replay_task:
mock_replay_task.return_value = mock_task_output4
replayed_output = crew.replay_from_task(str(task4.id))
assert replayed_output.raw == "Presentation on AI advancements..."
# Clean up the file after test
if os.path.exists(CREW_TASKS_OUTPUT_FILE):
os.remove(CREW_TASKS_OUTPUT_FILE)