feat: adding response_id in streaming response

This commit is contained in:
Vidit Ostwal
2026-01-26 14:50:04 +05:30
committed by GitHub
parent 3d771f03fa
commit 06a58e463c
10 changed files with 74 additions and 3 deletions

View File

@@ -84,3 +84,4 @@ class LLMStreamChunkEvent(LLMEventBase):
chunk: str chunk: str
tool_call: ToolCall | None = None tool_call: ToolCall | None = None
call_type: LLMCallType | None = None call_type: LLMCallType | None = None
response_id: str | None = None

View File

@@ -768,6 +768,10 @@ class LLM(BaseLLM):
# Extract content from the chunk # Extract content from the chunk
chunk_content = None chunk_content = None
response_id = None
if hasattr(chunk,'id'):
response_id = chunk.id
# Safely extract content from various chunk formats # Safely extract content from various chunk formats
try: try:
@@ -823,6 +827,7 @@ class LLM(BaseLLM):
available_functions=available_functions, available_functions=available_functions,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id
) )
if result is not None: if result is not None:
@@ -844,6 +849,7 @@ class LLM(BaseLLM):
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
response_id=response_id
), ),
) )
# --- 4) Fallback to non-streaming if no content received # --- 4) Fallback to non-streaming if no content received
@@ -1021,6 +1027,7 @@ class LLM(BaseLLM):
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Task | None = None, from_task: Task | None = None,
from_agent: Agent | None = None, from_agent: Agent | None = None,
response_id: str | None = None,
) -> Any: ) -> Any:
for tool_call in tool_calls: for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index] current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -1041,6 +1048,7 @@ class LLM(BaseLLM):
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id
), ),
) )
@@ -1402,11 +1410,13 @@ class LLM(BaseLLM):
params["stream"] = True params["stream"] = True
params["stream_options"] = {"include_usage": True} params["stream_options"] = {"include_usage": True}
response_id = None
try: try:
async for chunk in await litellm.acompletion(**params): async for chunk in await litellm.acompletion(**params):
chunk_count += 1 chunk_count += 1
chunk_content = None chunk_content = None
response_id = chunk.id if hasattr(chunk, "id") else None
try: try:
choices = None choices = None
@@ -1466,6 +1476,7 @@ class LLM(BaseLLM):
chunk=chunk_content, chunk=chunk_content,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id
), ),
) )
@@ -1503,6 +1514,7 @@ class LLM(BaseLLM):
available_functions=available_functions, available_functions=available_functions,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id,
) )
if result is not None: if result is not None:
return result return result

View File

@@ -404,6 +404,7 @@ class BaseLLM(ABC):
from_agent: Agent | None = None, from_agent: Agent | None = None,
tool_call: dict[str, Any] | None = None, tool_call: dict[str, Any] | None = None,
call_type: LLMCallType | None = None, call_type: LLMCallType | None = None,
response_id: str | None = None
) -> None: ) -> None:
"""Emit stream chunk event. """Emit stream chunk event.
@@ -413,6 +414,7 @@ class BaseLLM(ABC):
from_agent: The agent that initiated the call. from_agent: The agent that initiated the call.
tool_call: Tool call information if this is a tool call chunk. tool_call: Tool call information if this is a tool call chunk.
call_type: The type of LLM call (LLM_CALL or TOOL_CALL). call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
response_id: Unique ID for a particular LLM response, chunks have same response_id.
""" """
if not hasattr(crewai_event_bus, "emit"): if not hasattr(crewai_event_bus, "emit"):
raise ValueError("crewai_event_bus does not have an emit method") from None raise ValueError("crewai_event_bus does not have an emit method") from None
@@ -425,6 +427,7 @@ class BaseLLM(ABC):
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
call_type=call_type, call_type=call_type,
response_id=response_id
), ),
) )

View File

