from typing import Any, Dict, List import tiktoken from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import LLMResult class TokenProcess: total_tokens: int = 0 prompt_tokens: int = 0 completion_tokens: int = 0 successful_requests: int = 0 def sum_prompt_tokens(self, tokens: int): self.prompt_tokens = self.prompt_tokens + tokens self.total_tokens = self.total_tokens + tokens def sum_completion_tokens(self, tokens: int): self.completion_tokens = self.completion_tokens + tokens self.total_tokens = self.total_tokens + tokens def sum_successful_requests(self, requests: int): self.successful_requests = self.successful_requests + requests def get_summary(self) -> str: return { "total_tokens": self.total_tokens, "prompt_tokens": self.prompt_tokens, "completion_tokens": self.completion_tokens, "successful_requests": self.successful_requests, } class TokenCalcHandler(BaseCallbackHandler): model: str = "" token_cost_process: TokenProcess def __init__(self, model, token_cost_process): self.model = model self.token_cost_process = token_cost_process def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: if "gpt" in self.model: encoding = tiktoken.encoding_for_model(self.model) else: encoding = tiktoken.get_encoding("cl100k_base") if self.token_cost_process == None: return for prompt in prompts: self.token_cost_process.sum_prompt_tokens(len(encoding.encode(prompt))) async def on_llm_new_token(self, token: str, **kwargs) -> None: self.token_cost_process.sum_completion_tokens(1) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.token_cost_process.sum_successful_requests(1)