Fix token tracking in LangChainAgentAdapter and refactor token_process attribute to be public

This commit is contained in:
Brandon Hancock
2025-02-28 11:31:08 -05:00
parent 33ef612cd5
commit 88d8079dcd
6 changed files with 552 additions and 44 deletions

View File

@@ -82,13 +82,18 @@ class BaseAgent(ABC, BaseModel):
"""
__hash__ = object.__hash__ # type: ignore
model_config = {
"arbitrary_types_allowed": True,
}
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
_original_role: Optional[str] = PrivateAttr(default=None)
_original_goal: Optional[str] = PrivateAttr(default=None)
_original_backstory: Optional[str] = PrivateAttr(default=None)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
formatting_errors: int = Field(
default=0, description="Number of formatting errors."
@@ -198,8 +203,6 @@ class BaseAgent(ABC, BaseModel):
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger
)
if not self._token_process:
self._token_process = TokenProcess()
return self
@@ -219,8 +222,7 @@ class BaseAgent(ABC, BaseModel):
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger
)
if not self._token_process:
self._token_process = TokenProcess()
return self
@property
@@ -268,7 +270,7 @@ class BaseAgent(ABC, BaseModel):
"_logger",
"_rpm_controller",
"_request_within_rpm_limit",
"_token_process",
"token_process",
"agent_executor",
"tools",
"tools_handler",

View File

@@ -1,18 +1,17 @@
from typing import Any, List, Optional, Type, Union, cast
from crewai.tools.base_tool import Tool
try:
from langchain_core.tools import Tool as LangChainTool # type: ignore
except ImportError:
LangChainTool = None
from pydantic import Field, field_validator
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.task import Task
from crewai.tools import BaseTool
from crewai.tools.base_tool import Tool
from crewai.utilities.converter import Converter, generate_model_description
from crewai.utilities.token_counter_callback import (
LangChainTokenCounter,
LiteLLMTokenCounter,
)
class LangChainAgentAdapter(BaseAgent):
@@ -51,6 +50,8 @@ class LangChainAgentAdapter(BaseAgent):
i18n: Any = None
crew: Any = None
knowledge: Any = None
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
token_callback: Optional[Any] = None
class Config:
arbitrary_types_allowed = True
@@ -72,16 +73,35 @@ class LangChainAgentAdapter(BaseAgent):
def _extract_text(self, message: Any) -> str:
"""
Helper to extract plain text from a message object.
This checks if the message is a dict with a "content" key, or has a "content" attribute.
This checks if the message is a dict with a "content" key, or has a "content" attribute,
or if it's a tuple from LangGraph's message format.
"""
if isinstance(message, dict) and "content" in message:
return message["content"]
# Handle LangGraph message tuple format (role, content)
if isinstance(message, tuple) and len(message) == 2:
return str(message[1])
# Handle dictionary with content key
elif isinstance(message, dict):
if "content" in message:
return message["content"]
# Handle LangGraph message format with additional metadata
elif "messages" in message and message["messages"]:
last_message = message["messages"][-1]
if isinstance(last_message, tuple) and len(last_message) == 2:
return str(last_message[1])
return self._extract_text(last_message)
# Handle object with content attribute
elif hasattr(message, "content") and isinstance(
getattr(message, "content"), str
):
return getattr(message, "content")
# Handle string directly
elif isinstance(message, str):
return message
# Default fallback
return str(message)
def execute_task(
@@ -161,19 +181,77 @@ class LangChainAgentAdapter(BaseAgent):
else:
task_prompt = self._use_trained_data(task_prompt=task_prompt)
# Initialize token tracking callback if needed
if hasattr(self, "token_process") and self.token_callback is None:
# Determine if we're using LangChain or LiteLLM based on the agent type
if hasattr(self.langchain_agent, "client") and hasattr(
self.langchain_agent.client, "callbacks"
):
# This is likely a LiteLLM-based agent
self.token_callback = LiteLLMTokenCounter(self.token_process)
# Add our callback to the LLM directly
if isinstance(self.langchain_agent.client.callbacks, list):
self.langchain_agent.client.callbacks.append(self.token_callback)
else:
self.langchain_agent.client.callbacks = [self.token_callback]
else:
# This is likely a LangChain-based agent
self.token_callback = LangChainTokenCounter(self.token_process)
# Add callback to the LangChain model
if hasattr(self.langchain_agent, "callbacks"):
if self.langchain_agent.callbacks is None:
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
# For direct LLM models
elif hasattr(self.langchain_agent, "llm") and hasattr(
self.langchain_agent.llm, "callbacks"
):
if self.langchain_agent.llm.callbacks is None:
self.langchain_agent.llm.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.llm.callbacks, list):
self.langchain_agent.llm.callbacks.append(self.token_callback)
# Direct LLM case
elif not hasattr(self.langchain_agent, "agent"):
# This might be a direct LLM, not an agent
if (
not hasattr(self.langchain_agent, "callbacks")
or self.langchain_agent.callbacks is None
):
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
init_state = {"messages": [("user", task_prompt)]}
# Estimate input tokens for tracking
if hasattr(self, "token_process"):
# Rough estimate based on characters (better than word count)
estimated_prompt_tokens = len(task_prompt) // 4 # ~4 chars per token
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
state = self.agent_executor.invoke(init_state)
# Extract output from state based on its structure
if "structured_response" in state:
current_output = state["structured_response"]
elif "messages" in state and state["messages"]:
last_message = state["messages"][-1]
if isinstance(last_message, tuple):
current_output = last_message[1]
else:
current_output = self._extract_text(last_message)
current_output = self._extract_text(last_message)
elif "output" in state:
current_output = str(state["output"])
else:
current_output = ""
# Fallback to extracting text from the entire state
current_output = self._extract_text(state)
# Estimate completion tokens for tracking if we don't have actual counts
if hasattr(self, "token_process"):
# Rough estimate based on characters
estimated_completion_tokens = len(current_output) // 4 # ~4 chars per token
self.token_process.sum_completion_tokens(estimated_completion_tokens)
self.token_process.sum_successful_requests(1)
if task.human_input:
current_output = self._handle_human_feedback(current_output)
@@ -203,20 +281,40 @@ class LangChainAgentAdapter(BaseAgent):
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n"
f"Updated answer:"
)
# Estimate input tokens for tracking
if hasattr(self, "token_process"):
# Rough estimate based on characters
estimated_prompt_tokens = len(new_prompt) // 4 # ~4 chars per token
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
try:
new_state = self.agent_executor.invoke(
{"messages": [("user", new_prompt)]}
)
# Extract output from state based on its structure
if "structured_response" in new_state:
new_output = new_state["structured_response"]
elif "messages" in new_state and new_state["messages"]:
last_message = new_state["messages"][-1]
if isinstance(last_message, tuple):
new_output = last_message[1]
else:
new_output = self._extract_text(last_message)
new_output = self._extract_text(last_message)
elif "output" in new_state:
new_output = str(new_state["output"])
else:
new_output = ""
# Fallback to extracting text from the entire state
new_output = self._extract_text(new_state)
# Estimate completion tokens for tracking
if hasattr(self, "token_process"):
# Rough estimate based on characters
estimated_completion_tokens = (
len(new_output) // 4
) # ~4 chars per token
self.token_process.sum_completion_tokens(
estimated_completion_tokens
)
self.token_process.sum_successful_requests(1)
current_output = new_output
except Exception as e:
print("Error during re-invocation with feedback:", e)
@@ -310,6 +408,52 @@ class LangChainAgentAdapter(BaseAgent):
agent_role = getattr(self, "role", "agent")
sanitized_role = re.sub(r"\s+", "_", agent_role)
# Initialize token tracking callback if needed
if hasattr(self, "token_process") and self.token_callback is None:
# Determine if we're using LangChain or LiteLLM based on the agent type
if hasattr(self.langchain_agent, "client") and hasattr(
self.langchain_agent.client, "callbacks"
):
# This is likely a LiteLLM-based agent
self.token_callback = LiteLLMTokenCounter(self.token_process)
# Add our callback to the LLM directly
if isinstance(self.langchain_agent.client.callbacks, list):
if self.token_callback not in self.langchain_agent.client.callbacks:
self.langchain_agent.client.callbacks.append(
self.token_callback
)
else:
self.langchain_agent.client.callbacks = [self.token_callback]
else:
# This is likely a LangChain-based agent
self.token_callback = LangChainTokenCounter(self.token_process)
# Add callback to the LangChain model
if hasattr(self.langchain_agent, "callbacks"):
if self.langchain_agent.callbacks is None:
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
# For direct LLM models
elif hasattr(self.langchain_agent, "llm") and hasattr(
self.langchain_agent.llm, "callbacks"
):
if self.langchain_agent.llm.callbacks is None:
self.langchain_agent.llm.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.llm.callbacks, list):
self.langchain_agent.llm.callbacks.append(self.token_callback)
# Direct LLM case
elif not hasattr(self.langchain_agent, "agent"):
# This might be a direct LLM, not an agent
if (
not hasattr(self.langchain_agent, "callbacks")
or self.langchain_agent.callbacks is None
):
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
self.agent_executor = create_react_agent(
model=self.langchain_agent,
tools=used_tools,

View File

@@ -641,7 +641,7 @@ class Crew(BaseModel):
for after_callback in self.after_kickoff_callbacks:
result = after_callback(result)
metrics += [agent._token_process.get_summary() for agent in self.agents]
metrics += [agent.token_process.get_summary() for agent in self.agents]
self.usage_metrics = UsageMetrics()
for metric in metrics:
@@ -1195,12 +1195,15 @@ class Crew(BaseModel):
"""Calculates and returns the usage metrics."""
total_usage_metrics = UsageMetrics()
for agent in self.agents:
if hasattr(agent, "_token_process"):
token_sum = agent._token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
token_sum = self.manager_agent._token_process.get_summary()
# Directly access token_process since it's now a field in BaseAgent
token_sum = agent.token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
if self.manager_agent:
# Directly access token_process since it's now a field in BaseAgent
token_sum = self.manager_agent.token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
self.usage_metrics = total_usage_metrics
return total_usage_metrics

View File

@@ -1,15 +1,52 @@
import warnings
from typing import Any, Dict, Optional
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks.base import BaseCallbackHandler
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import Usage
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
class TokenCalcHandler(CustomLogger):
def __init__(self, token_cost_process: Optional[TokenProcess]):
self.token_cost_process = token_cost_process
class AbstractTokenCounter(ABC):
"""
Abstract base class for token counting callbacks.
Implementations should track token usage from different LLM providers.
"""
def __init__(self, token_process: Optional[TokenProcess] = None):
"""Initialize with a TokenProcess instance to track tokens."""
self.token_process = token_process
@abstractmethod
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
"""Update token usage counts in the token process."""
pass
class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter):
"""
Token counter implementation for LiteLLM.
Uses LiteLLM's CustomLogger interface to track token usage.
"""
def __init__(self, token_process: Optional[TokenProcess] = None):
AbstractTokenCounter.__init__(self, token_process)
CustomLogger.__init__(self)
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
"""Update token usage counts in the token process."""
if self.token_process is None:
return
if prompt_tokens > 0:
self.token_process.sum_prompt_tokens(prompt_tokens)
if completion_tokens > 0:
self.token_process.sum_completion_tokens(completion_tokens)
self.token_process.sum_successful_requests(1)
def log_success_event(
self,
@@ -18,7 +55,11 @@ class TokenCalcHandler(CustomLogger):
start_time: float,
end_time: float,
) -> None:
if self.token_cost_process is None:
"""
Process successful LLM call and extract token usage information.
This method is called by LiteLLM after a successful completion.
"""
if self.token_process is None:
return
with warnings.catch_warnings():
@@ -26,18 +67,154 @@ class TokenCalcHandler(CustomLogger):
if isinstance(response_obj, dict) and "usage" in response_obj:
usage: Usage = response_obj["usage"]
if usage:
self.token_cost_process.sum_successful_requests(1)
prompt_tokens = 0
completion_tokens = 0
if hasattr(usage, "prompt_tokens"):
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
prompt_tokens = usage.prompt_tokens
elif isinstance(usage, dict) and "prompt_tokens" in usage:
prompt_tokens = usage["prompt_tokens"]
if hasattr(usage, "completion_tokens"):
self.token_cost_process.sum_completion_tokens(
usage.completion_tokens
)
completion_tokens = usage.completion_tokens
elif isinstance(usage, dict) and "completion_tokens" in usage:
completion_tokens = usage["completion_tokens"]
self.update_token_usage(prompt_tokens, completion_tokens)
# Handle cached tokens if available
if (
hasattr(usage, "prompt_tokens_details")
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
self.token_cost_process.sum_cached_prompt_tokens(
self.token_process.sum_cached_prompt_tokens(
usage.prompt_tokens_details.cached_tokens
)
class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
"""
Token counter implementation for LangChain.
Implements the necessary callback methods to track token usage from LangChain responses.
"""
def __init__(self, token_process: Optional[TokenProcess] = None):
BaseCallbackHandler.__init__(self)
AbstractTokenCounter.__init__(self, token_process)
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
"""Update token usage counts in the token process."""
if self.token_process is None:
return
if prompt_tokens > 0:
self.token_process.sum_prompt_tokens(prompt_tokens)
if completion_tokens > 0:
self.token_process.sum_completion_tokens(completion_tokens)
self.token_process.sum_successful_requests(1)
@property
def ignore_llm(self) -> bool:
return False
@property
def ignore_chain(self) -> bool:
return True
@property
def ignore_agent(self) -> bool:
return False
@property
def ignore_chat_model(self) -> bool:
return False
@property
def ignore_retriever(self) -> bool:
return True
@property
def ignore_tools(self) -> bool:
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Called when LLM starts processing."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Called when LLM generates a new token."""
pass
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
"""
Called when LLM ends processing.
Extracts token usage from LangChain response objects.
"""
if self.token_process is None:
return
# Handle LangChain response format
if hasattr(response, "llm_output") and isinstance(response.llm_output, dict):
token_usage = response.llm_output.get("token_usage", {})
prompt_tokens = token_usage.get("prompt_tokens", 0)
completion_tokens = token_usage.get("completion_tokens", 0)
self.update_token_usage(prompt_tokens, completion_tokens)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Called when LLM errors."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Called when a chain starts."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Called when a chain ends."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Called when a chain errors."""
pass
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Called when a tool starts."""
pass
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Called when a tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Called when a tool errors."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Called when text is generated."""
pass
def on_agent_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None:
"""Called when an agent starts."""
pass
def on_agent_end(self, output: Any, **kwargs: Any) -> None:
"""Called when an agent ends."""
pass
def on_agent_error(self, error: BaseException, **kwargs: Any) -> None:
"""Called when an agent errors."""
pass
# For backward compatibility
TokenCalcHandler = LiteLLMTokenCounter

