From fa397d47e3b9b66c81dcb5dec867487351a99e39 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Mon, 2 Dec 2024 12:22:50 -0500 Subject: [PATCH] v1 of fix implemented. Need to confirm with tokens. --- src/crewai/agent.py | 6 +-- src/crewai/agents/crew_agent_executor.py | 53 ++++++++++++++++++------ src/crewai/tools/base_tool.py | 1 + src/crewai/tools/structured_tool.py | 4 ++ 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 26380ebc2..abe678db1 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -1,7 +1,7 @@ import os import shutil import subprocess -from typing import Any, List, Literal, Optional, Union, Dict +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import Field, InstanceOf, PrivateAttr, model_validator @@ -9,9 +9,10 @@ from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.cli.constants import ENV_VARS -from crewai.llm import LLM from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource +from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context +from crewai.llm import LLM from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.task import Task from crewai.tools import BaseTool @@ -21,7 +22,6 @@ from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_F from crewai.utilities.converter import generate_model_description from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.training_handler import CrewTrainingHandler -from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context def mock_agent_ops_provider(): diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index bf14e6915..a8dc1c8db 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -1,5 +1,6 @@ import json import re +from dataclasses import dataclass from typing import Any, Dict, List, Union from crewai.agents.agent_builder.base_agent import BaseAgent @@ -12,6 +13,7 @@ from crewai.agents.parser import ( OutputParserException, ) from crewai.agents.tools_handler import ToolsHandler +from crewai.tools.base_tool import BaseTool from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException from crewai.utilities import I18N, Printer from crewai.utilities.constants import TRAINING_DATA_FILE @@ -22,6 +24,12 @@ from crewai.utilities.logger import Logger from crewai.utilities.training_handler import CrewTrainingHandler +@dataclass +class ToolResult: + result: Any + result_as_answer: bool + + class CrewAgentExecutor(CrewAgentExecutorMixin): _logger: Logger = Logger() @@ -33,7 +41,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): agent: BaseAgent, prompt: dict[str, str], max_iter: int, - tools: List[Any], + tools: List[BaseTool], tools_names: str, stop_words: List[str], tools_description: str, @@ -70,7 +78,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self.iterations = 0 self.log_error_after = 3 self.have_forced_answer = False - self.name_to_tool_map = {tool.name: tool for tool in self.tools} + self.tool_name_to_tool_map: Dict[str, BaseTool] = { + tool.name: tool for tool in self.tools + } if self.llm.stop: self.llm.stop = list(set(self.llm.stop + self.stop)) else: @@ -91,6 +101,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False)) formatted_answer = self._invoke_loop() + print("FORMATTED ANSWER: ", formatted_answer) if self.ask_for_human_input: human_feedback = self._ask_human_input(formatted_answer.output) @@ -111,7 +122,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): def _invoke_loop(self, formatted_answer=None): try: while not isinstance(formatted_answer, AgentFinish): + print("STARTING LOOP") if not self.request_within_rpm_limit or self.request_within_rpm_limit(): + print("MESSAGES: ", self.messages) answer = self.llm.call( self.messages, callbacks=self.callbacks, @@ -140,9 +153,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): formatted_answer = self._format_answer(answer) if isinstance(formatted_answer, AgentAction): - action_result = self._use_tool(formatted_answer) - formatted_answer.text += f"\nObservation: {action_result}" - formatted_answer.result = action_result + tool_result = self._execute_tool_and_check_finality( + formatted_answer + ) + formatted_answer.text += f"\nObservation: {tool_result.result}" + formatted_answer.result = tool_result.result + if tool_result.result_as_answer: + print("RESULT AS ANSWER: ", tool_result.result) + return AgentFinish( + thought="", + output=tool_result.result, + text=formatted_answer.text, + ) self._show_logs(formatted_answer) if self.step_callback: @@ -165,6 +187,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self.messages.append( self._format_msg(formatted_answer.text, role="assistant") ) + print("FORMATTED ANSWER in invoke_loop: ", formatted_answer) except OutputParserException as e: self.messages.append({"role": "user", "content": e.error}) @@ -239,7 +262,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n" ) - def _use_tool(self, agent_action: AgentAction) -> Any: + def _execute_tool_and_check_finality(self, agent_action: AgentAction) -> ToolResult: tool_usage = ToolUsage( tools_handler=self.tools_handler, tools=self.tools, @@ -255,19 +278,25 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): if isinstance(tool_calling, ToolUsageErrorException): tool_result = tool_calling.message + return ToolResult(result=tool_result, result_as_answer=False) else: if tool_calling.tool_name.casefold().strip() in [ - name.casefold().strip() for name in self.name_to_tool_map + name.casefold().strip() for name in self.tool_name_to_tool_map ] or tool_calling.tool_name.casefold().replace("_", " ") in [ - name.casefold().strip() for name in self.name_to_tool_map + name.casefold().strip() for name in self.tool_name_to_tool_map ]: tool_result = tool_usage.use(tool_calling, agent_action.text) + tool = self.tool_name_to_tool_map.get(tool_calling.tool_name) + if tool: + return ToolResult( + result=tool_result, result_as_answer=tool.result_as_answer + ) else: tool_result = self._i18n.errors("wrong_tool_name").format( tool=tool_calling.tool_name, tools=", ".join([tool.name.casefold() for tool in self.tools]), ) - return tool_result + return ToolResult(result=tool_result, result_as_answer=False) def _summarize_messages(self) -> None: messages_groups = [] @@ -333,9 +362,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): if self.crew is not None and hasattr(self.crew, "_train_iteration"): train_iteration = self.crew._train_iteration if agent_id in training_data and isinstance(train_iteration, int): - training_data[agent_id][train_iteration]["improved_output"] = ( - result.output - ) + training_data[agent_id][train_iteration][ + "improved_output" + ] = result.output training_handler.save(training_data) else: self._logger.log( diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index 46bfc03a6..c3840d23c 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -73,6 +73,7 @@ class BaseTool(BaseModel, ABC): description=self.description, args_schema=self.args_schema, func=self._run, + result_as_answer=self.result_as_answer, ) @classmethod diff --git a/src/crewai/tools/structured_tool.py b/src/crewai/tools/structured_tool.py index bd6818605..dfd23a9cb 100644 --- a/src/crewai/tools/structured_tool.py +++ b/src/crewai/tools/structured_tool.py @@ -22,6 +22,7 @@ class CrewStructuredTool: description: str, args_schema: type[BaseModel], func: Callable[..., Any], + result_as_answer: bool = False, ) -> None: """Initialize the structured tool. @@ -30,12 +31,14 @@ class CrewStructuredTool: description: A description of what the tool does args_schema: The pydantic model for the tool's arguments func: The function to run when the tool is called + result_as_answer: Whether to return the output directly """ self.name = name self.description = description self.args_schema = args_schema self.func = func self._logger = Logger() + self.result_as_answer = result_as_answer # Validate the function signature matches the schema self._validate_function_signature() @@ -98,6 +101,7 @@ class CrewStructuredTool: description=description, args_schema=schema, func=func, + result_as_answer=return_direct, ) @staticmethod