@@ -700,7 +700,11 @@ class AnthropicCompletion(BaseLLM):
# Make streaming API call # Make streaming API call
with self.client.messages.stream(**stream_params) as stream: with self.client.messages.stream(**stream_params) as stream:
response_id = None
for event in stream: for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"):
response_id = event.message.id
if hasattr(event, "delta") and hasattr(event.delta, "text"): if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text text_delta = event.delta.text
full_response += text_delta full_response += text_delta
@@ -708,6 +712,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta, chunk=text_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id
) )
if event.type == "content_block_start": if event.type == "content_block_start":
@@ -734,6 +739,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id
) )
elif event.type == "content_block_delta": elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta": if event.delta.type == "input_json_delta":
@@ -757,6 +763,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id
) )
final_message: Message = stream.get_final_message() final_message: Message = stream.get_final_message()
@@ -1114,7 +1121,11 @@ class AnthropicCompletion(BaseLLM):
current_tool_calls: dict[int, dict[str, Any]] = {} current_tool_calls: dict[int, dict[str, Any]] = {}
async with self.async_client.messages.stream(**stream_params) as stream: async with self.async_client.messages.stream(**stream_params) as stream:
response_id = None
async for event in stream: async for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"):
response_id = event.message.id
if hasattr(event, "delta") and hasattr(event.delta, "text"): if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text text_delta = event.delta.text
full_response += text_delta full_response += text_delta
@@ -1122,6 +1133,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta, chunk=text_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id
) )
if event.type == "content_block_start": if event.type == "content_block_start":
@@ -1148,6 +1160,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id
) )
elif event.type == "content_block_delta": elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta": if event.delta.type == "input_json_delta":
@@ -1171,6 +1184,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id
) )
final_message: Message = await stream.get_final_message() final_message: Message = await stream.get_final_message()

View File

@@ -726,6 +726,7 @@ class AzureCompletion(BaseLLM):
""" """
if update.choices: if update.choices:
choice = update.choices[0] choice = update.choices[0]
response_id = update.id if hasattr(update,"id") else None
if choice.delta and choice.delta.content: if choice.delta and choice.delta.content:
content_delta = choice.delta.content content_delta = choice.delta.content
full_response += content_delta full_response += content_delta
@@ -733,6 +734,7 @@ class AzureCompletion(BaseLLM):
chunk=content_delta, chunk=content_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id
) )
if choice.delta and choice.delta.tool_calls: if choice.delta and choice.delta.tool_calls:
@@ -767,6 +769,7 @@ class AzureCompletion(BaseLLM):
"index": idx, "index": idx,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id
) )
return full_response return full_response

View File

@@ -736,6 +736,7 @@ class BedrockCompletion(BaseLLM):
) )
stream = response.get("stream") stream = response.get("stream")
response_id = None
if stream: if stream:
for event in stream: for event in stream:
if "messageStart" in event: if "messageStart" in event:
@@ -767,6 +768,7 @@ class BedrockCompletion(BaseLLM):
"index": tool_use_index, "index": tool_use_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id,
) )
logging.debug( logging.debug(
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})" f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
@@ -782,6 +784,7 @@ class BedrockCompletion(BaseLLM):
chunk=text_chunk, chunk=text_chunk,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id,
) )
elif "toolUse" in delta and current_tool_use: elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "") tool_input = delta["toolUse"].get("input", "")
@@ -802,6 +805,7 @@ class BedrockCompletion(BaseLLM):
"index": tool_use_index, "index": tool_use_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id
) )
elif "contentBlockStop" in event: elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream") logging.debug("Content block stopped in stream")
@@ -1122,6 +1126,7 @@ class BedrockCompletion(BaseLLM):
) )
stream = response.get("stream") stream = response.get("stream")
response_id = None
if stream: if stream:
async for event in stream: async for event in stream:
if "messageStart" in event: if "messageStart" in event:
@@ -1153,6 +1158,7 @@ class BedrockCompletion(BaseLLM):
"index": tool_use_index, "index": tool_use_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id,
) )
logging.debug( logging.debug(
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})" f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
@@ -1168,6 +1174,7 @@ class BedrockCompletion(BaseLLM):
chunk=text_chunk, chunk=text_chunk,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id
) )
elif "toolUse" in delta and current_tool_use: elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "") tool_input = delta["toolUse"].get("input", "")
@@ -1188,6 +1195,7 @@ class BedrockCompletion(BaseLLM):
"index": tool_use_index, "index": tool_use_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id,
) )
elif "contentBlockStop" in event: elif "contentBlockStop" in event:

