Adding initial formatting error counting and token counter

This commit is contained in:
João Moura
2024-02-28 01:57:04 -03:00
parent 577db88f8e
commit 7b49b4e985
3 changed files with 65 additions and 13 deletions

View File

@@ -23,6 +23,7 @@ from pydantic_core import PydanticCustomError
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
from crewai.utilities import I18N, Logger, Prompts, RPMController
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess
class Agent(BaseModel):
@@ -51,7 +52,9 @@ class Agent(BaseModel):
_logger: Logger = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
_token_process: TokenProcess = TokenProcess()
formatting_errors: int = 0
model_config = ConfigDict(arbitrary_types_allowed=True)
id: UUID4 = Field(
default_factory=uuid.uuid4,
@@ -123,8 +126,12 @@ class Agent(BaseModel):
return self
@model_validator(mode="after")
def check_agent_executor(self) -> "Agent":
"""Check if the agent executor is set."""
def set_agent_executor(self) -> "Agent":
"""set agent executor is set."""
if hasattr(self.llm, "model_name"):
self.llm.callbacks = [
TokenCalcHandler(self.llm.model_name, self._token_process)
]
if not self.agent_executor:
self.set_cache_handler(self.cache_handler)
return self
@@ -243,20 +250,14 @@ class Agent(BaseModel):
)
bind = self.llm.bind(stop=[self.i18n.slice("observation")])
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser()
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
self.agent_executor = CrewAgentExecutor(
agent=RunnableAgent(runnable=inner_agent), **executor_args
)
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
"""Parse tools to be used for the task."""
tools_list = []
for tool in tools:
if isinstance(tool, CrewAITool):
tools_list.append(tool.to_langchain())
else:
tools_list.append(tool)
return tools_list
def count_formatting_errors(self) -> None:
"""Count the formatting errors of the agent."""
self.formatting_errors += 1
def format_log_to_str(
self,
@@ -271,6 +272,16 @@ class Agent(BaseModel):
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
return thoughts
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
"""Parse tools to be used for the task."""
tools_list = []
for tool in tools:
if isinstance(tool, CrewAITool):
tools_list.append(tool.to_langchain())
else:
tools_list.append(tool)
return tools_list
@staticmethod
def __tools_names(tools) -> str:
return ", ".join([t.name for t in tools])

View File

@@ -1,4 +1,4 @@
from typing import Union
from typing import Any, Union
from langchain.agents.output_parsers import ReActSingleInputOutputParser
from langchain_core.agents import AgentAction, AgentFinish
@@ -34,6 +34,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
"""
_i18n: I18N = I18N()
agent: Any = None
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
includes_answer = FINAL_ANSWER_ACTION in text
@@ -41,6 +42,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
if includes_tool:
if includes_answer:
self.agent.count_formatting_errors()
raise OutputParserException(f"{FINAL_ANSWER_AND_TOOL_ERROR_MESSAGE}")
return AgentAction("", "", text)
@@ -52,6 +54,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
format = self._i18n.slice("format_without_tools")
error = f"{format}"
self.agent.count_formatting_errors()
raise OutputParserException(
error,
observation=error,

View File

@@ -4,11 +4,13 @@ from unittest.mock import patch
import pytest
from langchain.tools import tool
from langchain_core.exceptions import OutputParserException
from langchain_openai import ChatOpenAI
from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler
from crewai.agents.executor import CrewAgentExecutor
from crewai.agents.parser import CrewAgentParser
from crewai.tools.tool_calling import InstructorToolCalling
from crewai.tools.tool_usage import ToolUsage
from crewai.utilities import RPMController
@@ -576,3 +578,39 @@ def test_agent_function_calling_llm():
crew.kickoff()
private_mock.assert_called()
def test_agent_count_formatting_error():
from unittest.mock import patch
agent1 = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
verbose=True,
)
parser = CrewAgentParser()
parser.agent = agent1
with patch.object(Agent, "count_formatting_errors") as mock_count_errors:
test_text = "This text does not match expected formats."
with pytest.raises(OutputParserException):
parser.parse(test_text)
mock_count_errors.assert_called_once()
def test_agent_llm_uses_token_calc_handler_with_llm_has_model_name():
agent1 = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
verbose=True,
)
assert len(agent1.llm.callbacks) == 1
assert agent1.llm.callbacks[0].__class__.__name__ == "TokenCalcHandler"
assert agent1.llm.callbacks[0].model == "gpt-4"
assert (
agent1.llm.callbacks[0].token_cost_process.__class__.__name__ == "TokenProcess"
)