mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 13:18:10 +00:00
fix(bedrock): propagate finish_reason + response_id on async paths
The original commit covered every provider's sync path and Bedrock's sync streaming path, but two Bedrock async paths still emitted LLMCallCompletedEvent without finish_reason/response_id: - _ahandle_converse: the final fallback emit_call_completed_event call was missing both fields. Added stop_reason + response_id matching the other emission sites in the same function. - _ahandle_streaming_converse: response_id was never seeded from the initial response object, and stream_finish_reason wasn't propagated to the structured-output and final-text emissions. Now extracts response_id up front and threads stream_finish_reason through every completion event. Adds a dedicated test file covering the new event fields end-to-end: - LLMCallCompletedEvent.finish_reason / response_id Pydantic validation (string accepted, None default, non-string coerced to None). - LLMCallStartedEvent sampling params (all nine fields accepted, default to None). - BaseLLM._emit_call_started_event introspecting sampling params off self, with explicit kwargs overriding. - BaseLLM._emit_call_completed_event passing finish_reason/response_id through to the event. - LLM._extract_finish_reason_and_response_id across the LiteLLM shapes (non-streaming response, streaming chunk, dict, missing fields, non-string values, unexpected input).
This commit is contained in:
@@ -1414,6 +1414,8 @@ class BedrockCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
finish_reason=stop_reason,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
return text_content
|
||||
@@ -1548,7 +1550,9 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
response_id = None
|
||||
_, stream_response_id = self._extract_finish_reason_and_id(response)
|
||||
response_id = stream_response_id
|
||||
stream_finish_reason: str | None = None
|
||||
if stream:
|
||||
async for event in stream:
|
||||
if "messageStart" in event:
|
||||
@@ -1647,6 +1651,8 @@ class BedrockCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage_data,
|
||||
finish_reason=stream_finish_reason,
|
||||
response_id=response_id,
|
||||
)
|
||||
return result # type: ignore[return-value]
|
||||
except Exception as e:
|
||||
@@ -1704,6 +1710,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
elif "messageStop" in event:
|
||||
stop_reason = event["messageStop"].get("stopReason")
|
||||
stream_finish_reason = stop_reason
|
||||
logging.debug(f"Streaming message stopped: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning(
|
||||
@@ -1750,6 +1757,8 @@ class BedrockCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage_data,
|
||||
finish_reason=stream_finish_reason,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
|
||||
236
lib/crewai/tests/events/test_llm_finish_reason_response_id.py
Normal file
236
lib/crewai/tests/events/test_llm_finish_reason_response_id.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
)
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
class _StubLLM(BaseLLM):
|
||||
model: str = "test-model"
|
||||
|
||||
def call(self, *args: Any, **kwargs: Any) -> str:
|
||||
return ""
|
||||
|
||||
async def acall(self, *args: Any, **kwargs: Any) -> str:
|
||||
return ""
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_emit():
|
||||
with patch.object(CrewAIEventsBus, "emit") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
class TestLLMCallCompletedEventFinishReasonAndResponseId:
|
||||
def test_accepts_string_values(self):
|
||||
event = LLMCallCompletedEvent(
|
||||
response="hi",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
call_id="call-1",
|
||||
finish_reason="stop",
|
||||
response_id="resp_123",
|
||||
)
|
||||
assert event.finish_reason == "stop"
|
||||
assert event.response_id == "resp_123"
|
||||
|
||||
def test_defaults_to_none(self):
|
||||
event = LLMCallCompletedEvent(
|
||||
response="hi",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
call_id="call-1",
|
||||
)
|
||||
assert event.finish_reason is None
|
||||
assert event.response_id is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value",
|
||||
[MagicMock(), 42, 1.5, ["stop"], {"reason": "stop"}, object()],
|
||||
)
|
||||
def test_coerces_non_string_to_none(self, value):
|
||||
event = LLMCallCompletedEvent(
|
||||
response="hi",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
call_id="call-1",
|
||||
finish_reason=value,
|
||||
response_id=value,
|
||||
)
|
||||
assert event.finish_reason is None
|
||||
assert event.response_id is None
|
||||
|
||||
|
||||
class TestLLMCallStartedEventSamplingParams:
|
||||
def test_accepts_all_sampling_params(self):
|
||||
event = LLMCallStartedEvent(
|
||||
call_id="call-1",
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_tokens=512,
|
||||
stream=True,
|
||||
seed=42,
|
||||
stop_sequences=["END"],
|
||||
frequency_penalty=0.1,
|
||||
presence_penalty=0.2,
|
||||
n=3,
|
||||
)
|
||||
assert event.temperature == 0.7
|
||||
assert event.top_p == 0.9
|
||||
assert event.max_tokens == 512
|
||||
assert event.stream is True
|
||||
assert event.seed == 42
|
||||
assert event.stop_sequences == ["END"]
|
||||
assert event.frequency_penalty == 0.1
|
||||
assert event.presence_penalty == 0.2
|
||||
assert event.n == 3
|
||||
|
||||
def test_all_sampling_params_default_to_none(self):
|
||||
event = LLMCallStartedEvent(call_id="call-1")
|
||||
assert event.temperature is None
|
||||
assert event.top_p is None
|
||||
assert event.max_tokens is None
|
||||
assert event.stream is None
|
||||
assert event.seed is None
|
||||
assert event.stop_sequences is None
|
||||
assert event.frequency_penalty is None
|
||||
assert event.presence_penalty is None
|
||||
assert event.n is None
|
||||
|
||||
|
||||
class TestEmitCallStartedEventIntrospectsSamplingParams:
|
||||
def test_reads_sampling_params_off_self(self, mock_emit):
|
||||
llm = _StubLLM(model="test-model", temperature=0.4)
|
||||
llm.top_p = 0.8
|
||||
llm.max_tokens = 256
|
||||
llm.stream = False
|
||||
llm.seed = 7
|
||||
llm.frequency_penalty = 0.5
|
||||
llm.presence_penalty = 0.6
|
||||
llm.n = 2
|
||||
llm.stop = ["STOP"]
|
||||
|
||||
llm._emit_call_started_event(messages="hi")
|
||||
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert isinstance(event, LLMCallStartedEvent)
|
||||
assert event.temperature == 0.4
|
||||
assert event.top_p == 0.8
|
||||
assert event.max_tokens == 256
|
||||
assert event.stream is False
|
||||
assert event.seed == 7
|
||||
assert event.stop_sequences == ["STOP"]
|
||||
assert event.frequency_penalty == 0.5
|
||||
assert event.presence_penalty == 0.6
|
||||
assert event.n == 2
|
||||
|
||||
def test_explicit_kwargs_override_introspection(self, mock_emit):
|
||||
llm = _StubLLM(model="test-model", temperature=0.4)
|
||||
|
||||
llm._emit_call_started_event(messages="hi", temperature=0.9)
|
||||
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert event.temperature == 0.9
|
||||
|
||||
|
||||
class TestEmitCallCompletedEventPassesFinishReasonAndResponseId:
|
||||
def test_passes_through_to_event(self, mock_emit):
|
||||
llm = _StubLLM(model="test-model")
|
||||
|
||||
llm._emit_call_completed_event(
|
||||
response="hi",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
finish_reason="stop",
|
||||
response_id="resp_123",
|
||||
)
|
||||
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert isinstance(event, LLMCallCompletedEvent)
|
||||
assert event.finish_reason == "stop"
|
||||
assert event.response_id == "resp_123"
|
||||
|
||||
def test_omitted_defaults_to_none(self, mock_emit):
|
||||
llm = _StubLLM(model="test-model")
|
||||
|
||||
llm._emit_call_completed_event(
|
||||
response="hi",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
)
|
||||
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert event.finish_reason is None
|
||||
assert event.response_id is None
|
||||
|
||||
|
||||
class TestLLMExtractFinishReasonAndResponseId:
|
||||
def test_non_streaming_litellm_shape(self):
|
||||
response = SimpleNamespace(
|
||||
id="chatcmpl-abc",
|
||||
choices=[SimpleNamespace(finish_reason="stop", message=SimpleNamespace())],
|
||||
)
|
||||
|
||||
finish_reason, response_id = LLM._extract_finish_reason_and_response_id(
|
||||
response
|
||||
)
|
||||
|
||||
assert finish_reason == "stop"
|
||||
assert response_id == "chatcmpl-abc"
|
||||
|
||||
def test_streaming_litellm_chunk_shape(self):
|
||||
last_chunk = SimpleNamespace(
|
||||
id="chatcmpl-stream-xyz",
|
||||
choices=[SimpleNamespace(finish_reason="tool_calls", delta=SimpleNamespace())],
|
||||
)
|
||||
|
||||
finish_reason, response_id = LLM._extract_finish_reason_and_response_id(
|
||||
last_chunk
|
||||
)
|
||||
|
||||
assert finish_reason == "tool_calls"
|
||||
assert response_id == "chatcmpl-stream-xyz"
|
||||
|
||||
def test_dict_shape(self):
|
||||
chunk = {
|
||||
"id": "chatcmpl-dict",
|
||||
"choices": [{"finish_reason": "length", "delta": {}}],
|
||||
}
|
||||
|
||||
finish_reason, response_id = LLM._extract_finish_reason_and_response_id(chunk)
|
||||
|
||||
assert finish_reason == "length"
|
||||
assert response_id == "chatcmpl-dict"
|
||||
|
||||
def test_missing_fields_return_none(self):
|
||||
finish_reason, response_id = LLM._extract_finish_reason_and_response_id(
|
||||
SimpleNamespace()
|
||||
)
|
||||
|
||||
assert finish_reason is None
|
||||
assert response_id is None
|
||||
|
||||
def test_non_string_values_coerced_to_none(self):
|
||||
response = SimpleNamespace(
|
||||
id=12345,
|
||||
choices=[SimpleNamespace(finish_reason=MagicMock(), delta=SimpleNamespace())],
|
||||
)
|
||||
|
||||
finish_reason, response_id = LLM._extract_finish_reason_and_response_id(
|
||||
response
|
||||
)
|
||||
|
||||
assert finish_reason is None
|
||||
assert response_id is None
|
||||
|
||||
def test_never_raises_on_unexpected_input(self):
|
||||
assert LLM._extract_finish_reason_and_response_id(None) == (None, None)
|
||||
assert LLM._extract_finish_reason_and_response_id(42) == (None, None)
|
||||
assert LLM._extract_finish_reason_and_response_id("string") == (None, None)
|
||||
Reference in New Issue
Block a user