From 7b49b4e985a55bc54bf9bc542589dbcc8eb2521a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Moura?= Date: Wed, 28 Feb 2024 01:57:04 -0300 Subject: [PATCH] Adding initial formatting error counting and token counter --- src/crewai/agent.py | 35 ++++++++++++++++++++++------------ src/crewai/agents/parser.py | 5 ++++- tests/agent_test.py | 38 +++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 9aa6a3d3a..a0bed66d2 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -23,6 +23,7 @@ from pydantic_core import PydanticCustomError from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler from crewai.utilities import I18N, Logger, Prompts, RPMController +from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess class Agent(BaseModel): @@ -51,7 +52,9 @@ class Agent(BaseModel): _logger: Logger = PrivateAttr() _rpm_controller: RPMController = PrivateAttr(default=None) _request_within_rpm_limit: Any = PrivateAttr(default=None) + _token_process: TokenProcess = TokenProcess() + formatting_errors: int = 0 model_config = ConfigDict(arbitrary_types_allowed=True) id: UUID4 = Field( default_factory=uuid.uuid4, @@ -123,8 +126,12 @@ class Agent(BaseModel): return self @model_validator(mode="after") - def check_agent_executor(self) -> "Agent": - """Check if the agent executor is set.""" + def set_agent_executor(self) -> "Agent": + """set agent executor is set.""" + if hasattr(self.llm, "model_name"): + self.llm.callbacks = [ + TokenCalcHandler(self.llm.model_name, self._token_process) + ] if not self.agent_executor: self.set_cache_handler(self.cache_handler) return self @@ -243,20 +250,14 @@ class Agent(BaseModel): ) bind = self.llm.bind(stop=[self.i18n.slice("observation")]) - inner_agent = agent_args | execution_prompt | bind | CrewAgentParser() + inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self) self.agent_executor = CrewAgentExecutor( agent=RunnableAgent(runnable=inner_agent), **executor_args ) - def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]: - """Parse tools to be used for the task.""" - tools_list = [] - for tool in tools: - if isinstance(tool, CrewAITool): - tools_list.append(tool.to_langchain()) - else: - tools_list.append(tool) - return tools_list + def count_formatting_errors(self) -> None: + """Count the formatting errors of the agent.""" + self.formatting_errors += 1 def format_log_to_str( self, @@ -271,6 +272,16 @@ class Agent(BaseModel): thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}" return thoughts + def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]: + """Parse tools to be used for the task.""" + tools_list = [] + for tool in tools: + if isinstance(tool, CrewAITool): + tools_list.append(tool.to_langchain()) + else: + tools_list.append(tool) + return tools_list + @staticmethod def __tools_names(tools) -> str: return ", ".join([t.name for t in tools]) diff --git a/src/crewai/agents/parser.py b/src/crewai/agents/parser.py index 7fba15b5c..0ec7172b0 100644 --- a/src/crewai/agents/parser.py +++ b/src/crewai/agents/parser.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any, Union from langchain.agents.output_parsers import ReActSingleInputOutputParser from langchain_core.agents import AgentAction, AgentFinish @@ -34,6 +34,7 @@ class CrewAgentParser(ReActSingleInputOutputParser): """ _i18n: I18N = I18N() + agent: Any = None def parse(self, text: str) -> Union[AgentAction, AgentFinish]: includes_answer = FINAL_ANSWER_ACTION in text @@ -41,6 +42,7 @@ class CrewAgentParser(ReActSingleInputOutputParser): if includes_tool: if includes_answer: + self.agent.count_formatting_errors() raise OutputParserException(f"{FINAL_ANSWER_AND_TOOL_ERROR_MESSAGE}") return AgentAction("", "", text) @@ -52,6 +54,7 @@ class CrewAgentParser(ReActSingleInputOutputParser): format = self._i18n.slice("format_without_tools") error = f"{format}" + self.agent.count_formatting_errors() raise OutputParserException( error, observation=error, diff --git a/tests/agent_test.py b/tests/agent_test.py index d69ec627e..932dc802a 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -4,11 +4,13 @@ from unittest.mock import patch import pytest from langchain.tools import tool +from langchain_core.exceptions import OutputParserException from langchain_openai import ChatOpenAI from crewai import Agent, Crew, Task from crewai.agents.cache import CacheHandler from crewai.agents.executor import CrewAgentExecutor +from crewai.agents.parser import CrewAgentParser from crewai.tools.tool_calling import InstructorToolCalling from crewai.tools.tool_usage import ToolUsage from crewai.utilities import RPMController @@ -576,3 +578,39 @@ def test_agent_function_calling_llm(): crew.kickoff() private_mock.assert_called() + + +def test_agent_count_formatting_error(): + from unittest.mock import patch + + agent1 = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + verbose=True, + ) + + parser = CrewAgentParser() + parser.agent = agent1 + + with patch.object(Agent, "count_formatting_errors") as mock_count_errors: + test_text = "This text does not match expected formats." + with pytest.raises(OutputParserException): + parser.parse(test_text) + mock_count_errors.assert_called_once() + + +def test_agent_llm_uses_token_calc_handler_with_llm_has_model_name(): + agent1 = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + verbose=True, + ) + + assert len(agent1.llm.callbacks) == 1 + assert agent1.llm.callbacks[0].__class__.__name__ == "TokenCalcHandler" + assert agent1.llm.callbacks[0].model == "gpt-4" + assert ( + agent1.llm.callbacks[0].token_cost_process.__class__.__name__ == "TokenProcess" + )