View File

@@ -790,6 +790,7 @@ class GeminiCompletion(BaseLLM):
Returns: Returns:
Tuple of (updated full_response, updated function_calls, updated usage_data) Tuple of (updated full_response, updated function_calls, updated usage_data)
""" """
response_id=chunk.response_id if hasattr(chunk,"response_id") else None
if chunk.usage_metadata: if chunk.usage_metadata:
usage_data = self._extract_token_usage(chunk) usage_data = self._extract_token_usage(chunk)
@@ -799,6 +800,7 @@ class GeminiCompletion(BaseLLM):
chunk=chunk.text, chunk=chunk.text,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id
) )
if chunk.candidates: if chunk.candidates:
@@ -835,6 +837,7 @@ class GeminiCompletion(BaseLLM):
"index": call_index, "index": call_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id
) )
return full_response, function_calls, usage_data return full_response, function_calls, usage_data

View File

@@ -1047,8 +1047,12 @@ class OpenAICompletion(BaseLLM):
final_response: Response | None = None final_response: Response | None = None
stream = self.client.responses.create(**params) stream = self.client.responses.create(**params)
response_id_stream = None
for event in stream: for event in stream:
if event.type == "response.created":
response_id_stream = event.response.id
if event.type == "response.output_text.delta": if event.type == "response.output_text.delta":
delta_text = event.delta or "" delta_text = event.delta or ""
full_response += delta_text full_response += delta_text
@@ -1056,6 +1060,7 @@ class OpenAICompletion(BaseLLM):
chunk=delta_text, chunk=delta_text,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id_stream
) )
elif event.type == "response.function_call_arguments.delta": elif event.type == "response.function_call_arguments.delta":
@@ -1170,8 +1175,12 @@ class OpenAICompletion(BaseLLM):
final_response: Response | None = None final_response: Response | None = None
stream = await self.async_client.responses.create(**params) stream = await self.async_client.responses.create(**params)
response_id_stream = None
async for event in stream: async for event in stream:
if event.type == "response.created":
response_id_stream = event.response.id
if event.type == "response.output_text.delta": if event.type == "response.output_text.delta":
delta_text = event.delta or "" delta_text = event.delta or ""
full_response += delta_text full_response += delta_text
@@ -1179,6 +1188,7 @@ class OpenAICompletion(BaseLLM):
chunk=delta_text, chunk=delta_text,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id_stream,
) )
elif event.type == "response.function_call_arguments.delta": elif event.type == "response.function_call_arguments.delta":
@@ -1699,6 +1709,8 @@ class OpenAICompletion(BaseLLM):
**parse_params, response_format=response_model **parse_params, response_format=response_model
) as stream: ) as stream:
for chunk in stream: for chunk in stream:
response_id_stream=chunk.id if hasattr(chunk,"id") else None
if chunk.type == "content.delta": if chunk.type == "content.delta":
delta_content = chunk.delta delta_content = chunk.delta
if delta_content: if delta_content:
@@ -1706,6 +1718,7 @@ class OpenAICompletion(BaseLLM):
chunk=delta_content, chunk=delta_content,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id_stream
) )
final_completion = stream.get_final_completion() final_completion = stream.get_final_completion()
@@ -1735,6 +1748,8 @@ class OpenAICompletion(BaseLLM):
usage_data = {"total_tokens": 0} usage_data = {"total_tokens": 0}
for completion_chunk in completion_stream: for completion_chunk in completion_stream:
response_id_stream=completion_chunk.id if hasattr(completion_chunk,"id") else None
if hasattr(completion_chunk, "usage") and completion_chunk.usage: if hasattr(completion_chunk, "usage") and completion_chunk.usage:
usage_data = self._extract_openai_token_usage(completion_chunk) usage_data = self._extract_openai_token_usage(completion_chunk)
continue continue
@@ -1751,6 +1766,7 @@ class OpenAICompletion(BaseLLM):
chunk=chunk_delta.content, chunk=chunk_delta.content,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id_stream
) )
if chunk_delta.tool_calls: if chunk_delta.tool_calls:
@@ -1789,6 +1805,7 @@ class OpenAICompletion(BaseLLM):
"index": tool_calls[tool_index]["index"], "index": tool_calls[tool_index]["index"],
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id_stream
) )
self._track_token_usage_internal(usage_data) self._track_token_usage_internal(usage_data)
@@ -2000,6 +2017,8 @@ class OpenAICompletion(BaseLLM):
accumulated_content = "" accumulated_content = ""
usage_data = {"total_tokens": 0} usage_data = {"total_tokens": 0}
async for chunk in completion_stream: async for chunk in completion_stream:
response_id_stream=chunk.id if hasattr(chunk,"id") else None
if hasattr(chunk, "usage") and chunk.usage: if hasattr(chunk, "usage") and chunk.usage:
usage_data = self._extract_openai_token_usage(chunk) usage_data = self._extract_openai_token_usage(chunk)
continue continue
@@ -2016,6 +2035,7 @@ class OpenAICompletion(BaseLLM):
chunk=delta.content, chunk=delta.content,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id_stream
) )
self._track_token_usage_internal(usage_data) self._track_token_usage_internal(usage_data)
@@ -2051,6 +2071,8 @@ class OpenAICompletion(BaseLLM):
usage_data = {"total_tokens": 0} usage_data = {"total_tokens": 0}
async for chunk in stream: async for chunk in stream:
response_id_stream=chunk.id if hasattr(chunk,"id") else None
if hasattr(chunk, "usage") and chunk.usage: if hasattr(chunk, "usage") and chunk.usage:
usage_data = self._extract_openai_token_usage(chunk) usage_data = self._extract_openai_token_usage(chunk)
continue continue
@@ -2067,6 +2089,7 @@ class OpenAICompletion(BaseLLM):
chunk=chunk_delta.content, chunk=chunk_delta.content,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id_stream
) )
if chunk_delta.tool_calls: if chunk_delta.tool_calls:
@@ -2105,6 +2128,7 @@ class OpenAICompletion(BaseLLM):
"index": tool_calls[tool_index]["index"], "index": tool_calls[tool_index]["index"],
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id_stream
) )
self._track_token_usage_internal(usage_data) self._track_token_usage_internal(usage_data)