View File

@@ -0,0 +1,183 @@
#!/usr/bin/env python
"""
Test module for token tracking functionality in CrewAI.
This tests both direct LangChain models and LiteLLM integration.
"""
import os
from typing import Any, Dict
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.tools import Tool
from langchain_openai import ChatOpenAI
from crewai import Crew, Process, Task
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.langchain_agent_adapter import LangChainAgentAdapter
from crewai.utilities.token_counter_callback import (
LangChainTokenCounter,
LiteLLMTokenCounter,
)
def get_weather(location: str = "San Francisco"):
"""Simulates fetching current weather data for a given location."""
# In a real implementation, you could replace this with an API call.
return f"Current weather in {location}: Sunny, 25°C"
class TestTokenTracking:
"""Test suite for token tracking functionality."""
@pytest.fixture
def weather_tool(self):
"""Create a simple weather tool for testing."""
return Tool(
name="Weather",
func=get_weather,
description="Useful for fetching current weather information for a given location.",
)
@pytest.fixture
def mock_openai_response(self):
"""Create a mock OpenAI response with token usage information."""
return {
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
}
}
def test_token_process_basic(self):
"""Test basic functionality of TokenProcess class."""
token_process = TokenProcess()
# Test adding prompt tokens
token_process.sum_prompt_tokens(100)
assert token_process.prompt_tokens == 100
# Test adding completion tokens
token_process.sum_completion_tokens(50)
assert token_process.completion_tokens == 50
# Test adding successful requests
token_process.sum_successful_requests(1)
assert token_process.successful_requests == 1
# Test getting summary
summary = token_process.get_summary()
assert summary.prompt_tokens == 100
assert summary.completion_tokens == 50
assert summary.total_tokens == 150
assert summary.successful_requests == 1
@patch("litellm.completion")
def test_litellm_token_counter(self, mock_completion):
"""Test LiteLLMTokenCounter with a mock response."""
# Setup
token_process = TokenProcess()
counter = LiteLLMTokenCounter(token_process)
# Mock the response
mock_completion.return_value = {
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
}
}
# Simulate a successful LLM call
counter.log_success_event(
kwargs={},
response_obj=mock_completion.return_value,
start_time=0,
end_time=1,
)
# Verify token counts were updated
assert token_process.prompt_tokens == 100
assert token_process.completion_tokens == 50
assert token_process.successful_requests == 1
def test_langchain_token_counter(self):
"""Test LangChainTokenCounter with a mock response."""
# Setup
token_process = TokenProcess()
counter = LangChainTokenCounter(token_process)
# Create a mock LangChain response
mock_response = MagicMock()
mock_response.llm_output = {
"token_usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
}
}
# Simulate a successful LLM call
counter.on_llm_end(mock_response)
# Verify token counts were updated
assert token_process.prompt_tokens == 100
assert token_process.completion_tokens == 50
assert token_process.successful_requests == 1
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY"),
reason="OPENAI_API_KEY environment variable not set",
)
def test_langchain_agent_adapter_token_tracking(self, weather_tool):
"""
Integration test for token tracking with LangChainAgentAdapter.
This test requires an OpenAI API key.
"""
# Initialize a ChatOpenAI model
llm = ChatOpenAI(model="gpt-3.5-turbo")
# Create a LangChainAgentAdapter with the direct LLM
agent = LangChainAgentAdapter(
langchain_agent=llm,
tools=[weather_tool],
role="Weather Agent",
goal="Provide current weather information for the requested location.",
backstory="An expert weather provider that fetches current weather information using simulated data.",
verbose=True,
)
# Create a weather task for the agent
task = Task(
description="Fetch the current weather for San Francisco.",
expected_output="A weather report showing current conditions in San Francisco.",
agent=agent,
)
# Create a crew with the single agent and task
crew = Crew(
agents=[agent],
tasks=[task],
verbose=True,
process=Process.sequential,
)
# Execute the crew
result = crew.kickoff()
# Verify token usage was tracked
assert result.token_usage is not None
assert result.token_usage.total_tokens > 0
assert result.token_usage.prompt_tokens > 0
assert result.token_usage.completion_tokens > 0
assert result.token_usage.successful_requests > 0
# Also verify token usage directly from the agent
usage = agent.token_process.get_summary()
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens > 0
assert usage.successful_requests > 0
if __name__ == "__main__":
pytest.main(["-xvs", __file__])