mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
Refactor token tracking: Remove token_cost_process parameter for cleaner code
This commit is contained in:
@@ -217,4 +217,9 @@ class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
TokenCalcHandler = LiteLLMTokenCounter
|
||||
class TokenCalcHandler(LiteLLMTokenCounter):
|
||||
"""
|
||||
Alias for LiteLLMTokenCounter.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -18,15 +18,15 @@ def test_llm_callback_replacement():
|
||||
llm1 = LLM(model="gpt-4o-mini")
|
||||
llm2 = LLM(model="gpt-4o-mini")
|
||||
|
||||
calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
calc_handler_1 = TokenCalcHandler(token_process=TokenProcess())
|
||||
calc_handler_2 = TokenCalcHandler(token_process=TokenProcess())
|
||||
|
||||
result1 = llm1.call(
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
callbacks=[calc_handler_1],
|
||||
)
|
||||
print("result1:", result1)
|
||||
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()
|
||||
usage_metrics_1 = calc_handler_1.token_process.get_summary()
|
||||
print("usage_metrics_1:", usage_metrics_1)
|
||||
|
||||
result2 = llm2.call(
|
||||
@@ -35,13 +35,13 @@ def test_llm_callback_replacement():
|
||||
)
|
||||
sleep(5)
|
||||
print("result2:", result2)
|
||||
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
|
||||
usage_metrics_2 = calc_handler_2.token_process.get_summary()
|
||||
print("usage_metrics_2:", usage_metrics_2)
|
||||
|
||||
# The first handler should not have been updated
|
||||
assert usage_metrics_1.successful_requests == 1
|
||||
assert usage_metrics_2.successful_requests == 1
|
||||
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
|
||||
assert usage_metrics_1 == calc_handler_1.token_process.get_summary()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -57,14 +57,14 @@ def test_llm_call_with_string_input():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_with_string_input_and_callbacks():
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
calc_handler = TokenCalcHandler(token_process=TokenProcess())
|
||||
|
||||
# Test the call method with a string input and callbacks
|
||||
result = llm.call(
|
||||
"Tell me a joke.",
|
||||
callbacks=[calc_handler],
|
||||
)
|
||||
usage_metrics = calc_handler.token_cost_process.get_summary()
|
||||
usage_metrics = calc_handler.token_process.get_summary()
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result.strip()) > 0
|
||||
@@ -285,6 +285,7 @@ def test_o3_mini_reasoning_effort_medium():
|
||||
assert isinstance(result, str)
|
||||
assert "Paris" in result
|
||||
|
||||
|
||||
def test_context_window_validation():
|
||||
"""Test that context window validation works correctly."""
|
||||
# Test valid window size
|
||||
|
||||
@@ -133,8 +133,14 @@ class TestTokenTracking:
|
||||
Integration test for token tracking with LangChainAgentAdapter.
|
||||
This test requires an OpenAI API key.
|
||||
"""
|
||||
# Skip if LangGraph is not installed
|
||||
try:
|
||||
from langgraph.prebuilt import ToolNode
|
||||
except ImportError:
|
||||
pytest.skip("LangGraph is not installed. Install it with: uv add langgraph")
|
||||
|
||||
# Initialize a ChatOpenAI model
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
|
||||
# Create a LangChainAgentAdapter with the direct LLM
|
||||
agent = LangChainAgentAdapter(
|
||||
|
||||
Reference in New Issue
Block a user