Refactor token tracking: Remove token_cost_process parameter for cleaner code

This commit is contained in:
Brandon Hancock
2025-02-28 12:17:14 -05:00
parent ef48cbe971
commit 75d8e086a4
3 changed files with 21 additions and 9 deletions

View File

@@ -217,4 +217,9 @@ class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
# For backward compatibility # For backward compatibility
TokenCalcHandler = LiteLLMTokenCounter class TokenCalcHandler(LiteLLMTokenCounter):
"""
Alias for LiteLLMTokenCounter.
"""
pass

View File

@@ -18,15 +18,15 @@ def test_llm_callback_replacement():
llm1 = LLM(model="gpt-4o-mini") llm1 = LLM(model="gpt-4o-mini")
llm2 = LLM(model="gpt-4o-mini") llm2 = LLM(model="gpt-4o-mini")
calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess()) calc_handler_1 = TokenCalcHandler(token_process=TokenProcess())
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess()) calc_handler_2 = TokenCalcHandler(token_process=TokenProcess())
result1 = llm1.call( result1 = llm1.call(
messages=[{"role": "user", "content": "Hello, world!"}], messages=[{"role": "user", "content": "Hello, world!"}],
callbacks=[calc_handler_1], callbacks=[calc_handler_1],
) )
print("result1:", result1) print("result1:", result1)
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary() usage_metrics_1 = calc_handler_1.token_process.get_summary()
print("usage_metrics_1:", usage_metrics_1) print("usage_metrics_1:", usage_metrics_1)
result2 = llm2.call( result2 = llm2.call(
@@ -35,13 +35,13 @@ def test_llm_callback_replacement():
) )
sleep(5) sleep(5)
print("result2:", result2) print("result2:", result2)
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary() usage_metrics_2 = calc_handler_2.token_process.get_summary()
print("usage_metrics_2:", usage_metrics_2) print("usage_metrics_2:", usage_metrics_2)
# The first handler should not have been updated # The first handler should not have been updated
assert usage_metrics_1.successful_requests == 1 assert usage_metrics_1.successful_requests == 1
assert usage_metrics_2.successful_requests == 1 assert usage_metrics_2.successful_requests == 1
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary() assert usage_metrics_1 == calc_handler_1.token_process.get_summary()
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -57,14 +57,14 @@ def test_llm_call_with_string_input():
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_string_input_and_callbacks(): def test_llm_call_with_string_input_and_callbacks():
llm = LLM(model="gpt-4o-mini") llm = LLM(model="gpt-4o-mini")
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess()) calc_handler = TokenCalcHandler(token_process=TokenProcess())
# Test the call method with a string input and callbacks # Test the call method with a string input and callbacks
result = llm.call( result = llm.call(
"Tell me a joke.", "Tell me a joke.",
callbacks=[calc_handler], callbacks=[calc_handler],
) )
usage_metrics = calc_handler.token_cost_process.get_summary() usage_metrics = calc_handler.token_process.get_summary()
assert isinstance(result, str) assert isinstance(result, str)
assert len(result.strip()) > 0 assert len(result.strip()) > 0
@@ -285,6 +285,7 @@ def test_o3_mini_reasoning_effort_medium():
assert isinstance(result, str) assert isinstance(result, str)
assert "Paris" in result assert "Paris" in result
def test_context_window_validation(): def test_context_window_validation():
"""Test that context window validation works correctly.""" """Test that context window validation works correctly."""
# Test valid window size # Test valid window size

View File

@@ -133,8 +133,14 @@ class TestTokenTracking:
Integration test for token tracking with LangChainAgentAdapter. Integration test for token tracking with LangChainAgentAdapter.
This test requires an OpenAI API key. This test requires an OpenAI API key.
""" """
# Skip if LangGraph is not installed
try:
from langgraph.prebuilt import ToolNode
except ImportError:
pytest.skip("LangGraph is not installed. Install it with: uv add langgraph")
# Initialize a ChatOpenAI model # Initialize a ChatOpenAI model
llm = ChatOpenAI(model="gpt-3.5-turbo") llm = ChatOpenAI(model="gpt-4o")
# Create a LangChainAgentAdapter with the direct LLM # Create a LangChainAgentAdapter with the direct LLM
agent = LangChainAgentAdapter( agent = LangChainAgentAdapter(