Lj/optional agent in task bug (#843)

* fixed bug for manager overriding task agent and then added pydanic valditors to sequential when no agent is added to task

* better test and fixed task.agent logic

* fixed tests and better validator message

* added validator for async_execution true in tasks whenever in hierarchical run
This commit is contained in:
Lorenze Jay
2024-07-03 14:45:53 -07:00
committed by GitHub
parent 57fc079267
commit 5d18f73654
5 changed files with 3904 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import json
import uuid
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Tuple
from langchain_core.callbacks import BaseCallbackHandler
from pydantic import (
@@ -224,6 +224,33 @@ class Crew(BaseModel):
agent.set_rpm_controller(self._rpm_controller)
return self
@model_validator(mode="after")
def validate_tasks(self):
if self.process == Process.sequential:
for task in self.tasks:
if task.agent is None:
raise PydanticCustomError(
"missing_agent_in_task",
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
{},
)
return self
@model_validator(mode="after")
def check_tasks_in_hierarchical_process_not_async(self):
"""Validates that the tasks in hierarchical process are not flagged with async_execution."""
if self.process == Process.hierarchical:
for task in self.tasks:
if task.async_execution:
raise PydanticCustomError(
"async_execution_in_hierarchical_process",
"Hierarchical process error: Tasks cannot be flagged with async_execution.",
{},
)
return self
def _setup_from_config(self):
assert self.config is not None, "Config should not be None."
@@ -315,9 +342,7 @@ class Crew(BaseModel):
if self.process == Process.sequential:
result = self._run_sequential_process()
elif self.process == Process.hierarchical:
# type: ignore # Unpacking a string is disallowed
result, manager_metrics = self._run_hierarchical_process()
# type: ignore # Cannot determine type of "manager_metrics"
metrics.append(manager_metrics)
else:
raise NotImplementedError(
@@ -433,14 +458,16 @@ class Crew(BaseModel):
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
return self._format_output(task_output, token_usage)
def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]:
def _run_hierarchical_process(
self,
) -> Tuple[Union[str, Dict[str, Any]], Dict[str, Any]]:
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
i18n = I18N(prompt_file=self.prompt_file)
if self.manager_agent is not None:
self.manager_agent.allow_delegation = True
manager = self.manager_agent
if len(manager.tools) > 0:
if manager.tools is not None and len(manager.tools) > 0:
raise Exception("Manager agent should not have tools")
manager.tools = self.manager_agent.get_delegation_tools(self.agents)
else:
@@ -465,6 +492,10 @@ class Crew(BaseModel):
agent=manager.role, task=task.description, status="started"
)
if task.agent:
manager.tools = task.agent.get_delegation_tools([task.agent])
else:
manager.tools = manager.get_delegation_tools(self.agents)
task_output = task.execute(
agent=manager, context=task_output, tools=manager.tools
)
@@ -554,7 +585,9 @@ class Crew(BaseModel):
self._rpm_controller.stop_rpm_counter()
if agentops:
agentops.end_session(
end_state="Success", end_state_reason="Finished Execution", is_auto_end=True
end_state="Success",
end_state_reason="Finished Execution",
is_auto_end=True,
)
self._telemetry.end_crew(self, output)