mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Merge pull request #1064 from crewAIInc/thiago/pipeline-fix
Fix flaky test due to suppressed error on `on_llm_start` callback
This commit is contained in:
@@ -10,24 +10,24 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
|
||||
class TokenCalcHandler(BaseCallbackHandler):
|
||||
model_name: str = ""
|
||||
token_cost_process: TokenProcess
|
||||
encoding: tiktoken.Encoding
|
||||
|
||||
def __init__(self, model_name, token_cost_process):
|
||||
self.model_name = model_name
|
||||
self.token_cost_process = token_cost_process
|
||||
try:
|
||||
self.encoding = tiktoken.encoding_for_model(self.model_name)
|
||||
except KeyError:
|
||||
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(self.model_name)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
if self.token_cost_process is None:
|
||||
return
|
||||
|
||||
for prompt in prompts:
|
||||
self.token_cost_process.sum_prompt_tokens(len(encoding.encode(prompt)))
|
||||
self.token_cost_process.sum_prompt_tokens(len(self.encoding.encode(prompt)))
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs) -> None:
|
||||
self.token_cost_process.sum_completion_tokens(1)
|
||||
|
||||
Reference in New Issue
Block a user