View File

@@ -511,10 +511,13 @@ def test_openai_streaming_with_response_model():
mock_chunk1 = MagicMock() mock_chunk1 = MagicMock()
mock_chunk1.type = "content.delta" mock_chunk1.type = "content.delta"
mock_chunk1.delta = '{"answer": "test", ' mock_chunk1.delta = '{"answer": "test", '
mock_chunk1.id = "response-1"
# Second chunk
mock_chunk2 = MagicMock() mock_chunk2 = MagicMock()
mock_chunk2.type = "content.delta" mock_chunk2.type = "content.delta"
mock_chunk2.delta = '"confidence": 0.95}' mock_chunk2.delta = '"confidence": 0.95}'
mock_chunk2.id = "response-2"
# Create mock final completion with parsed result # Create mock final completion with parsed result
mock_parsed = TestResponse(answer="test", confidence=0.95) mock_parsed = TestResponse(answer="test", confidence=0.95)

View File

@@ -984,8 +984,8 @@ def test_streaming_fallback_to_non_streaming():
def mock_call(messages, tools=None, callbacks=None, available_functions=None): def mock_call(messages, tools=None, callbacks=None, available_functions=None):
nonlocal fallback_called nonlocal fallback_called
# Emit a couple of chunks to simulate partial streaming # Emit a couple of chunks to simulate partial streaming
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1")) crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1", response_id = "Id"))
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2")) crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2", response_id = "Id"))
# Mark that fallback would be called # Mark that fallback would be called
fallback_called = True fallback_called = True
@@ -1041,7 +1041,7 @@ def test_streaming_empty_response_handling():
def mock_call(messages, tools=None, callbacks=None, available_functions=None): def mock_call(messages, tools=None, callbacks=None, available_functions=None):
# Emit a few empty chunks # Emit a few empty chunks
for _ in range(3): for _ in range(3):
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="")) crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="",response_id="id"))
# Return the default message for empty responses # Return the default message for empty responses
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request." return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."