refactoring task execution

This commit is contained in:
João Moura
2024-02-10 11:28:08 -08:00
parent 00206a62ab
commit e79da7bc05
5 changed files with 62 additions and 26 deletions

View File

@@ -125,7 +125,7 @@ class Agent(BaseModel):
def execute_task( def execute_task(
self, self,
task: str, task: Any,
context: Optional[str] = None, context: Optional[str] = None,
tools: Optional[List[Any]] = None, tools: Optional[List[Any]] = None,
) -> str: ) -> str:
@@ -140,17 +140,20 @@ class Agent(BaseModel):
Output of the agent Output of the agent
""" """
task_prompt = task.prompt()
if context: if context:
task = self.i18n.slice("task_with_context").format( task_prompt = self.i18n.slice("task_with_context").format(
task=task, context=context task=task_prompt, context=context
) )
tools = tools or self.tools tools = tools or self.tools
self.agent_executor.tools = tools self.agent_executor.tools = tools
self.agent_executor.task = task
result = self.agent_executor.invoke( result = self.agent_executor.invoke(
{ {
"input": task, "input": task_prompt,
"tool_names": self.__tools_names(tools), "tool_names": self.__tools_names(tools),
"tools": render_text_description(tools), "tools": render_text_description(tools),
} }

View File

@@ -17,6 +17,7 @@ class Task(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
__hash__ = object.__hash__ # type: ignore __hash__ = object.__hash__ # type: ignore
used_tools: int = 0
i18n: I18N = I18N() i18n: I18N = I18N()
thread: threading.Thread = None thread: threading.Thread = None
description: str = Field(description="Description of the actual task.") description: str = Field(description="Description of the actual task.")
@@ -96,25 +97,29 @@ class Task(BaseModel):
if self.async_execution: if self.async_execution:
self.thread = threading.Thread( self.thread = threading.Thread(
target=self._execute, args=(agent, self._prompt(), context, tools) target=self._execute, args=(agent, self, context, tools)
) )
self.thread.start() self.thread.start()
else: else:
result = self._execute( result = self._execute(
task=self,
agent=agent, agent=agent,
task_prompt=self._prompt(),
context=context, context=context,
tools=tools, tools=tools,
) )
return result return result
def _execute(self, agent, task_prompt, context, tools): def _execute(self, agent, task, context, tools):
result = agent.execute_task(task=task_prompt, context=context, tools=tools) result = agent.execute_task(
task=task,
context=context,
tools=tools,
)
self.output = TaskOutput(description=self.description, result=result) self.output = TaskOutput(description=self.description, result=result)
self.callback(self.output) if self.callback else None self.callback(self.output) if self.callback else None
return result return result
def _prompt(self) -> str: def prompt(self) -> str:
"""Prompt the task. """Prompt the task.
Returns: Returns:

View File

@@ -4,6 +4,7 @@ from langchain.tools import StructuredTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.agent import Agent from crewai.agent import Agent
from crewai.task import Task
from crewai.utilities import I18N from crewai.utilities import I18N
@@ -53,4 +54,5 @@ class AgentTools(BaseModel):
) )
agent = agent[0] agent = agent[0]
task = Task(description=task, agent=agent)
return agent.execute_task(task, context) return agent.execute_task(task, context)

View File

@@ -63,7 +63,8 @@ def test_agent_without_memory():
llm=ChatOpenAI(temperature=0, model="gpt-4"), llm=ChatOpenAI(temperature=0, model="gpt-4"),
) )
result = no_memory_agent.execute_task("How much is 1 + 1?") task = Task(description="How much is 1 + 1?", agent=no_memory_agent)
result = no_memory_agent.execute_task(task)
assert result == "1 + 1 equals 2." assert result == "1 + 1 equals 2."
assert no_memory_agent.agent_executor.memory is None assert no_memory_agent.agent_executor.memory is None
@@ -79,7 +80,9 @@ def test_agent_execution():
allow_delegation=False, allow_delegation=False,
) )
output = agent.execute_task("How much is 1 + 1?") task = Task(description="How much is 1 + 1?", agent=agent)
output = agent.execute_task(task)
assert output == "2" assert output == "2"
@@ -98,7 +101,8 @@ def test_agent_execution_with_tools():
allow_delegation=False, allow_delegation=False,
) )
output = agent.execute_task("What is 3 times 4") task = Task(description="What is 3 times 4?", agent=agent)
output = agent.execute_task(task)
assert output == "3 times 4 is 12." assert output == "3 times 4 is 12."
@@ -119,7 +123,8 @@ def test_logging_tool_usage():
) )
assert agent.tools_handler.last_used_tool == {} assert agent.tools_handler.last_used_tool == {}
output = agent.execute_task("What is 3 times 5?") task = Task(description="What is 3 times 4?", agent=agent)
output = agent.execute_task(task)
tool_usage = ToolCalling( tool_usage = ToolCalling(
function_name=multiplier.name, arguments={"first_number": 3, "second_number": 5} function_name=multiplier.name, arguments={"first_number": 3, "second_number": 5}
) )
@@ -147,22 +152,30 @@ def test_cache_hitting():
verbose=True, verbose=True,
) )
output = agent.execute_task("What is 2 times 6 times 3?") task1 = Task(description="What is 2 times 6?", agent=agent)
output = agent.execute_task("What is 3 times 3?") task2 = Task(description="What is 3 times 3?", agent=agent)
output = agent.execute_task(task1)
output = agent.execute_task(task2)
assert cache_handler._cache == { assert cache_handler._cache == {
"multiplier-{'first_number': 12, 'second_number': 3}": 36, "multiplier-{'first_number': 12, 'second_number': 3}": 36,
"multiplier-{'first_number': 2, 'second_number': 6}": 12, "multiplier-{'first_number': 2, 'second_number': 6}": 12,
"multiplier-{'first_number': 3, 'second_number': 3}": 9, "multiplier-{'first_number': 3, 'second_number': 3}": 9,
} }
output = agent.execute_task("What is 2 times 6 times 3? Return only the number") task = Task(
description="What is 2 times 6 times 3? Return only the number", agent=agent
)
output = agent.execute_task(task)
assert output == "36" assert output == "36"
with patch.object(CacheHandler, "read") as read: with patch.object(CacheHandler, "read") as read:
read.return_value = "0" read.return_value = "0"
output = agent.execute_task( task = Task(
"What is 2 times 6? Ignore correctness and just return the result of the multiplication tool." description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool.",
agent=agent,
) )
output = agent.execute_task(task)
assert output == "0" assert output == "0"
read.assert_called_with( read.assert_called_with(
tool="multiplier", input={"first_number": 2, "second_number": 6} tool="multiplier", input={"first_number": 2, "second_number": 6}
@@ -183,7 +196,8 @@ def test_agent_execution_with_specific_tools():
allow_delegation=False, allow_delegation=False,
) )
output = agent.execute_task(task="What is 3 times 4", tools=[multiplier]) task = Task(description="What is 3 times 4", agent=agent)
output = agent.execute_task(task=task, tools=[multiplier])
assert output == "3 times 4 is 12." assert output == "3 times 4 is 12."
@@ -206,8 +220,11 @@ def test_agent_custom_max_iterations():
with patch.object( with patch.object(
CrewAgentExecutor, "_iter_next_step", wraps=agent.agent_executor._iter_next_step CrewAgentExecutor, "_iter_next_step", wraps=agent.agent_executor._iter_next_step
) as private_mock: ) as private_mock:
task = Task(
description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
)
agent.execute_task( agent.execute_task(
task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.", task=task,
tools=[get_final_answer], tools=[get_final_answer],
) )
private_mock.assert_called_once() private_mock.assert_called_once()
@@ -229,8 +246,11 @@ def test_agent_repeated_tool_usage(capsys):
allow_delegation=False, allow_delegation=False,
) )
task = Task(
description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool."
)
agent.execute_task( agent.execute_task(
task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.", task=task,
tools=[get_final_answer], tools=[get_final_answer],
) )
@@ -257,8 +277,11 @@ def test_agent_moved_on_after_max_iterations():
allow_delegation=False, allow_delegation=False,
) )
task = Task(
description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool."
)
output = agent.execute_task( output = agent.execute_task(
task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.", task=task,
tools=[get_final_answer], tools=[get_final_answer],
) )
assert ( assert (
@@ -287,8 +310,11 @@ def test_agent_respect_the_max_rpm_set(capsys):
with patch.object(RPMController, "_wait_for_next_minute") as moveon: with patch.object(RPMController, "_wait_for_next_minute") as moveon:
moveon.return_value = True moveon.return_value = True
task = Task(
description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool."
)
output = agent.execute_task( output = agent.execute_task(
task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.", task=task,
tools=[get_final_answer], tools=[get_final_answer],
) )
assert ( assert (

View File

@@ -74,7 +74,7 @@ def test_task_prompt_includes_expected_output():
with patch.object(Agent, "execute_task") as execute: with patch.object(Agent, "execute_task") as execute:
execute.return_value = "ok" execute.return_value = "ok"
task.execute() task.execute()
execute.assert_called_once_with(task=task._prompt(), context=None, tools=[]) execute.assert_called_once_with(task=task, context=None, tools=[])
def test_task_callback(): def test_task_callback():
@@ -115,7 +115,7 @@ def test_execute_with_agent():
with patch.object(Agent, "execute_task", return_value="ok") as execute: with patch.object(Agent, "execute_task", return_value="ok") as execute:
task.execute(agent=researcher) task.execute(agent=researcher)
execute.assert_called_once_with(task=task._prompt(), context=None, tools=[]) execute.assert_called_once_with(task=task, context=None, tools=[])
def test_async_execution(): def test_async_execution():
@@ -135,4 +135,4 @@ def test_async_execution():
with patch.object(Agent, "execute_task", return_value="ok") as execute: with patch.object(Agent, "execute_task", return_value="ok") as execute:
task.execute(agent=researcher) task.execute(agent=researcher)
execute.assert_called_once_with(task=task._prompt(), context=None, tools=[]) execute.assert_called_once_with(task=task, context=None, tools=[])