Adding initial formatting error counting and token counter

This commit is contained in:
João Moura
2024-02-28 01:57:04 -03:00
parent 577db88f8e
commit 7b49b4e985
3 changed files with 65 additions and 13 deletions

View File

@@ -23,6 +23,7 @@ from pydantic_core import PydanticCustomError
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
from crewai.utilities import I18N, Logger, Prompts, RPMController from crewai.utilities import I18N, Logger, Prompts, RPMController
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess
class Agent(BaseModel): class Agent(BaseModel):
@@ -51,7 +52,9 @@ class Agent(BaseModel):
_logger: Logger = PrivateAttr() _logger: Logger = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr(default=None) _rpm_controller: RPMController = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None) _request_within_rpm_limit: Any = PrivateAttr(default=None)
_token_process: TokenProcess = TokenProcess()
formatting_errors: int = 0
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
id: UUID4 = Field( id: UUID4 = Field(
default_factory=uuid.uuid4, default_factory=uuid.uuid4,
@@ -123,8 +126,12 @@ class Agent(BaseModel):
return self return self
@model_validator(mode="after") @model_validator(mode="after")
def check_agent_executor(self) -> "Agent": def set_agent_executor(self) -> "Agent":
"""Check if the agent executor is set.""" """set agent executor is set."""
if hasattr(self.llm, "model_name"):
self.llm.callbacks = [
TokenCalcHandler(self.llm.model_name, self._token_process)
]
if not self.agent_executor: if not self.agent_executor:
self.set_cache_handler(self.cache_handler) self.set_cache_handler(self.cache_handler)
return self return self
@@ -243,20 +250,14 @@ class Agent(BaseModel):
) )
bind = self.llm.bind(stop=[self.i18n.slice("observation")]) bind = self.llm.bind(stop=[self.i18n.slice("observation")])
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser() inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
self.agent_executor = CrewAgentExecutor( self.agent_executor = CrewAgentExecutor(
agent=RunnableAgent(runnable=inner_agent), **executor_args agent=RunnableAgent(runnable=inner_agent), **executor_args
) )
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]: def count_formatting_errors(self) -> None:
"""Parse tools to be used for the task.""" """Count the formatting errors of the agent."""
tools_list = [] self.formatting_errors += 1
for tool in tools:
if isinstance(tool, CrewAITool):
tools_list.append(tool.to_langchain())
else:
tools_list.append(tool)
return tools_list
def format_log_to_str( def format_log_to_str(
self, self,
@@ -271,6 +272,16 @@ class Agent(BaseModel):
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}" thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
return thoughts return thoughts
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
"""Parse tools to be used for the task."""
tools_list = []
for tool in tools:
if isinstance(tool, CrewAITool):
tools_list.append(tool.to_langchain())
else:
tools_list.append(tool)
return tools_list
@staticmethod @staticmethod
def __tools_names(tools) -> str: def __tools_names(tools) -> str:
return ", ".join([t.name for t in tools]) return ", ".join([t.name for t in tools])

View File

@@ -1,4 +1,4 @@
from typing import Union from typing import Any, Union
from langchain.agents.output_parsers import ReActSingleInputOutputParser from langchain.agents.output_parsers import ReActSingleInputOutputParser
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
@@ -34,6 +34,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
""" """
_i18n: I18N = I18N() _i18n: I18N = I18N()
agent: Any = None
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
includes_answer = FINAL_ANSWER_ACTION in text includes_answer = FINAL_ANSWER_ACTION in text
@@ -41,6 +42,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
if includes_tool: if includes_tool:
if includes_answer: if includes_answer:
self.agent.count_formatting_errors()
raise OutputParserException(f"{FINAL_ANSWER_AND_TOOL_ERROR_MESSAGE}") raise OutputParserException(f"{FINAL_ANSWER_AND_TOOL_ERROR_MESSAGE}")
return AgentAction("", "", text) return AgentAction("", "", text)
@@ -52,6 +54,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
format = self._i18n.slice("format_without_tools") format = self._i18n.slice("format_without_tools")
error = f"{format}" error = f"{format}"
self.agent.count_formatting_errors()
raise OutputParserException( raise OutputParserException(
error, error,
observation=error, observation=error,

View File

@@ -4,11 +4,13 @@ from unittest.mock import patch
import pytest import pytest
from langchain.tools import tool from langchain.tools import tool
from langchain_core.exceptions import OutputParserException
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from crewai import Agent, Crew, Task from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.agents.executor import CrewAgentExecutor from crewai.agents.executor import CrewAgentExecutor
from crewai.agents.parser import CrewAgentParser
from crewai.tools.tool_calling import InstructorToolCalling from crewai.tools.tool_calling import InstructorToolCalling
from crewai.tools.tool_usage import ToolUsage from crewai.tools.tool_usage import ToolUsage
from crewai.utilities import RPMController from crewai.utilities import RPMController
@@ -576,3 +578,39 @@ def test_agent_function_calling_llm():
crew.kickoff() crew.kickoff()
private_mock.assert_called() private_mock.assert_called()
def test_agent_count_formatting_error():
from unittest.mock import patch
agent1 = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
verbose=True,
)
parser = CrewAgentParser()
parser.agent = agent1
with patch.object(Agent, "count_formatting_errors") as mock_count_errors:
test_text = "This text does not match expected formats."
with pytest.raises(OutputParserException):
parser.parse(test_text)
mock_count_errors.assert_called_once()
def test_agent_llm_uses_token_calc_handler_with_llm_has_model_name():
agent1 = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
verbose=True,
)
assert len(agent1.llm.callbacks) == 1
assert agent1.llm.callbacks[0].__class__.__name__ == "TokenCalcHandler"
assert agent1.llm.callbacks[0].model == "gpt-4"
assert (
agent1.llm.callbacks[0].token_cost_process.__class__.__name__ == "TokenProcess"
)