mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Adding initial formatting error counting and token counter
This commit is contained in:
@@ -23,6 +23,7 @@ from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
|
||||
from crewai.utilities import I18N, Logger, Prompts, RPMController
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess
|
||||
|
||||
|
||||
class Agent(BaseModel):
|
||||
@@ -51,7 +52,9 @@ class Agent(BaseModel):
|
||||
_logger: Logger = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_token_process: TokenProcess = TokenProcess()
|
||||
|
||||
formatting_errors: int = 0
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
id: UUID4 = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
@@ -123,8 +126,12 @@ class Agent(BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_agent_executor(self) -> "Agent":
|
||||
"""Check if the agent executor is set."""
|
||||
def set_agent_executor(self) -> "Agent":
|
||||
"""set agent executor is set."""
|
||||
if hasattr(self.llm, "model_name"):
|
||||
self.llm.callbacks = [
|
||||
TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||
]
|
||||
if not self.agent_executor:
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
return self
|
||||
@@ -243,20 +250,14 @@ class Agent(BaseModel):
|
||||
)
|
||||
|
||||
bind = self.llm.bind(stop=[self.i18n.slice("observation")])
|
||||
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser()
|
||||
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
agent=RunnableAgent(runnable=inner_agent), **executor_args
|
||||
)
|
||||
|
||||
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
|
||||
"""Parse tools to be used for the task."""
|
||||
tools_list = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, CrewAITool):
|
||||
tools_list.append(tool.to_langchain())
|
||||
else:
|
||||
tools_list.append(tool)
|
||||
return tools_list
|
||||
def count_formatting_errors(self) -> None:
|
||||
"""Count the formatting errors of the agent."""
|
||||
self.formatting_errors += 1
|
||||
|
||||
def format_log_to_str(
|
||||
self,
|
||||
@@ -271,6 +272,16 @@ class Agent(BaseModel):
|
||||
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
||||
return thoughts
|
||||
|
||||
def _parse_tools(self, tools: List[Any]) -> List[LangChainTool]:
|
||||
"""Parse tools to be used for the task."""
|
||||
tools_list = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, CrewAITool):
|
||||
tools_list.append(tool.to_langchain())
|
||||
else:
|
||||
tools_list.append(tool)
|
||||
return tools_list
|
||||
|
||||
@staticmethod
|
||||
def __tools_names(tools) -> str:
|
||||
return ", ".join([t.name for t in tools])
|
||||
|
||||
Reference in New Issue
Block a user