mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
fix: surface Anthropic stop_reason to detect truncation (#5148)
- Add stop_reason field to LLMCallCompletedEvent - Update BaseLLM._emit_call_completed_event to accept and pass stop_reason - Add _warn_if_truncated helper to AnthropicCompletion that logs a warning when stop_reason='max_tokens' - Apply fix to all 6 Anthropic completion methods (sync and async): _handle_completion, _handle_streaming_completion, _handle_tool_use_conversation, _ahandle_completion, _ahandle_streaming_completion, _ahandle_tool_use_conversation - Add 7 tests covering truncation warning, event field, and tool use paths Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -57,6 +57,7 @@ class LLMCallCompletedEvent(LLMEventBase):
|
||||
messages: str | list[dict[str, Any]] | None = None
|
||||
response: Any
|
||||
call_type: LLMCallType
|
||||
stop_reason: str | None = None
|
||||
|
||||
|
||||
class LLMCallFailedEvent(LLMEventBase):
|
||||
|
||||
@@ -412,6 +412,7 @@ class BaseLLM(ABC):
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
messages: str | list[LLMMessage] | None = None,
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM call completed event."""
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
@@ -426,6 +427,7 @@ class BaseLLM(ABC):
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
call_id=get_current_call_id(),
|
||||
stop_reason=stop_reason,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -676,6 +676,20 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return converted
|
||||
|
||||
def _warn_if_truncated(
|
||||
self,
|
||||
response: Message | BetaMessage,
|
||||
from_agent: Any | None = None,
|
||||
) -> None:
|
||||
"""Log a warning if the response was truncated due to max_tokens."""
|
||||
stop_reason = getattr(response, "stop_reason", None)
|
||||
if stop_reason == "max_tokens":
|
||||
agent_hint = f" [{from_agent.role}]" if from_agent else ""
|
||||
logging.warning(
|
||||
f"Truncated response{agent_hint}: stop_reason='max_tokens'. "
|
||||
f"Consider increasing max_tokens (current: {self.max_tokens})."
|
||||
)
|
||||
|
||||
def _format_messages_for_anthropic(
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[LLMMessage], str | None]:
|
||||
@@ -858,6 +872,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
stop_reason = getattr(response, "stop_reason", None)
|
||||
self._warn_if_truncated(response, from_agent)
|
||||
|
||||
if _is_pydantic_model_class(response_model) and response.content:
|
||||
if use_native_structured_output:
|
||||
for block in response.content:
|
||||
@@ -869,6 +886,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return structured_data
|
||||
else:
|
||||
@@ -884,6 +902,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return structured_data
|
||||
|
||||
@@ -906,6 +925,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return list(tool_uses)
|
||||
|
||||
@@ -937,6 +957,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
@@ -1077,6 +1098,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
stop_reason = getattr(final_message, "stop_reason", None)
|
||||
self._warn_if_truncated(final_message, from_agent)
|
||||
|
||||
if _is_pydantic_model_class(response_model):
|
||||
if use_native_structured_output:
|
||||
structured_data = response_model.model_validate_json(full_response)
|
||||
@@ -1086,6 +1110,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return structured_data
|
||||
for block in final_message.content:
|
||||
@@ -1100,6 +1125,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return structured_data
|
||||
|
||||
@@ -1129,6 +1155,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
@@ -1275,6 +1302,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
stop_reason = getattr(final_response, "stop_reason", None)
|
||||
self._warn_if_truncated(final_response, from_agent)
|
||||
|
||||
final_content = ""
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
|
||||
@@ -1299,6 +1329,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=follow_up_params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
# Log combined token usage
|
||||
@@ -1379,6 +1410,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
stop_reason = getattr(response, "stop_reason", None)
|
||||
self._warn_if_truncated(response, from_agent)
|
||||
|
||||
if _is_pydantic_model_class(response_model) and response.content:
|
||||
if use_native_structured_output:
|
||||
for block in response.content:
|
||||
@@ -1390,6 +1424,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return structured_data
|
||||
else:
|
||||
@@ -1405,6 +1440,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return structured_data
|
||||
|
||||
@@ -1425,6 +1461,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return list(tool_uses)
|
||||
|
||||
@@ -1448,6 +1485,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
@@ -1576,6 +1614,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
stop_reason = getattr(final_message, "stop_reason", None)
|
||||
self._warn_if_truncated(final_message, from_agent)
|
||||
|
||||
if _is_pydantic_model_class(response_model):
|
||||
if use_native_structured_output:
|
||||
structured_data = response_model.model_validate_json(full_response)
|
||||
@@ -1585,6 +1626,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return structured_data
|
||||
for block in final_message.content:
|
||||
@@ -1599,6 +1641,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
return structured_data
|
||||
|
||||
@@ -1627,6 +1670,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
return full_response
|
||||
@@ -1671,6 +1715,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
stop_reason = getattr(final_response, "stop_reason", None)
|
||||
self._warn_if_truncated(final_response, from_agent)
|
||||
|
||||
final_content = ""
|
||||
if final_response.content:
|
||||
for content_block in final_response.content:
|
||||
@@ -1685,6 +1732,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=follow_up_params["messages"],
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
total_usage = {
|
||||
|
||||
@@ -1463,3 +1463,217 @@ def test_tool_search_saves_input_tokens():
|
||||
f"Expected tool_search ({usage_search.prompt_tokens}) to use fewer input tokens "
|
||||
f"than no search ({usage_no_search.prompt_tokens})"
|
||||
)
|
||||
|
||||
|
||||
def test_anthropic_warns_on_max_tokens_truncation():
|
||||
"""Test that a warning is logged when Anthropic response has stop_reason='max_tokens'."""
|
||||
import logging
|
||||
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
with patch.object(llm.client.messages, "create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="truncated response")]
|
||||
mock_response.content[0].type = "text"
|
||||
mock_response.stop_reason = "max_tokens"
|
||||
mock_response.usage = MagicMock(
|
||||
input_tokens=100, output_tokens=4096, cache_read_input_tokens=0
|
||||
)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
with patch("crewai.llms.providers.anthropic.completion.logging") as mock_logging:
|
||||
result = llm.call("Tell me a very long story")
|
||||
|
||||
assert result == "truncated response"
|
||||
mock_logging.warning.assert_called_once()
|
||||
warning_msg = mock_logging.warning.call_args[0][0]
|
||||
assert "stop_reason='max_tokens'" in warning_msg
|
||||
assert "Consider increasing max_tokens" in warning_msg
|
||||
|
||||
|
||||
def test_anthropic_no_warning_on_end_turn():
|
||||
"""Test that no truncation warning is logged when stop_reason is 'end_turn'."""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
with patch.object(llm.client.messages, "create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="complete response")]
|
||||
mock_response.content[0].type = "text"
|
||||
mock_response.stop_reason = "end_turn"
|
||||
mock_response.usage = MagicMock(
|
||||
input_tokens=50, output_tokens=25, cache_read_input_tokens=0
|
||||
)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
with patch("crewai.llms.providers.anthropic.completion.logging") as mock_logging:
|
||||
result = llm.call("Hello")
|
||||
|
||||
assert result == "complete response"
|
||||
mock_logging.warning.assert_not_called()
|
||||
|
||||
|
||||
def test_anthropic_truncation_warning_includes_agent_role():
|
||||
"""Test that the truncation warning includes the agent role when available."""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.role = "Research Analyst"
|
||||
|
||||
with patch.object(llm.client.messages, "create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="truncated")]
|
||||
mock_response.content[0].type = "text"
|
||||
mock_response.stop_reason = "max_tokens"
|
||||
mock_response.usage = MagicMock(
|
||||
input_tokens=100, output_tokens=4096, cache_read_input_tokens=0
|
||||
)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
with patch("crewai.llms.providers.anthropic.completion.logging") as mock_logging:
|
||||
llm.call("Tell me everything", from_agent=mock_agent)
|
||||
|
||||
mock_logging.warning.assert_called_once()
|
||||
warning_msg = mock_logging.warning.call_args[0][0]
|
||||
assert "[Research Analyst]" in warning_msg
|
||||
assert "stop_reason='max_tokens'" in warning_msg
|
||||
|
||||
|
||||
def test_anthropic_stop_reason_in_completed_event():
|
||||
"""Test that stop_reason is included in the LLMCallCompletedEvent."""
|
||||
from crewai.events.types.llm_events import LLMCallCompletedEvent
|
||||
from crewai.events import crewai_event_bus
|
||||
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
captured_events = []
|
||||
|
||||
def capture_event(source, event):
|
||||
if isinstance(event, LLMCallCompletedEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
crewai_event_bus.register_handler(LLMCallCompletedEvent, capture_event)
|
||||
|
||||
try:
|
||||
with patch.object(llm.client.messages, "create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="truncated response")]
|
||||
mock_response.content[0].type = "text"
|
||||
mock_response.stop_reason = "max_tokens"
|
||||
mock_response.usage = MagicMock(
|
||||
input_tokens=100, output_tokens=4096, cache_read_input_tokens=0
|
||||
)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
llm.call("Tell me a long story")
|
||||
|
||||
crewai_event_bus.flush(timeout=5.0)
|
||||
|
||||
assert len(captured_events) >= 1
|
||||
event = captured_events[-1]
|
||||
assert event.stop_reason == "max_tokens"
|
||||
finally:
|
||||
crewai_event_bus.off(LLMCallCompletedEvent, capture_event)
|
||||
|
||||
|
||||
def test_anthropic_stop_reason_none_when_normal_completion():
|
||||
"""Test that stop_reason is None in event when response completes normally."""
|
||||
from crewai.events.types.llm_events import LLMCallCompletedEvent
|
||||
from crewai.events import crewai_event_bus
|
||||
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
captured_events = []
|
||||
|
||||
def capture_event(source, event):
|
||||
if isinstance(event, LLMCallCompletedEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
crewai_event_bus.register_handler(LLMCallCompletedEvent, capture_event)
|
||||
|
||||
try:
|
||||
with patch.object(llm.client.messages, "create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="complete response")]
|
||||
mock_response.content[0].type = "text"
|
||||
mock_response.stop_reason = "end_turn"
|
||||
mock_response.usage = MagicMock(
|
||||
input_tokens=50, output_tokens=25, cache_read_input_tokens=0
|
||||
)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
llm.call("Hello")
|
||||
|
||||
crewai_event_bus.flush(timeout=5.0)
|
||||
|
||||
assert len(captured_events) >= 1
|
||||
event = captured_events[-1]
|
||||
assert event.stop_reason == "end_turn"
|
||||
finally:
|
||||
crewai_event_bus.off(LLMCallCompletedEvent, capture_event)
|
||||
|
||||
|
||||
def test_anthropic_tool_use_conversation_warns_on_truncation():
|
||||
"""Test that _handle_tool_use_conversation warns when the follow-up response is truncated."""
|
||||
from anthropic.types import ToolUseBlock, TextBlock
|
||||
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock initial response with tool use
|
||||
mock_initial_response = MagicMock()
|
||||
mock_tool_block = MagicMock(spec=ToolUseBlock)
|
||||
mock_tool_block.type = "tool_use"
|
||||
mock_tool_block.id = "tool_123"
|
||||
mock_tool_block.name = "get_data"
|
||||
mock_tool_block.input = {"query": "test"}
|
||||
mock_initial_response.content = [mock_tool_block]
|
||||
|
||||
# Mock final response with max_tokens truncation
|
||||
mock_final_response = MagicMock()
|
||||
mock_text_block = MagicMock(spec=TextBlock)
|
||||
mock_text_block.text = "truncated tool result"
|
||||
mock_text_block.type = "text"
|
||||
mock_final_response.content = [mock_text_block]
|
||||
mock_final_response.stop_reason = "max_tokens"
|
||||
mock_final_response.usage = MagicMock(
|
||||
input_tokens=200, output_tokens=4096, cache_read_input_tokens=0
|
||||
)
|
||||
|
||||
with patch.object(llm.client.messages, "create", return_value=mock_final_response):
|
||||
with patch("crewai.llms.providers.anthropic.completion.logging") as mock_logging:
|
||||
|
||||
def mock_tool_fn(**kwargs):
|
||||
return "tool output"
|
||||
|
||||
result = llm._handle_tool_use_conversation(
|
||||
initial_response=mock_initial_response,
|
||||
tool_uses=[mock_tool_block],
|
||||
params={"messages": [{"role": "user", "content": "test"}], "model": "claude-3-5-sonnet-20241022", "max_tokens": 4096},
|
||||
available_functions={"get_data": mock_tool_fn},
|
||||
)
|
||||
|
||||
assert result == "truncated tool result"
|
||||
mock_logging.warning.assert_called_once()
|
||||
warning_msg = mock_logging.warning.call_args[0][0]
|
||||
assert "stop_reason='max_tokens'" in warning_msg
|
||||
|
||||
|
||||
def test_llm_call_completed_event_has_stop_reason_field():
|
||||
"""Test that LLMCallCompletedEvent has the stop_reason field with correct default."""
|
||||
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallType
|
||||
|
||||
# Default stop_reason should be None
|
||||
event = LLMCallCompletedEvent(
|
||||
response="test",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
call_id="test-id",
|
||||
)
|
||||
assert event.stop_reason is None
|
||||
|
||||
# stop_reason can be set
|
||||
event_with_reason = LLMCallCompletedEvent(
|
||||
response="test",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
call_id="test-id",
|
||||
stop_reason="max_tokens",
|
||||
)
|
||||
assert event_with_reason.stop_reason == "max_tokens"
|
||||
|
||||
Reference in New Issue
Block a user