Adding initial formatting error counting and token counter

This commit is contained in:
João Moura
2024-02-28 01:57:04 -03:00
parent 577db88f8e
commit 7b49b4e985
3 changed files with 65 additions and 13 deletions

View File

@@ -23,6 +23,7 @@ from pydantic_core import PydanticCustomError
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
from crewai.utilities import I18N, Logger, Prompts, RPMController
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess
class Agent(BaseModel):
@@ -51,7 +52,9 @@ class Agent(BaseModel):
_logger: Logger = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
_token_process: TokenProcess = TokenProcess()
formatting_errors: int = 0
model_config = ConfigDict(arbitrary_types_allowed=True)
id: UUID4 = Field(
default_factory=uuid.uuid4,
@@ -123,8 +126,12 @@ class Agent(BaseModel):
return self
@model_validator(mode="after")
def check_agent_executor(self) -> "Agent":
"""Check if the agent executor is set."""
def set_agent_executor(self) -> "Agent":
"""set agent executor is set."""
if hasattr(self.llm, "model_name"):
self.llm.callbacks = [
TokenCalcHandler(self.llm.model_name, self._token_process)
]
if not self.agent_executor:
self.set_cache_handler(self.cache_handler)
return self
@@ -243,20 +250,14 @@ class Agent(BaseModel):
)
bind = self.llm.bind(stop=[self.i18n.slice("observation")])
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser()
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
self.agent_executor = CrewAgentExecutor(
agent=RunnableAgent(runnable=inner_agent), **executor_args
)
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
"""Parse tools to be used for the task."""
tools_list = []
for tool in tools:
if isinstance(tool, CrewAITool):
tools_list.append(tool.to_langchain())
else:
tools_list.append(tool)
return tools_list
def count_formatting_errors(self) -> None:
"""Count the formatting errors of the agent."""
self.formatting_errors += 1
def format_log_to_str(
self,
@@ -271,6 +272,16 @@ class Agent(BaseModel):
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
return thoughts
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
"""Parse tools to be used for the task."""
tools_list = []
for tool in tools:
if isinstance(tool, CrewAITool):
tools_list.append(tool.to_langchain())
else:
tools_list.append(tool)
return tools_list
@staticmethod
def __tools_names(tools) -> str:
return ", ".join([t.name for t in tools])