mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Adding initial formatting error counting and token counter
This commit is contained in:
@@ -23,6 +23,7 @@ from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
|
||||
from crewai.utilities import I18N, Logger, Prompts, RPMController
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess
|
||||
|
||||
|
||||
class Agent(BaseModel):
|
||||
@@ -51,7 +52,9 @@ class Agent(BaseModel):
|
||||
_logger: Logger = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_token_process: TokenProcess = TokenProcess()
|
||||
|
||||
formatting_errors: int = 0
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
id: UUID4 = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
@@ -123,8 +126,12 @@ class Agent(BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_agent_executor(self) -> "Agent":
|
||||
"""Check if the agent executor is set."""
|
||||
def set_agent_executor(self) -> "Agent":
|
||||
"""set agent executor is set."""
|
||||
if hasattr(self.llm, "model_name"):
|
||||
self.llm.callbacks = [
|
||||
TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||
]
|
||||
if not self.agent_executor:
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
return self
|
||||
@@ -243,20 +250,14 @@ class Agent(BaseModel):
|
||||
)
|
||||
|
||||
bind = self.llm.bind(stop=[self.i18n.slice("observation")])
|
||||
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser()
|
||||
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
agent=RunnableAgent(runnable=inner_agent), **executor_args
|
||||
)
|
||||
|
||||
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
|
||||
"""Parse tools to be used for the task."""
|
||||
tools_list = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, CrewAITool):
|
||||
tools_list.append(tool.to_langchain())
|
||||
else:
|
||||
tools_list.append(tool)
|
||||
return tools_list
|
||||
def count_formatting_errors(self) -> None:
|
||||
"""Count the formatting errors of the agent."""
|
||||
self.formatting_errors += 1
|
||||
|
||||
def format_log_to_str(
|
||||
self,
|
||||
@@ -271,6 +272,16 @@ class Agent(BaseModel):
|
||||
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
||||
return thoughts
|
||||
|
||||
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
|
||||
"""Parse tools to be used for the task."""
|
||||
tools_list = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, CrewAITool):
|
||||
tools_list.append(tool.to_langchain())
|
||||
else:
|
||||
tools_list.append(tool)
|
||||
return tools_list
|
||||
|
||||
@staticmethod
|
||||
def __tools_names(tools) -> str:
|
||||
return ", ".join([t.name for t in tools])
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain.agents.output_parsers import ReActSingleInputOutputParser
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
@@ -34,6 +34,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
|
||||
"""
|
||||
|
||||
_i18n: I18N = I18N()
|
||||
agent: Any = None
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
@@ -41,6 +42,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
|
||||
|
||||
if includes_tool:
|
||||
if includes_answer:
|
||||
self.agent.count_formatting_errors()
|
||||
raise OutputParserException(f"{FINAL_ANSWER_AND_TOOL_ERROR_MESSAGE}")
|
||||
|
||||
return AgentAction("", "", text)
|
||||
@@ -52,6 +54,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
|
||||
|
||||
format = self._i18n.slice("format_without_tools")
|
||||
error = f"{format}"
|
||||
self.agent.count_formatting_errors()
|
||||
raise OutputParserException(
|
||||
error,
|
||||
observation=error,
|
||||
|
||||
@@ -4,11 +4,13 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain.tools import tool
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.executor import CrewAgentExecutor
|
||||
from crewai.agents.parser import CrewAgentParser
|
||||
from crewai.tools.tool_calling import InstructorToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities import RPMController
|
||||
@@ -576,3 +578,39 @@ def test_agent_function_calling_llm():
|
||||
|
||||
crew.kickoff()
|
||||
private_mock.assert_called()
|
||||
|
||||
|
||||
def test_agent_count_formatting_error():
|
||||
from unittest.mock import patch
|
||||
|
||||
agent1 = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
parser = CrewAgentParser()
|
||||
parser.agent = agent1
|
||||
|
||||
with patch.object(Agent, "count_formatting_errors") as mock_count_errors:
|
||||
test_text = "This text does not match expected formats."
|
||||
with pytest.raises(OutputParserException):
|
||||
parser.parse(test_text)
|
||||
mock_count_errors.assert_called_once()
|
||||
|
||||
|
||||
def test_agent_llm_uses_token_calc_handler_with_llm_has_model_name():
|
||||
agent1 = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert len(agent1.llm.callbacks) == 1
|
||||
assert agent1.llm.callbacks[0].__class__.__name__ == "TokenCalcHandler"
|
||||
assert agent1.llm.callbacks[0].model == "gpt-4"
|
||||
assert (
|
||||
agent1.llm.callbacks[0].token_cost_process.__class__.__name__ == "TokenProcess"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user