mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 13:58:15 +00:00
Compare commits
1 Commits
devin/1767
...
devin/1767
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0976c42c6b |
@@ -531,20 +531,10 @@ def _delegate_to_a2a(
|
||||
agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents))
|
||||
task_config = task.config or {}
|
||||
context_id = task_config.get("context_id")
|
||||
task_id_config = task_config.get("task_id")
|
||||
metadata = task_config.get("metadata")
|
||||
extensions = task_config.get("extensions")
|
||||
|
||||
# Use endpoint-scoped task IDs to prevent reusing task IDs across different A2A agents
|
||||
# This fixes the issue where delegating to a second A2A agent fails because the task_id
|
||||
# from the first agent is in "completed" state
|
||||
# Make a defensive copy to avoid in-place mutation of task.config
|
||||
# Handle case where value is explicitly None (e.g., from JSON/YAML config)
|
||||
existing_task_ids = task_config.get("a2a_task_ids_by_endpoint")
|
||||
a2a_task_ids_by_endpoint: dict[str, str] = (
|
||||
dict(existing_task_ids) if existing_task_ids else {}
|
||||
)
|
||||
task_id_config = a2a_task_ids_by_endpoint.get(agent_id)
|
||||
|
||||
reference_task_ids = task_config.get("reference_task_ids", [])
|
||||
|
||||
if original_task_description is None:
|
||||
@@ -585,8 +575,6 @@ def _delegate_to_a2a(
|
||||
if conversation_history:
|
||||
latest_message = conversation_history[-1]
|
||||
if latest_message.task_id is not None:
|
||||
# Update task_id_config for the current loop iteration only
|
||||
# Don't persist to a2a_task_ids_by_endpoint yet - wait until we know the status
|
||||
task_id_config = latest_message.task_id
|
||||
if latest_message.context_id is not None:
|
||||
context_id = latest_message.context_id
|
||||
@@ -596,16 +584,13 @@ def _delegate_to_a2a(
|
||||
a2a_result["status"] == "completed"
|
||||
and agent_config.trust_remote_completion_status
|
||||
):
|
||||
# Don't persist completed task IDs - they can't be reused
|
||||
# (A2A protocol rejects task IDs in terminal state)
|
||||
# Only add to reference_task_ids for tracking purposes
|
||||
if task.config is None:
|
||||
task.config = {}
|
||||
if (
|
||||
task_id_config is not None
|
||||
and task_id_config not in reference_task_ids
|
||||
):
|
||||
reference_task_ids.append(task_id_config)
|
||||
if task.config is None:
|
||||
task.config = {}
|
||||
task.config["reference_task_ids"] = reference_task_ids
|
||||
|
||||
result_text = a2a_result.get("result", "")
|
||||
|
||||
@@ -928,7 +928,17 @@ class LLM(BaseLLM):
|
||||
if not tool_calls or not available_functions:
|
||||
# Track token usage and log callbacks if available in streaming mode
|
||||
if usage_info:
|
||||
self._track_token_usage_internal(usage_info)
|
||||
# Convert usage object to dict if needed
|
||||
if hasattr(usage_info, "__dict__"):
|
||||
usage_dict = {
|
||||
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage_info, "total_tokens", 0),
|
||||
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
|
||||
}
|
||||
else:
|
||||
usage_dict = usage_info
|
||||
self._track_token_usage_internal(usage_dict)
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
|
||||
if response_model and self.is_litellm:
|
||||
@@ -964,7 +974,17 @@ class LLM(BaseLLM):
|
||||
|
||||
# --- 10) Track token usage and log callbacks if available in streaming mode
|
||||
if usage_info:
|
||||
self._track_token_usage_internal(usage_info)
|
||||
# Convert usage object to dict if needed
|
||||
if hasattr(usage_info, "__dict__"):
|
||||
usage_dict = {
|
||||
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage_info, "total_tokens", 0),
|
||||
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
|
||||
}
|
||||
else:
|
||||
usage_dict = usage_info
|
||||
self._track_token_usage_internal(usage_dict)
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
|
||||
# --- 11) Emit completion event and return response
|
||||
@@ -1173,7 +1193,23 @@ class LLM(BaseLLM):
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
# --- 3) Handle callbacks with usage info
|
||||
|
||||
# --- 3a) Track token usage internally
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info:
|
||||
# Convert usage object to dict if needed
|
||||
if hasattr(usage_info, "__dict__"):
|
||||
usage_dict = {
|
||||
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage_info, "total_tokens", 0),
|
||||
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
|
||||
}
|
||||
else:
|
||||
usage_dict = usage_info
|
||||
self._track_token_usage_internal(usage_dict)
|
||||
|
||||
# --- 3b) Handle callbacks with usage info
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
@@ -1293,10 +1329,24 @@ class LLM(BaseLLM):
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
|
||||
# Track token usage internally
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info:
|
||||
# Convert usage object to dict if needed
|
||||
if hasattr(usage_info, "__dict__"):
|
||||
usage_dict = {
|
||||
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage_info, "total_tokens", 0),
|
||||
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
|
||||
}
|
||||
else:
|
||||
usage_dict = usage_info
|
||||
self._track_token_usage_internal(usage_dict)
|
||||
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
@@ -1381,7 +1431,10 @@ class LLM(BaseLLM):
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
# Try to extract usage information if available
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
@@ -1434,6 +1487,20 @@ class LLM(BaseLLM):
|
||||
),
|
||||
)
|
||||
|
||||
# Track token usage internally
|
||||
if usage_info:
|
||||
# Convert usage object to dict if needed
|
||||
if hasattr(usage_info, "__dict__"):
|
||||
usage_dict = {
|
||||
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage_info, "total_tokens", 0),
|
||||
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
|
||||
}
|
||||
else:
|
||||
usage_dict = usage_info
|
||||
self._track_token_usage_internal(usage_dict)
|
||||
|
||||
if callbacks and len(callbacks) > 0 and usage_info:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
"""Test A2A delegation to multiple endpoints sequentially.
|
||||
|
||||
This test file covers the bug fix for issue #4166 where delegating to a second
|
||||
A2A agent fails because the task_id from the first agent is in "completed" state.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
try:
|
||||
from a2a.types import Message, Part, Role, TextPart
|
||||
|
||||
A2A_SDK_INSTALLED = True
|
||||
except ImportError:
|
||||
A2A_SDK_INSTALLED = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||
def test_sequential_delegation_to_multiple_endpoints_uses_separate_task_ids():
|
||||
"""When delegating to multiple A2A endpoints sequentially, each should get a unique task_id.
|
||||
|
||||
This test verifies the fix for issue #4166 where the second A2A delegation
|
||||
fails with 'Task is in terminal state: completed' because the task_id from
|
||||
the first delegation was being reused.
|
||||
"""
|
||||
from crewai.a2a.wrapper import _delegate_to_a2a
|
||||
from crewai import Agent, Task
|
||||
|
||||
# Configure agent with two A2A endpoints
|
||||
a2a_configs = [
|
||||
A2AConfig(
|
||||
endpoint="http://endpoint-a.com",
|
||||
trust_remote_completion_status=True,
|
||||
),
|
||||
A2AConfig(
|
||||
endpoint="http://endpoint-b.com",
|
||||
trust_remote_completion_status=True,
|
||||
),
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="test manager",
|
||||
goal="coordinate",
|
||||
backstory="test",
|
||||
a2a=a2a_configs,
|
||||
)
|
||||
|
||||
task = Task(description="test", expected_output="test", agent=agent)
|
||||
|
||||
# First delegation to endpoint A
|
||||
class MockResponseA:
|
||||
is_a2a = True
|
||||
message = "Please help with task A"
|
||||
a2a_ids = ["http://endpoint-a.com/"]
|
||||
|
||||
# Second delegation to endpoint B
|
||||
class MockResponseB:
|
||||
is_a2a = True
|
||||
message = "Please help with task B"
|
||||
a2a_ids = ["http://endpoint-b.com/"]
|
||||
|
||||
task_ids_used = []
|
||||
|
||||
def mock_execute_a2a_delegation(**kwargs):
|
||||
"""Track the task_id used for each delegation."""
|
||||
task_ids_used.append(kwargs.get("task_id"))
|
||||
endpoint = kwargs.get("endpoint")
|
||||
|
||||
# Create a mock message with a task_id
|
||||
mock_message = MagicMock()
|
||||
mock_message.task_id = f"task-id-for-{endpoint}"
|
||||
mock_message.context_id = None
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": f"Done by {endpoint}",
|
||||
"history": [mock_message],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.a2a.wrapper.execute_a2a_delegation",
|
||||
side_effect=mock_execute_a2a_delegation,
|
||||
) as mock_execute,
|
||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
):
|
||||
mock_card_a = MagicMock()
|
||||
mock_card_a.name = "Agent A"
|
||||
mock_card_b = MagicMock()
|
||||
mock_card_b.name = "Agent B"
|
||||
mock_fetch.return_value = (
|
||||
{
|
||||
"http://endpoint-a.com/": mock_card_a,
|
||||
"http://endpoint-b.com/": mock_card_b,
|
||||
},
|
||||
{},
|
||||
)
|
||||
|
||||
# First delegation to endpoint A
|
||||
result_a = _delegate_to_a2a(
|
||||
self=agent,
|
||||
agent_response=MockResponseA(),
|
||||
task=task,
|
||||
original_fn=lambda *args, **kwargs: "fallback",
|
||||
context=None,
|
||||
tools=None,
|
||||
agent_cards={
|
||||
"http://endpoint-a.com/": mock_card_a,
|
||||
"http://endpoint-b.com/": mock_card_b,
|
||||
},
|
||||
original_task_description="test",
|
||||
)
|
||||
|
||||
assert result_a == "Done by http://endpoint-a.com/"
|
||||
|
||||
# Second delegation to endpoint B
|
||||
result_b = _delegate_to_a2a(
|
||||
self=agent,
|
||||
agent_response=MockResponseB(),
|
||||
task=task,
|
||||
original_fn=lambda *args, **kwargs: "fallback",
|
||||
context=None,
|
||||
tools=None,
|
||||
agent_cards={
|
||||
"http://endpoint-a.com/": mock_card_a,
|
||||
"http://endpoint-b.com/": mock_card_b,
|
||||
},
|
||||
original_task_description="test",
|
||||
)
|
||||
|
||||
assert result_b == "Done by http://endpoint-b.com/"
|
||||
|
||||
# Verify that the second delegation used a different (None) task_id
|
||||
# The first call should have task_id=None (no prior task_id for endpoint A)
|
||||
# The second call should also have task_id=None (no prior task_id for endpoint B)
|
||||
assert len(task_ids_used) == 2
|
||||
assert task_ids_used[0] is None # First delegation to endpoint A
|
||||
assert task_ids_used[1] is None # Second delegation to endpoint B (not reusing A's task_id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||
def test_completed_task_ids_are_not_persisted_for_reuse():
|
||||
"""Completed task IDs should NOT be persisted for reuse.
|
||||
|
||||
The A2A protocol rejects task IDs that are in terminal state (completed/failed).
|
||||
This test verifies that completed task IDs are not stored in task.config
|
||||
for future delegations, so each new delegation gets a fresh task_id.
|
||||
"""
|
||||
from crewai.a2a.wrapper import _delegate_to_a2a
|
||||
from crewai import Agent, Task
|
||||
|
||||
a2a_config = A2AConfig(
|
||||
endpoint="http://test-endpoint.com",
|
||||
trust_remote_completion_status=True,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="test manager",
|
||||
goal="coordinate",
|
||||
backstory="test",
|
||||
a2a=a2a_config,
|
||||
)
|
||||
|
||||
task = Task(description="test", expected_output="test", agent=agent)
|
||||
|
||||
class MockResponse:
|
||||
is_a2a = True
|
||||
message = "Please help"
|
||||
a2a_ids = ["http://test-endpoint.com/"]
|
||||
|
||||
task_ids_used = []
|
||||
|
||||
def mock_execute_a2a_delegation(**kwargs):
|
||||
"""Track the task_id used for each call."""
|
||||
task_ids_used.append(kwargs.get("task_id"))
|
||||
|
||||
# Create a mock message with a task_id
|
||||
mock_message = MagicMock()
|
||||
mock_message.task_id = "completed-task-id"
|
||||
mock_message.context_id = None
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": "Done",
|
||||
"history": [mock_message],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.a2a.wrapper.execute_a2a_delegation",
|
||||
side_effect=mock_execute_a2a_delegation,
|
||||
),
|
||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
):
|
||||
mock_card = MagicMock()
|
||||
mock_card.name = "Test"
|
||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||
|
||||
# First delegation
|
||||
_delegate_to_a2a(
|
||||
self=agent,
|
||||
agent_response=MockResponse(),
|
||||
task=task,
|
||||
original_fn=lambda *args, **kwargs: "fallback",
|
||||
context=None,
|
||||
tools=None,
|
||||
agent_cards={"http://test-endpoint.com/": mock_card},
|
||||
original_task_description="test",
|
||||
)
|
||||
|
||||
# Verify that completed task IDs are NOT stored in a2a_task_ids_by_endpoint
|
||||
# because they can't be reused (A2A protocol rejects terminal state task IDs)
|
||||
if task.config is not None:
|
||||
a2a_task_ids = task.config.get("a2a_task_ids_by_endpoint", {})
|
||||
# The endpoint should NOT have a stored task_id since it completed
|
||||
assert "http://test-endpoint.com/" not in a2a_task_ids
|
||||
|
||||
# Second delegation to the SAME endpoint should also get a fresh task_id
|
||||
_delegate_to_a2a(
|
||||
self=agent,
|
||||
agent_response=MockResponse(),
|
||||
task=task,
|
||||
original_fn=lambda *args, **kwargs: "fallback",
|
||||
context=None,
|
||||
tools=None,
|
||||
agent_cards={"http://test-endpoint.com/": mock_card},
|
||||
original_task_description="test",
|
||||
)
|
||||
|
||||
# Verify that BOTH calls used None as task_id (fresh task for each)
|
||||
# because completed task IDs are not persisted
|
||||
assert len(task_ids_used) == 2
|
||||
assert task_ids_used[0] is None # First call - new conversation
|
||||
assert task_ids_used[1] is None # Second call - also new (completed IDs not reused)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||
def test_reference_task_ids_are_tracked_for_completed_tasks():
|
||||
"""Completed task IDs should be added to reference_task_ids for tracking.
|
||||
|
||||
While completed task IDs can't be reused for new delegations, they should
|
||||
still be tracked in reference_task_ids for context/history purposes.
|
||||
"""
|
||||
from crewai.a2a.wrapper import _delegate_to_a2a
|
||||
from crewai import Agent, Task
|
||||
|
||||
a2a_config = A2AConfig(
|
||||
endpoint="http://test-endpoint.com",
|
||||
trust_remote_completion_status=True,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="test manager",
|
||||
goal="coordinate",
|
||||
backstory="test",
|
||||
a2a=a2a_config,
|
||||
)
|
||||
|
||||
task = Task(description="test", expected_output="test", agent=agent)
|
||||
|
||||
class MockResponse:
|
||||
is_a2a = True
|
||||
message = "Please help"
|
||||
a2a_ids = ["http://test-endpoint.com/"]
|
||||
|
||||
def mock_execute_a2a_delegation(**kwargs):
|
||||
mock_message = MagicMock()
|
||||
mock_message.task_id = "unique-task-id-123"
|
||||
mock_message.context_id = None
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": "Done",
|
||||
"history": [mock_message],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.a2a.wrapper.execute_a2a_delegation",
|
||||
side_effect=mock_execute_a2a_delegation,
|
||||
),
|
||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
):
|
||||
mock_card = MagicMock()
|
||||
mock_card.name = "Test"
|
||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||
|
||||
_delegate_to_a2a(
|
||||
self=agent,
|
||||
agent_response=MockResponse(),
|
||||
task=task,
|
||||
original_fn=lambda *args, **kwargs: "fallback",
|
||||
context=None,
|
||||
tools=None,
|
||||
agent_cards={"http://test-endpoint.com/": mock_card},
|
||||
original_task_description="test",
|
||||
)
|
||||
|
||||
# Verify the completed task_id is tracked in reference_task_ids
|
||||
assert task.config is not None
|
||||
assert "reference_task_ids" in task.config
|
||||
assert "unique-task-id-123" in task.config["reference_task_ids"]
|
||||
369
lib/crewai/tests/llms/litellm/test_litellm_token_usage.py
Normal file
369
lib/crewai/tests/llms/litellm/test_litellm_token_usage.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""Tests for LiteLLM token usage tracking functionality.
|
||||
|
||||
These tests verify that token usage metrics are properly tracked for:
|
||||
- Non-streaming responses
|
||||
- Async non-streaming responses
|
||||
- Async streaming responses
|
||||
|
||||
This addresses GitHub issue #4170 where token usage metrics were not being
|
||||
updated when using litellm with streaming responses and async calls.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
class MockUsage:
|
||||
"""Mock usage object that mimics litellm's usage response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 20,
|
||||
total_tokens: int = 30,
|
||||
):
|
||||
self.prompt_tokens = prompt_tokens
|
||||
self.completion_tokens = completion_tokens
|
||||
self.total_tokens = total_tokens
|
||||
|
||||
|
||||
class MockMessage:
|
||||
"""Mock message object that mimics litellm's message response."""
|
||||
|
||||
def __init__(self, content: str = "Test response"):
|
||||
self.content = content
|
||||
self.tool_calls = None
|
||||
|
||||
|
||||
class MockChoice:
|
||||
"""Mock choice object that mimics litellm's choice response."""
|
||||
|
||||
def __init__(self, content: str = "Test response"):
|
||||
self.message = MockMessage(content)
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Mock response object that mimics litellm's completion response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str = "Test response",
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 20,
|
||||
):
|
||||
self.choices = [MockChoice(content)]
|
||||
self.usage = MockUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
class MockStreamDelta:
|
||||
"""Mock delta object for streaming responses."""
|
||||
|
||||
def __init__(self, content: str | None = None):
|
||||
self.content = content
|
||||
self.tool_calls = None
|
||||
|
||||
|
||||
class MockStreamChoice:
|
||||
"""Mock choice object for streaming responses."""
|
||||
|
||||
def __init__(self, content: str | None = None):
|
||||
self.delta = MockStreamDelta(content)
|
||||
|
||||
|
||||
class MockStreamChunk:
|
||||
"""Mock chunk object for streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str | None = None,
|
||||
usage: MockUsage | None = None,
|
||||
):
|
||||
self.choices = [MockStreamChoice(content)]
|
||||
self.usage = usage
|
||||
|
||||
|
||||
def test_non_streaming_response_tracks_token_usage():
|
||||
"""Test that non-streaming responses properly track token usage."""
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False)
|
||||
|
||||
mock_response = MockResponse(
|
||||
content="Hello, world!",
|
||||
prompt_tokens=15,
|
||||
completion_tokens=25,
|
||||
)
|
||||
|
||||
with patch("litellm.completion", return_value=mock_response):
|
||||
result = llm.call("Say hello")
|
||||
|
||||
assert result == "Hello, world!"
|
||||
|
||||
# Verify token usage was tracked
|
||||
usage_summary = llm.get_token_usage_summary()
|
||||
assert usage_summary.prompt_tokens == 15
|
||||
assert usage_summary.completion_tokens == 25
|
||||
assert usage_summary.total_tokens == 40
|
||||
assert usage_summary.successful_requests == 1
|
||||
|
||||
|
||||
def test_non_streaming_response_accumulates_token_usage():
|
||||
"""Test that multiple non-streaming calls accumulate token usage."""
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False)
|
||||
|
||||
mock_response1 = MockResponse(
|
||||
content="First response",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
)
|
||||
mock_response2 = MockResponse(
|
||||
content="Second response",
|
||||
prompt_tokens=15,
|
||||
completion_tokens=25,
|
||||
)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.return_value = mock_response1
|
||||
llm.call("First call")
|
||||
|
||||
mock_completion.return_value = mock_response2
|
||||
llm.call("Second call")
|
||||
|
||||
# Verify accumulated token usage
|
||||
usage_summary = llm.get_token_usage_summary()
|
||||
assert usage_summary.prompt_tokens == 25 # 10 + 15
|
||||
assert usage_summary.completion_tokens == 45 # 20 + 25
|
||||
assert usage_summary.total_tokens == 70 # 30 + 40
|
||||
assert usage_summary.successful_requests == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_non_streaming_response_tracks_token_usage():
|
||||
"""Test that async non-streaming responses properly track token usage."""
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False)
|
||||
|
||||
mock_response = MockResponse(
|
||||
content="Async hello!",
|
||||
prompt_tokens=12,
|
||||
completion_tokens=18,
|
||||
)
|
||||
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = mock_response
|
||||
result = await llm.acall("Say hello async")
|
||||
|
||||
assert result == "Async hello!"
|
||||
|
||||
# Verify token usage was tracked
|
||||
usage_summary = llm.get_token_usage_summary()
|
||||
assert usage_summary.prompt_tokens == 12
|
||||
assert usage_summary.completion_tokens == 18
|
||||
assert usage_summary.total_tokens == 30
|
||||
assert usage_summary.successful_requests == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_non_streaming_response_accumulates_token_usage():
|
||||
"""Test that multiple async non-streaming calls accumulate token usage."""
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False)
|
||||
|
||||
mock_response1 = MockResponse(
|
||||
content="First async response",
|
||||
prompt_tokens=8,
|
||||
completion_tokens=12,
|
||||
)
|
||||
mock_response2 = MockResponse(
|
||||
content="Second async response",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=15,
|
||||
)
|
||||
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = mock_response1
|
||||
await llm.acall("First async call")
|
||||
|
||||
mock_acompletion.return_value = mock_response2
|
||||
await llm.acall("Second async call")
|
||||
|
||||
# Verify accumulated token usage
|
||||
usage_summary = llm.get_token_usage_summary()
|
||||
assert usage_summary.prompt_tokens == 18 # 8 + 10
|
||||
assert usage_summary.completion_tokens == 27 # 12 + 15
|
||||
assert usage_summary.total_tokens == 45 # 20 + 25
|
||||
assert usage_summary.successful_requests == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_response_tracks_token_usage():
|
||||
"""Test that async streaming responses properly track token usage."""
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=True)
|
||||
|
||||
# Create mock streaming chunks
|
||||
chunks = [
|
||||
MockStreamChunk(content="Hello"),
|
||||
MockStreamChunk(content=", "),
|
||||
MockStreamChunk(content="world"),
|
||||
MockStreamChunk(content="!"),
|
||||
# Final chunk with usage info (this is how litellm typically sends usage)
|
||||
MockStreamChunk(
|
||||
content=None,
|
||||
usage=MockUsage(prompt_tokens=20, completion_tokens=30, total_tokens=50),
|
||||
),
|
||||
]
|
||||
|
||||
async def mock_async_generator() -> AsyncIterator[MockStreamChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = mock_async_generator()
|
||||
result = await llm.acall("Say hello streaming")
|
||||
|
||||
assert result == "Hello, world!"
|
||||
|
||||
# Verify token usage was tracked
|
||||
usage_summary = llm.get_token_usage_summary()
|
||||
assert usage_summary.prompt_tokens == 20
|
||||
assert usage_summary.completion_tokens == 30
|
||||
assert usage_summary.total_tokens == 50
|
||||
assert usage_summary.successful_requests == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_response_with_dict_usage():
|
||||
"""Test that async streaming handles dict-based usage info."""
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=True)
|
||||
|
||||
# Create mock streaming chunks using dict format
|
||||
class DictStreamChunk:
|
||||
def __init__(
|
||||
self,
|
||||
content: str | None = None,
|
||||
usage: dict | None = None,
|
||||
):
|
||||
self.choices = [MockStreamChoice(content)]
|
||||
# Simulate dict-based usage (some providers return this)
|
||||
self._usage = usage
|
||||
|
||||
@property
|
||||
def usage(self) -> MockUsage | None:
|
||||
if self._usage:
|
||||
return MockUsage(**self._usage)
|
||||
return None
|
||||
|
||||
chunks = [
|
||||
DictStreamChunk(content="Test"),
|
||||
DictStreamChunk(content=" response"),
|
||||
DictStreamChunk(
|
||||
content=None,
|
||||
usage={
|
||||
"prompt_tokens": 25,
|
||||
"completion_tokens": 35,
|
||||
"total_tokens": 60,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
async def mock_async_generator() -> AsyncIterator[DictStreamChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = mock_async_generator()
|
||||
result = await llm.acall("Test streaming with dict usage")
|
||||
|
||||
assert result == "Test response"
|
||||
|
||||
# Verify token usage was tracked
|
||||
usage_summary = llm.get_token_usage_summary()
|
||||
assert usage_summary.prompt_tokens == 25
|
||||
assert usage_summary.completion_tokens == 35
|
||||
assert usage_summary.total_tokens == 60
|
||||
assert usage_summary.successful_requests == 1
|
||||
|
||||
|
||||
def test_streaming_response_tracks_token_usage():
|
||||
"""Test that sync streaming responses properly track token usage."""
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=True)
|
||||
|
||||
# Create mock streaming chunks
|
||||
chunks = [
|
||||
MockStreamChunk(content="Sync"),
|
||||
MockStreamChunk(content=" streaming"),
|
||||
MockStreamChunk(content=" test"),
|
||||
# Final chunk with usage info
|
||||
MockStreamChunk(
|
||||
content=None,
|
||||
usage=MockUsage(prompt_tokens=18, completion_tokens=22, total_tokens=40),
|
||||
),
|
||||
]
|
||||
|
||||
with patch("litellm.completion", return_value=iter(chunks)):
|
||||
result = llm.call("Test sync streaming")
|
||||
|
||||
assert result == "Sync streaming test"
|
||||
|
||||
# Verify token usage was tracked
|
||||
usage_summary = llm.get_token_usage_summary()
|
||||
assert usage_summary.prompt_tokens == 18
|
||||
assert usage_summary.completion_tokens == 22
|
||||
assert usage_summary.total_tokens == 40
|
||||
assert usage_summary.successful_requests == 1
|
||||
|
||||
|
||||
def test_token_usage_with_no_usage_info():
|
||||
"""Test that token usage tracking handles missing usage info gracefully."""
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False)
|
||||
|
||||
# Create mock response without usage info
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MockChoice("Response without usage")]
|
||||
mock_response.usage = None
|
||||
|
||||
with patch("litellm.completion", return_value=mock_response):
|
||||
result = llm.call("Test without usage")
|
||||
|
||||
assert result == "Response without usage"
|
||||
|
||||
# Verify token usage remains at zero
|
||||
usage_summary = llm.get_token_usage_summary()
|
||||
assert usage_summary.prompt_tokens == 0
|
||||
assert usage_summary.completion_tokens == 0
|
||||
assert usage_summary.total_tokens == 0
|
||||
assert usage_summary.successful_requests == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_with_no_usage_info():
|
||||
"""Test that async streaming handles missing usage info gracefully."""
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=True)
|
||||
|
||||
# Create mock streaming chunks without usage info
|
||||
chunks = [
|
||||
MockStreamChunk(content="No"),
|
||||
MockStreamChunk(content=" usage"),
|
||||
MockStreamChunk(content=" info"),
|
||||
]
|
||||
|
||||
async def mock_async_generator() -> AsyncIterator[MockStreamChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = mock_async_generator()
|
||||
result = await llm.acall("Test without usage info")
|
||||
|
||||
assert result == "No usage info"
|
||||
|
||||
# Verify token usage remains at zero
|
||||
usage_summary = llm.get_token_usage_summary()
|
||||
assert usage_summary.prompt_tokens == 0
|
||||
assert usage_summary.completion_tokens == 0
|
||||
assert usage_summary.total_tokens == 0
|
||||
assert usage_summary.successful_requests == 0
|
||||
Reference in New Issue
Block a user