fixed bug for manager overriding task agent and then added pydanic valditors to sequential when no agent is added to task

This commit is contained in:
Lorenze Jay
2024-07-01 09:32:43 -07:00
parent 5b66e87621
commit 9392788ed0
4 changed files with 3724 additions and 13 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import json
import uuid
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Tuple
from langchain_core.callbacks import BaseCallbackHandler
from pydantic import (
@@ -219,6 +219,22 @@ class Crew(BaseModel):
agent.set_rpm_controller(self._rpm_controller)
return self
@model_validator(mode="after")
def validate_tasks(self):
process = self.process
tasks = self.tasks
if process == Process.sequential:
for task in tasks:
if task.agent is None:
raise PydanticCustomError(
"missing_agent_in_task",
"Agent is missing in the task with the following description: {task.description}",
{},
)
return self
def _setup_from_config(self):
assert self.config is not None, "Config should not be None."
@@ -309,9 +325,7 @@ class Crew(BaseModel):
if self.process == Process.sequential:
result = self._run_sequential_process()
elif self.process == Process.hierarchical:
# type: ignore # Unpacking a string is disallowed
result, manager_metrics = self._run_hierarchical_process()
# type: ignore # Cannot determine type of "manager_metrics"
metrics.append(manager_metrics)
else:
raise NotImplementedError(
@@ -409,14 +423,16 @@ class Crew(BaseModel):
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
return self._format_output(task_output, token_usage_formatted)
def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]:
def _run_hierarchical_process(
self,
) -> Tuple[Union[str, Dict[str, Any]], Dict[str, Any]]:
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
i18n = I18N(prompt_file=self.prompt_file)
if self.manager_agent is not None:
self.manager_agent.allow_delegation = True
manager = self.manager_agent
if len(manager.tools) > 0:
if manager.tools is not None and len(manager.tools) > 0:
raise Exception("Manager agent should not have tools")
manager.tools = self.manager_agent.get_delegation_tools(self.agents)
else:
@@ -428,7 +444,7 @@ class Crew(BaseModel):
llm=self.manager_llm,
verbose=True,
)
self.manager_agent = manager
task_output = ""
token_usage = []
for task in self.tasks:
@@ -439,15 +455,18 @@ class Crew(BaseModel):
self._file_handler.log(
agent=manager.role, task=task.description, status="started"
)
if task.agent is not None:
manager.tools = task.agent.get_delegation_tools([task.agent])
else:
manager.tools = manager.get_delegation_tools(self.agents)
task_output = task.execute(
agent=manager, context=task_output, tools=manager.tools
)
if hasattr(manager, "_token_process"):
token_summ = manager._token_process.get_summary()
token_usage.append(token_summ)
self._logger.log("debug", f"[{manager.role}] Task output: {task_output}")
if hasattr(task, "agent._token_process"):
token_summ = task.agent._token_process.get_summary()
token_usage.append(token_summ)
if self.output_log_file:
self._file_handler.log(
agent=manager.role, task=task_output, status="completed"
@@ -455,13 +474,13 @@ class Crew(BaseModel):
self._finish_execution(task_output)
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
manager_token_usage = manager._token_process.get_summary()
token_usage.append(manager_token_usage)
token_usage_formatted = self.aggregate_token_usage(token_usage)
return self._format_output(
task_output, token_usage_formatted
task_output,
token_usage_formatted,
), manager_token_usage
def copy(self):