Files
crewAI/src/crewai/utilities/token_counter_callback.py
2024-02-28 03:44:23 -03:00

61 lines
1.9 KiB
Python

from typing import Any, Dict, List
import tiktoken
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
class TokenProcess:
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
def sum_prompt_tokens(self, tokens: int):
self.prompt_tokens = self.prompt_tokens + tokens
self.total_tokens = self.total_tokens + tokens
def sum_completion_tokens(self, tokens: int):
self.completion_tokens = self.completion_tokens + tokens
self.total_tokens = self.total_tokens + tokens
def sum_successful_requests(self, requests: int):
self.successful_requests = self.successful_requests + requests
def get_summary(self) -> str:
return {
"total_tokens": self.total_tokens,
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"successful_requests": self.successful_requests,
}
class TokenCalcHandler(BaseCallbackHandler):
model: str = ""
token_cost_process: TokenProcess
def __init__(self, model, token_cost_process):
self.model = model
self.token_cost_process = token_cost_process
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
if "gpt" in self.model:
encoding = tiktoken.encoding_for_model(self.model)
else:
encoding = tiktoken.get_encoding("cl100k_base")
if self.token_cost_process == None:
return
for prompt in prompts:
self.token_cost_process.sum_prompt_tokens(len(encoding.encode(prompt)))
async def on_llm_new_token(self, token: str, **kwargs) -> None:
self.token_cost_process.sum_completion_tokens(1)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.token_cost_process.sum_successful_requests(1)