Fix token tracking in LangChainAgentAdapter and refactor token_process attribute to be public

This commit is contained in:
Brandon Hancock
2025-02-28 11:31:08 -05:00
parent 33ef612cd5
commit 88d8079dcd
6 changed files with 552 additions and 44 deletions

View File

@@ -82,13 +82,18 @@ class BaseAgent(ABC, BaseModel):
"""
__hash__ = object.__hash__ # type: ignore
model_config = {
"arbitrary_types_allowed": True,
}
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
_original_role: Optional[str] = PrivateAttr(default=None)
_original_goal: Optional[str] = PrivateAttr(default=None)
_original_backstory: Optional[str] = PrivateAttr(default=None)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
formatting_errors: int = Field(
default=0, description="Number of formatting errors."
@@ -198,8 +203,6 @@ class BaseAgent(ABC, BaseModel):
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger
)
if not self._token_process:
self._token_process = TokenProcess()
return self
@@ -219,8 +222,7 @@ class BaseAgent(ABC, BaseModel):
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger
)
if not self._token_process:
self._token_process = TokenProcess()
return self
@property
@@ -268,7 +270,7 @@ class BaseAgent(ABC, BaseModel):
"_logger",
"_rpm_controller",
"_request_within_rpm_limit",
"_token_process",
"token_process",
"agent_executor",
"tools",
"tools_handler",

View File

@@ -1,18 +1,17 @@
from typing import Any, List, Optional, Type, Union, cast
from crewai.tools.base_tool import Tool
try:
from langchain_core.tools import Tool as LangChainTool # type: ignore
except ImportError:
LangChainTool = None
from pydantic import Field, field_validator
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.task import Task
from crewai.tools import BaseTool
from crewai.tools.base_tool import Tool
from crewai.utilities.converter import Converter, generate_model_description
from crewai.utilities.token_counter_callback import (
LangChainTokenCounter,
LiteLLMTokenCounter,
)
class LangChainAgentAdapter(BaseAgent):
@@ -51,6 +50,8 @@ class LangChainAgentAdapter(BaseAgent):
i18n: Any = None
crew: Any = None
knowledge: Any = None
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
token_callback: Optional[Any] = None
class Config:
arbitrary_types_allowed = True
@@ -72,16 +73,35 @@ class LangChainAgentAdapter(BaseAgent):
def _extract_text(self, message: Any) -> str:
"""
Helper to extract plain text from a message object.
This checks if the message is a dict with a "content" key, or has a "content" attribute.
This checks if the message is a dict with a "content" key, or has a "content" attribute,
or if it's a tuple from LangGraph's message format.
"""
if isinstance(message, dict) and "content" in message:
return message["content"]
# Handle LangGraph message tuple format (role, content)
if isinstance(message, tuple) and len(message) == 2:
return str(message[1])
# Handle dictionary with content key
elif isinstance(message, dict):
if "content" in message:
return message["content"]
# Handle LangGraph message format with additional metadata
elif "messages" in message and message["messages"]:
last_message = message["messages"][-1]
if isinstance(last_message, tuple) and len(last_message) == 2:
return str(last_message[1])
return self._extract_text(last_message)
# Handle object with content attribute
elif hasattr(message, "content") and isinstance(
getattr(message, "content"), str
):
return getattr(message, "content")
# Handle string directly
elif isinstance(message, str):
return message
# Default fallback
return str(message)
def execute_task(
@@ -161,19 +181,77 @@ class LangChainAgentAdapter(BaseAgent):
else:
task_prompt = self._use_trained_data(task_prompt=task_prompt)
# Initialize token tracking callback if needed
if hasattr(self, "token_process") and self.token_callback is None:
# Determine if we're using LangChain or LiteLLM based on the agent type
if hasattr(self.langchain_agent, "client") and hasattr(
self.langchain_agent.client, "callbacks"
):
# This is likely a LiteLLM-based agent
self.token_callback = LiteLLMTokenCounter(self.token_process)
# Add our callback to the LLM directly
if isinstance(self.langchain_agent.client.callbacks, list):
self.langchain_agent.client.callbacks.append(self.token_callback)
else:
self.langchain_agent.client.callbacks = [self.token_callback]
else:
# This is likely a LangChain-based agent
self.token_callback = LangChainTokenCounter(self.token_process)
# Add callback to the LangChain model
if hasattr(self.langchain_agent, "callbacks"):
if self.langchain_agent.callbacks is None:
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
# For direct LLM models
elif hasattr(self.langchain_agent, "llm") and hasattr(
self.langchain_agent.llm, "callbacks"
):
if self.langchain_agent.llm.callbacks is None:
self.langchain_agent.llm.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.llm.callbacks, list):
self.langchain_agent.llm.callbacks.append(self.token_callback)
# Direct LLM case
elif not hasattr(self.langchain_agent, "agent"):
# This might be a direct LLM, not an agent
if (
not hasattr(self.langchain_agent, "callbacks")
or self.langchain_agent.callbacks is None
):
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
init_state = {"messages": [("user", task_prompt)]}
# Estimate input tokens for tracking
if hasattr(self, "token_process"):
# Rough estimate based on characters (better than word count)
estimated_prompt_tokens = len(task_prompt) // 4 # ~4 chars per token
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
state = self.agent_executor.invoke(init_state)
# Extract output from state based on its structure
if "structured_response" in state:
current_output = state["structured_response"]
elif "messages" in state and state["messages"]:
last_message = state["messages"][-1]
if isinstance(last_message, tuple):
current_output = last_message[1]
else:
current_output = self._extract_text(last_message)
current_output = self._extract_text(last_message)
elif "output" in state:
current_output = str(state["output"])
else:
current_output = ""
# Fallback to extracting text from the entire state
current_output = self._extract_text(state)
# Estimate completion tokens for tracking if we don't have actual counts
if hasattr(self, "token_process"):
# Rough estimate based on characters
estimated_completion_tokens = len(current_output) // 4 # ~4 chars per token
self.token_process.sum_completion_tokens(estimated_completion_tokens)
self.token_process.sum_successful_requests(1)
if task.human_input:
current_output = self._handle_human_feedback(current_output)
@@ -203,20 +281,40 @@ class LangChainAgentAdapter(BaseAgent):
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n"
f"Updated answer:"
)
# Estimate input tokens for tracking
if hasattr(self, "token_process"):
# Rough estimate based on characters
estimated_prompt_tokens = len(new_prompt) // 4 # ~4 chars per token
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
try:
new_state = self.agent_executor.invoke(
{"messages": [("user", new_prompt)]}
)
# Extract output from state based on its structure
if "structured_response" in new_state:
new_output = new_state["structured_response"]
elif "messages" in new_state and new_state["messages"]:
last_message = new_state["messages"][-1]
if isinstance(last_message, tuple):
new_output = last_message[1]
else:
new_output = self._extract_text(last_message)
new_output = self._extract_text(last_message)
elif "output" in new_state:
new_output = str(new_state["output"])
else:
new_output = ""
# Fallback to extracting text from the entire state
new_output = self._extract_text(new_state)
# Estimate completion tokens for tracking
if hasattr(self, "token_process"):
# Rough estimate based on characters
estimated_completion_tokens = (
len(new_output) // 4
) # ~4 chars per token
self.token_process.sum_completion_tokens(
estimated_completion_tokens
)
self.token_process.sum_successful_requests(1)
current_output = new_output
except Exception as e:
print("Error during re-invocation with feedback:", e)
@@ -310,6 +408,52 @@ class LangChainAgentAdapter(BaseAgent):
agent_role = getattr(self, "role", "agent")
sanitized_role = re.sub(r"\s+", "_", agent_role)
# Initialize token tracking callback if needed
if hasattr(self, "token_process") and self.token_callback is None:
# Determine if we're using LangChain or LiteLLM based on the agent type
if hasattr(self.langchain_agent, "client") and hasattr(
self.langchain_agent.client, "callbacks"
):
# This is likely a LiteLLM-based agent
self.token_callback = LiteLLMTokenCounter(self.token_process)
# Add our callback to the LLM directly
if isinstance(self.langchain_agent.client.callbacks, list):
if self.token_callback not in self.langchain_agent.client.callbacks:
self.langchain_agent.client.callbacks.append(
self.token_callback
)
else:
self.langchain_agent.client.callbacks = [self.token_callback]
else:
# This is likely a LangChain-based agent
self.token_callback = LangChainTokenCounter(self.token_process)
# Add callback to the LangChain model
if hasattr(self.langchain_agent, "callbacks"):
if self.langchain_agent.callbacks is None:
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
# For direct LLM models
elif hasattr(self.langchain_agent, "llm") and hasattr(
self.langchain_agent.llm, "callbacks"
):
if self.langchain_agent.llm.callbacks is None:
self.langchain_agent.llm.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.llm.callbacks, list):
self.langchain_agent.llm.callbacks.append(self.token_callback)
# Direct LLM case
elif not hasattr(self.langchain_agent, "agent"):
# This might be a direct LLM, not an agent
if (
not hasattr(self.langchain_agent, "callbacks")
or self.langchain_agent.callbacks is None
):
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
self.agent_executor = create_react_agent(
model=self.langchain_agent,
tools=used_tools,

View File

@@ -641,7 +641,7 @@ class Crew(BaseModel):
for after_callback in self.after_kickoff_callbacks:
result = after_callback(result)
metrics += [agent._token_process.get_summary() for agent in self.agents]
metrics += [agent.token_process.get_summary() for agent in self.agents]
self.usage_metrics = UsageMetrics()
for metric in metrics:
@@ -1195,12 +1195,15 @@ class Crew(BaseModel):
"""Calculates and returns the usage metrics."""
total_usage_metrics = UsageMetrics()
for agent in self.agents:
if hasattr(agent, "_token_process"):
token_sum = agent._token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
token_sum = self.manager_agent._token_process.get_summary()
# Directly access token_process since it's now a field in BaseAgent
token_sum = agent.token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
if self.manager_agent:
# Directly access token_process since it's now a field in BaseAgent
token_sum = self.manager_agent.token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
self.usage_metrics = total_usage_metrics
return total_usage_metrics

View File

@@ -1,15 +1,52 @@
import warnings
from typing import Any, Dict, Optional
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks.base import BaseCallbackHandler
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import Usage
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
class TokenCalcHandler(CustomLogger):
def __init__(self, token_cost_process: Optional[TokenProcess]):
self.token_cost_process = token_cost_process
class AbstractTokenCounter(ABC):
"""
Abstract base class for token counting callbacks.
Implementations should track token usage from different LLM providers.
"""
def __init__(self, token_process: Optional[TokenProcess] = None):
"""Initialize with a TokenProcess instance to track tokens."""
self.token_process = token_process
@abstractmethod
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
"""Update token usage counts in the token process."""
pass
class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter):
"""
Token counter implementation for LiteLLM.
Uses LiteLLM's CustomLogger interface to track token usage.
"""
def __init__(self, token_process: Optional[TokenProcess] = None):
AbstractTokenCounter.__init__(self, token_process)
CustomLogger.__init__(self)
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
"""Update token usage counts in the token process."""
if self.token_process is None:
return
if prompt_tokens > 0:
self.token_process.sum_prompt_tokens(prompt_tokens)
if completion_tokens > 0:
self.token_process.sum_completion_tokens(completion_tokens)
self.token_process.sum_successful_requests(1)
def log_success_event(
self,
@@ -18,7 +55,11 @@ class TokenCalcHandler(CustomLogger):
start_time: float,
end_time: float,
) -> None:
if self.token_cost_process is None:
"""
Process successful LLM call and extract token usage information.
This method is called by LiteLLM after a successful completion.
"""
if self.token_process is None:
return
with warnings.catch_warnings():
@@ -26,18 +67,154 @@ class TokenCalcHandler(CustomLogger):
if isinstance(response_obj, dict) and "usage" in response_obj:
usage: Usage = response_obj["usage"]
if usage:
self.token_cost_process.sum_successful_requests(1)
prompt_tokens = 0
completion_tokens = 0
if hasattr(usage, "prompt_tokens"):
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
prompt_tokens = usage.prompt_tokens
elif isinstance(usage, dict) and "prompt_tokens" in usage:
prompt_tokens = usage["prompt_tokens"]
if hasattr(usage, "completion_tokens"):
self.token_cost_process.sum_completion_tokens(
usage.completion_tokens
)
completion_tokens = usage.completion_tokens
elif isinstance(usage, dict) and "completion_tokens" in usage:
completion_tokens = usage["completion_tokens"]
self.update_token_usage(prompt_tokens, completion_tokens)
# Handle cached tokens if available
if (
hasattr(usage, "prompt_tokens_details")
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
self.token_cost_process.sum_cached_prompt_tokens(
self.token_process.sum_cached_prompt_tokens(
usage.prompt_tokens_details.cached_tokens
)
class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
"""
Token counter implementation for LangChain.
Implements the necessary callback methods to track token usage from LangChain responses.
"""
def __init__(self, token_process: Optional[TokenProcess] = None):
BaseCallbackHandler.__init__(self)
AbstractTokenCounter.__init__(self, token_process)
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
"""Update token usage counts in the token process."""
if self.token_process is None:
return
if prompt_tokens > 0:
self.token_process.sum_prompt_tokens(prompt_tokens)
if completion_tokens > 0:
self.token_process.sum_completion_tokens(completion_tokens)
self.token_process.sum_successful_requests(1)
@property
def ignore_llm(self) -> bool:
return False
@property
def ignore_chain(self) -> bool:
return True
@property
def ignore_agent(self) -> bool:
return False
@property
def ignore_chat_model(self) -> bool:
return False
@property
def ignore_retriever(self) -> bool:
return True
@property
def ignore_tools(self) -> bool:
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Called when LLM starts processing."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Called when LLM generates a new token."""
pass
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
"""
Called when LLM ends processing.
Extracts token usage from LangChain response objects.
"""
if self.token_process is None:
return
# Handle LangChain response format
if hasattr(response, "llm_output") and isinstance(response.llm_output, dict):
token_usage = response.llm_output.get("token_usage", {})
prompt_tokens = token_usage.get("prompt_tokens", 0)
completion_tokens = token_usage.get("completion_tokens", 0)
self.update_token_usage(prompt_tokens, completion_tokens)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Called when LLM errors."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Called when a chain starts."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Called when a chain ends."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Called when a chain errors."""
pass
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Called when a tool starts."""
pass
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Called when a tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Called when a tool errors."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Called when text is generated."""
pass
def on_agent_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None:
"""Called when an agent starts."""
pass
def on_agent_end(self, output: Any, **kwargs: Any) -> None:
"""Called when an agent ends."""
pass
def on_agent_error(self, error: BaseException, **kwargs: Any) -> None:
"""Called when an agent errors."""
pass
# For backward compatibility
TokenCalcHandler = LiteLLMTokenCounter