From bbbd976fe3e58bb8080dcdc9bf377b0cf8670706 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Moura?= Date: Sat, 10 Feb 2024 11:28:08 -0800 Subject: [PATCH] refactoring task execution --- src/crewai/agent.py | 11 ++++--- src/crewai/task.py | 15 ++++++--- src/crewai/tools/agent_tools.py | 2 ++ tests/agent_test.py | 54 ++++++++++++++++++++++++--------- tests/task_test.py | 6 ++-- 5 files changed, 62 insertions(+), 26 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 8da40ee26..9c0dcfe7a 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -125,7 +125,7 @@ class Agent(BaseModel): def execute_task( self, - task: str, + task: Any, context: Optional[str] = None, tools: Optional[List[Any]] = None, ) -> str: @@ -140,17 +140,20 @@ class Agent(BaseModel): Output of the agent """ + task_prompt = task.prompt() + if context: - task = self.i18n.slice("task_with_context").format( - task=task, context=context + task_prompt = self.i18n.slice("task_with_context").format( + task=task_prompt, context=context ) tools = tools or self.tools self.agent_executor.tools = tools + self.agent_executor.task = task result = self.agent_executor.invoke( { - "input": task, + "input": task_prompt, "tool_names": self.__tools_names(tools), "tools": render_text_description(tools), } diff --git a/src/crewai/task.py b/src/crewai/task.py index 18d90aabf..f6e1ceaed 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -17,6 +17,7 @@ class Task(BaseModel): arbitrary_types_allowed = True __hash__ = object.__hash__ # type: ignore + used_tools: int = 0 i18n: I18N = I18N() thread: threading.Thread = None description: str = Field(description="Description of the actual task.") @@ -96,25 +97,29 @@ class Task(BaseModel): if self.async_execution: self.thread = threading.Thread( - target=self._execute, args=(agent, self._prompt(), context, tools) + target=self._execute, args=(agent, self, context, tools) ) self.thread.start() else: result = self._execute( + task=self, agent=agent, - task_prompt=self._prompt(), context=context, tools=tools, ) return result - def _execute(self, agent, task_prompt, context, tools): - result = agent.execute_task(task=task_prompt, context=context, tools=tools) + def _execute(self, agent, task, context, tools): + result = agent.execute_task( + task=task, + context=context, + tools=tools, + ) self.output = TaskOutput(description=self.description, result=result) self.callback(self.output) if self.callback else None return result - def _prompt(self) -> str: + def prompt(self) -> str: """Prompt the task. Returns: diff --git a/src/crewai/tools/agent_tools.py b/src/crewai/tools/agent_tools.py index cb41da04e..b22869c16 100644 --- a/src/crewai/tools/agent_tools.py +++ b/src/crewai/tools/agent_tools.py @@ -4,6 +4,7 @@ from langchain.tools import StructuredTool from pydantic import BaseModel, Field from crewai.agent import Agent +from crewai.task import Task from crewai.utilities import I18N @@ -53,4 +54,5 @@ class AgentTools(BaseModel): ) agent = agent[0] + task = Task(description=task, agent=agent) return agent.execute_task(task, context) diff --git a/tests/agent_test.py b/tests/agent_test.py index cdba0e00b..48165135c 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -63,7 +63,8 @@ def test_agent_without_memory(): llm=ChatOpenAI(temperature=0, model="gpt-4"), ) - result = no_memory_agent.execute_task("How much is 1 + 1?") + task = Task(description="How much is 1 + 1?", agent=no_memory_agent) + result = no_memory_agent.execute_task(task) assert result == "1 + 1 equals 2." assert no_memory_agent.agent_executor.memory is None @@ -79,7 +80,9 @@ def test_agent_execution(): allow_delegation=False, ) - output = agent.execute_task("How much is 1 + 1?") + task = Task(description="How much is 1 + 1?", agent=agent) + + output = agent.execute_task(task) assert output == "2" @@ -98,7 +101,8 @@ def test_agent_execution_with_tools(): allow_delegation=False, ) - output = agent.execute_task("What is 3 times 4") + task = Task(description="What is 3 times 4?", agent=agent) + output = agent.execute_task(task) assert output == "3 times 4 is 12." @@ -119,7 +123,8 @@ def test_logging_tool_usage(): ) assert agent.tools_handler.last_used_tool == {} - output = agent.execute_task("What is 3 times 5?") + task = Task(description="What is 3 times 4?", agent=agent) + output = agent.execute_task(task) tool_usage = ToolCalling( function_name=multiplier.name, arguments={"first_number": 3, "second_number": 5} ) @@ -147,22 +152,30 @@ def test_cache_hitting(): verbose=True, ) - output = agent.execute_task("What is 2 times 6 times 3?") - output = agent.execute_task("What is 3 times 3?") + task1 = Task(description="What is 2 times 6?", agent=agent) + task2 = Task(description="What is 3 times 3?", agent=agent) + + output = agent.execute_task(task1) + output = agent.execute_task(task2) assert cache_handler._cache == { "multiplier-{'first_number': 12, 'second_number': 3}": 36, "multiplier-{'first_number': 2, 'second_number': 6}": 12, "multiplier-{'first_number': 3, 'second_number': 3}": 9, } - output = agent.execute_task("What is 2 times 6 times 3? Return only the number") + task = Task( + description="What is 2 times 6 times 3? Return only the number", agent=agent + ) + output = agent.execute_task(task) assert output == "36" with patch.object(CacheHandler, "read") as read: read.return_value = "0" - output = agent.execute_task( - "What is 2 times 6? Ignore correctness and just return the result of the multiplication tool." + task = Task( + description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool.", + agent=agent, ) + output = agent.execute_task(task) assert output == "0" read.assert_called_with( tool="multiplier", input={"first_number": 2, "second_number": 6} @@ -183,7 +196,8 @@ def test_agent_execution_with_specific_tools(): allow_delegation=False, ) - output = agent.execute_task(task="What is 3 times 4", tools=[multiplier]) + task = Task(description="What is 3 times 4", agent=agent) + output = agent.execute_task(task=task, tools=[multiplier]) assert output == "3 times 4 is 12." @@ -206,8 +220,11 @@ def test_agent_custom_max_iterations(): with patch.object( CrewAgentExecutor, "_iter_next_step", wraps=agent.agent_executor._iter_next_step ) as private_mock: + task = Task( + description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.", + ) agent.execute_task( - task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.", + task=task, tools=[get_final_answer], ) private_mock.assert_called_once() @@ -229,8 +246,11 @@ def test_agent_repeated_tool_usage(capsys): allow_delegation=False, ) + task = Task( + description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool." + ) agent.execute_task( - task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.", + task=task, tools=[get_final_answer], ) @@ -257,8 +277,11 @@ def test_agent_moved_on_after_max_iterations(): allow_delegation=False, ) + task = Task( + description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool." + ) output = agent.execute_task( - task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.", + task=task, tools=[get_final_answer], ) assert ( @@ -287,8 +310,11 @@ def test_agent_respect_the_max_rpm_set(capsys): with patch.object(RPMController, "_wait_for_next_minute") as moveon: moveon.return_value = True + task = Task( + description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool." + ) output = agent.execute_task( - task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.", + task=task, tools=[get_final_answer], ) assert ( diff --git a/tests/task_test.py b/tests/task_test.py index 1e0ef1b15..1a42044af 100644 --- a/tests/task_test.py +++ b/tests/task_test.py @@ -74,7 +74,7 @@ def test_task_prompt_includes_expected_output(): with patch.object(Agent, "execute_task") as execute: execute.return_value = "ok" task.execute() - execute.assert_called_once_with(task=task._prompt(), context=None, tools=[]) + execute.assert_called_once_with(task=task, context=None, tools=[]) def test_task_callback(): @@ -115,7 +115,7 @@ def test_execute_with_agent(): with patch.object(Agent, "execute_task", return_value="ok") as execute: task.execute(agent=researcher) - execute.assert_called_once_with(task=task._prompt(), context=None, tools=[]) + execute.assert_called_once_with(task=task, context=None, tools=[]) def test_async_execution(): @@ -135,4 +135,4 @@ def test_async_execution(): with patch.object(Agent, "execute_task", return_value="ok") as execute: task.execute(agent=researcher) - execute.assert_called_once_with(task=task._prompt(), context=None, tools=[]) + execute.assert_called_once_with(task=task, context=None, tools=[])