Refactor token tracking: Remove token_cost_process parameter for cleaner code

This commit is contained in:
Brandon Hancock
2025-02-28 12:17:14 -05:00
parent ef48cbe971
commit 75d8e086a4
3 changed files with 21 additions and 9 deletions

View File

@@ -217,4 +217,9 @@ class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
# For backward compatibility
TokenCalcHandler = LiteLLMTokenCounter
class TokenCalcHandler(LiteLLMTokenCounter):
"""
Alias for LiteLLMTokenCounter.
"""
pass

View File

@@ -18,15 +18,15 @@ def test_llm_callback_replacement():
llm1 = LLM(model="gpt-4o-mini")
llm2 = LLM(model="gpt-4o-mini")
calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess())
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess())
calc_handler_1 = TokenCalcHandler(token_process=TokenProcess())
calc_handler_2 = TokenCalcHandler(token_process=TokenProcess())
result1 = llm1.call(
messages=[{"role": "user", "content": "Hello, world!"}],
callbacks=[calc_handler_1],
)
print("result1:", result1)
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()
usage_metrics_1 = calc_handler_1.token_process.get_summary()
print("usage_metrics_1:", usage_metrics_1)
result2 = llm2.call(
@@ -35,13 +35,13 @@ def test_llm_callback_replacement():
)
sleep(5)
print("result2:", result2)
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
usage_metrics_2 = calc_handler_2.token_process.get_summary()
print("usage_metrics_2:", usage_metrics_2)
# The first handler should not have been updated
assert usage_metrics_1.successful_requests == 1
assert usage_metrics_2.successful_requests == 1
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
assert usage_metrics_1 == calc_handler_1.token_process.get_summary()
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -57,14 +57,14 @@ def test_llm_call_with_string_input():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_string_input_and_callbacks():
llm = LLM(model="gpt-4o-mini")
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
calc_handler = TokenCalcHandler(token_process=TokenProcess())
# Test the call method with a string input and callbacks
result = llm.call(
"Tell me a joke.",
callbacks=[calc_handler],
)
usage_metrics = calc_handler.token_cost_process.get_summary()
usage_metrics = calc_handler.token_process.get_summary()
assert isinstance(result, str)
assert len(result.strip()) > 0
@@ -285,6 +285,7 @@ def test_o3_mini_reasoning_effort_medium():
assert isinstance(result, str)
assert "Paris" in result
def test_context_window_validation():
"""Test that context window validation works correctly."""
# Test valid window size

View File

@@ -133,8 +133,14 @@ class TestTokenTracking:
Integration test for token tracking with LangChainAgentAdapter.
This test requires an OpenAI API key.
"""
# Skip if LangGraph is not installed
try:
from langgraph.prebuilt import ToolNode
except ImportError:
pytest.skip("LangGraph is not installed. Install it with: uv add langgraph")
# Initialize a ChatOpenAI model
llm = ChatOpenAI(model="gpt-3.5-turbo")
llm = ChatOpenAI(model="gpt-4o")
# Create a LangChainAgentAdapter with the direct LLM
agent = LangChainAgentAdapter(