mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Adding initial formatting error counting and token counter
This commit is contained in:
@@ -23,6 +23,7 @@ from pydantic_core import PydanticCustomError
|
|||||||
|
|
||||||
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
|
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
|
||||||
from crewai.utilities import I18N, Logger, Prompts, RPMController
|
from crewai.utilities import I18N, Logger, Prompts, RPMController
|
||||||
|
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess
|
||||||
|
|
||||||
|
|
||||||
class Agent(BaseModel):
|
class Agent(BaseModel):
|
||||||
@@ -51,7 +52,9 @@ class Agent(BaseModel):
|
|||||||
_logger: Logger = PrivateAttr()
|
_logger: Logger = PrivateAttr()
|
||||||
_rpm_controller: RPMController = PrivateAttr(default=None)
|
_rpm_controller: RPMController = PrivateAttr(default=None)
|
||||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||||
|
_token_process: TokenProcess = TokenProcess()
|
||||||
|
|
||||||
|
formatting_errors: int = 0
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
id: UUID4 = Field(
|
id: UUID4 = Field(
|
||||||
default_factory=uuid.uuid4,
|
default_factory=uuid.uuid4,
|
||||||
@@ -123,8 +126,12 @@ class Agent(BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_agent_executor(self) -> "Agent":
|
def set_agent_executor(self) -> "Agent":
|
||||||
"""Check if the agent executor is set."""
|
"""set agent executor is set."""
|
||||||
|
if hasattr(self.llm, "model_name"):
|
||||||
|
self.llm.callbacks = [
|
||||||
|
TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||||
|
]
|
||||||
if not self.agent_executor:
|
if not self.agent_executor:
|
||||||
self.set_cache_handler(self.cache_handler)
|
self.set_cache_handler(self.cache_handler)
|
||||||
return self
|
return self
|
||||||
@@ -243,20 +250,14 @@ class Agent(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
bind = self.llm.bind(stop=[self.i18n.slice("observation")])
|
bind = self.llm.bind(stop=[self.i18n.slice("observation")])
|
||||||
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser()
|
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
|
||||||
self.agent_executor = CrewAgentExecutor(
|
self.agent_executor = CrewAgentExecutor(
|
||||||
agent=RunnableAgent(runnable=inner_agent), **executor_args
|
agent=RunnableAgent(runnable=inner_agent), **executor_args
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
|
def count_formatting_errors(self) -> None:
|
||||||
"""Parse tools to be used for the task."""
|
"""Count the formatting errors of the agent."""
|
||||||
tools_list = []
|
self.formatting_errors += 1
|
||||||
for tool in tools:
|
|
||||||
if isinstance(tool, CrewAITool):
|
|
||||||
tools_list.append(tool.to_langchain())
|
|
||||||
else:
|
|
||||||
tools_list.append(tool)
|
|
||||||
return tools_list
|
|
||||||
|
|
||||||
def format_log_to_str(
|
def format_log_to_str(
|
||||||
self,
|
self,
|
||||||
@@ -271,6 +272,16 @@ class Agent(BaseModel):
|
|||||||
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
||||||
return thoughts
|
return thoughts
|
||||||
|
|
||||||
|
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
|
||||||
|
"""Parse tools to be used for the task."""
|
||||||
|
tools_list = []
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, CrewAITool):
|
||||||
|
tools_list.append(tool.to_langchain())
|
||||||
|
else:
|
||||||
|
tools_list.append(tool)
|
||||||
|
return tools_list
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tools_names(tools) -> str:
|
def __tools_names(tools) -> str:
|
||||||
return ", ".join([t.name for t in tools])
|
return ", ".join([t.name for t in tools])
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from langchain.agents.output_parsers import ReActSingleInputOutputParser
|
from langchain.agents.output_parsers import ReActSingleInputOutputParser
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
@@ -34,6 +34,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_i18n: I18N = I18N()
|
_i18n: I18N = I18N()
|
||||||
|
agent: Any = None
|
||||||
|
|
||||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||||
includes_answer = FINAL_ANSWER_ACTION in text
|
includes_answer = FINAL_ANSWER_ACTION in text
|
||||||
@@ -41,6 +42,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
|
|||||||
|
|
||||||
if includes_tool:
|
if includes_tool:
|
||||||
if includes_answer:
|
if includes_answer:
|
||||||
|
self.agent.count_formatting_errors()
|
||||||
raise OutputParserException(f"{FINAL_ANSWER_AND_TOOL_ERROR_MESSAGE}")
|
raise OutputParserException(f"{FINAL_ANSWER_AND_TOOL_ERROR_MESSAGE}")
|
||||||
|
|
||||||
return AgentAction("", "", text)
|
return AgentAction("", "", text)
|
||||||
@@ -52,6 +54,7 @@ class CrewAgentParser(ReActSingleInputOutputParser):
|
|||||||
|
|
||||||
format = self._i18n.slice("format_without_tools")
|
format = self._i18n.slice("format_without_tools")
|
||||||
error = f"{format}"
|
error = f"{format}"
|
||||||
|
self.agent.count_formatting_errors()
|
||||||
raise OutputParserException(
|
raise OutputParserException(
|
||||||
error,
|
error,
|
||||||
observation=error,
|
observation=error,
|
||||||
|
|||||||
@@ -4,11 +4,13 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from crewai import Agent, Crew, Task
|
from crewai import Agent, Crew, Task
|
||||||
from crewai.agents.cache import CacheHandler
|
from crewai.agents.cache import CacheHandler
|
||||||
from crewai.agents.executor import CrewAgentExecutor
|
from crewai.agents.executor import CrewAgentExecutor
|
||||||
|
from crewai.agents.parser import CrewAgentParser
|
||||||
from crewai.tools.tool_calling import InstructorToolCalling
|
from crewai.tools.tool_calling import InstructorToolCalling
|
||||||
from crewai.tools.tool_usage import ToolUsage
|
from crewai.tools.tool_usage import ToolUsage
|
||||||
from crewai.utilities import RPMController
|
from crewai.utilities import RPMController
|
||||||
@@ -576,3 +578,39 @@ def test_agent_function_calling_llm():
|
|||||||
|
|
||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
private_mock.assert_called()
|
private_mock.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_count_formatting_error():
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
agent1 = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = CrewAgentParser()
|
||||||
|
parser.agent = agent1
|
||||||
|
|
||||||
|
with patch.object(Agent, "count_formatting_errors") as mock_count_errors:
|
||||||
|
test_text = "This text does not match expected formats."
|
||||||
|
with pytest.raises(OutputParserException):
|
||||||
|
parser.parse(test_text)
|
||||||
|
mock_count_errors.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_llm_uses_token_calc_handler_with_llm_has_model_name():
|
||||||
|
agent1 = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(agent1.llm.callbacks) == 1
|
||||||
|
assert agent1.llm.callbacks[0].__class__.__name__ == "TokenCalcHandler"
|
||||||
|
assert agent1.llm.callbacks[0].model == "gpt-4"
|
||||||
|
assert (
|
||||||
|
agent1.llm.callbacks[0].token_cost_process.__class__.__name__ == "TokenProcess"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user