mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 09:08:14 +00:00
feat: adding response_id in streaming response
This commit is contained in:
@@ -84,3 +84,4 @@ class LLMStreamChunkEvent(LLMEventBase):
|
||||
chunk: str
|
||||
tool_call: ToolCall | None = None
|
||||
call_type: LLMCallType | None = None
|
||||
response_id: str | None = None
|
||||
|
||||
@@ -768,6 +768,10 @@ class LLM(BaseLLM):
|
||||
|
||||
# Extract content from the chunk
|
||||
chunk_content = None
|
||||
response_id = None
|
||||
|
||||
if hasattr(chunk,'id'):
|
||||
response_id = chunk.id
|
||||
|
||||
# Safely extract content from various chunk formats
|
||||
try:
|
||||
@@ -823,6 +827,7 @@ class LLM(BaseLLM):
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
@@ -844,6 +849,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
response_id=response_id
|
||||
),
|
||||
)
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
@@ -1021,6 +1027,7 @@ class LLM(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_id: str | None = None,
|
||||
) -> Any:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -1041,6 +1048,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1402,11 +1410,13 @@ class LLM(BaseLLM):
|
||||
|
||||
params["stream"] = True
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
response_id = None
|
||||
|
||||
try:
|
||||
async for chunk in await litellm.acompletion(**params):
|
||||
chunk_count += 1
|
||||
chunk_content = None
|
||||
response_id = chunk.id if hasattr(chunk, "id") else None
|
||||
|
||||
try:
|
||||
choices = None
|
||||
@@ -1466,6 +1476,7 @@ class LLM(BaseLLM):
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1503,6 +1514,7 @@ class LLM(BaseLLM):
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
@@ -404,6 +404,7 @@ class BaseLLM(ABC):
|
||||
from_agent: Agent | None = None,
|
||||
tool_call: dict[str, Any] | None = None,
|
||||
call_type: LLMCallType | None = None,
|
||||
response_id: str | None = None
|
||||
) -> None:
|
||||
"""Emit stream chunk event.
|
||||
|
||||
@@ -413,6 +414,7 @@ class BaseLLM(ABC):
|
||||
from_agent: The agent that initiated the call.
|
||||
tool_call: Tool call information if this is a tool call chunk.
|
||||
call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
|
||||
response_id: Unique ID for a particular LLM response, chunks have same response_id.
|
||||
"""
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||
@@ -425,6 +427,7 @@ class BaseLLM(ABC):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=call_type,
|
||||
response_id=response_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -700,7 +700,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
response_id = None
|
||||
for event in stream:
|
||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
||||
response_id = event.message.id
|
||||
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
@@ -708,6 +712,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
chunk=text_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
if event.type == "content_block_start":
|
||||
@@ -734,6 +739,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "input_json_delta":
|
||||
@@ -757,6 +763,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
final_message: Message = stream.get_final_message()
|
||||
@@ -1114,7 +1121,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
async with self.async_client.messages.stream(**stream_params) as stream:
|
||||
response_id = None
|
||||
async for event in stream:
|
||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
||||
response_id = event.message.id
|
||||
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
@@ -1122,6 +1133,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
chunk=text_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
if event.type == "content_block_start":
|
||||
@@ -1148,6 +1160,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "input_json_delta":
|
||||
@@ -1171,6 +1184,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
final_message: Message = await stream.get_final_message()
|
||||
|
||||
@@ -726,6 +726,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
response_id = update.id if hasattr(update,"id") else None
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
@@ -733,6 +734,7 @@ class AzureCompletion(BaseLLM):
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
@@ -767,6 +769,7 @@ class AzureCompletion(BaseLLM):
|
||||
"index": idx,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
@@ -736,6 +736,7 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
response_id = None
|
||||
if stream:
|
||||
for event in stream:
|
||||
if "messageStart" in event:
|
||||
@@ -767,6 +768,7 @@ class BedrockCompletion(BaseLLM):
|
||||
"index": tool_use_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id,
|
||||
)
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
|
||||
@@ -782,6 +784,7 @@ class BedrockCompletion(BaseLLM):
|
||||
chunk=text_chunk,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id,
|
||||
)
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
@@ -802,6 +805,7 @@ class BedrockCompletion(BaseLLM):
|
||||
"index": tool_use_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
)
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
@@ -1122,6 +1126,7 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
response_id = None
|
||||
if stream:
|
||||
async for event in stream:
|
||||
if "messageStart" in event:
|
||||
@@ -1153,6 +1158,7 @@ class BedrockCompletion(BaseLLM):
|
||||
"index": tool_use_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id,
|
||||
)
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
||||
@@ -1168,6 +1174,7 @@ class BedrockCompletion(BaseLLM):
|
||||
chunk=text_chunk,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
)
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
@@ -1188,6 +1195,7 @@ class BedrockCompletion(BaseLLM):
|
||||
"index": tool_use_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
elif "contentBlockStop" in event:
|
||||
|
||||
@@ -790,6 +790,7 @@ class GeminiCompletion(BaseLLM):
|
||||
Returns:
|
||||
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
||||
"""
|
||||
response_id=chunk.response_id if hasattr(chunk,"response_id") else None
|
||||
if chunk.usage_metadata:
|
||||
usage_data = self._extract_token_usage(chunk)
|
||||
|
||||
@@ -799,6 +800,7 @@ class GeminiCompletion(BaseLLM):
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
@@ -835,6 +837,7 @@ class GeminiCompletion(BaseLLM):
|
||||
"index": call_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
return full_response, function_calls, usage_data
|
||||
|
||||
@@ -1047,8 +1047,12 @@ class OpenAICompletion(BaseLLM):
|
||||
final_response: Response | None = None
|
||||
|
||||
stream = self.client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
for event in stream:
|
||||
if event.type == "response.created":
|
||||
response_id_stream = event.response.id
|
||||
|
||||
if event.type == "response.output_text.delta":
|
||||
delta_text = event.delta or ""
|
||||
full_response += delta_text
|
||||
@@ -1056,6 +1060,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta_text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
elif event.type == "response.function_call_arguments.delta":
|
||||
@@ -1170,8 +1175,12 @@ class OpenAICompletion(BaseLLM):
|
||||
final_response: Response | None = None
|
||||
|
||||
stream = await self.async_client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
async for event in stream:
|
||||
if event.type == "response.created":
|
||||
response_id_stream = event.response.id
|
||||
|
||||
if event.type == "response.output_text.delta":
|
||||
delta_text = event.delta or ""
|
||||
full_response += delta_text
|
||||
@@ -1179,6 +1188,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta_text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
elif event.type == "response.function_call_arguments.delta":
|
||||
@@ -1699,6 +1709,8 @@ class OpenAICompletion(BaseLLM):
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
||||
|
||||
if chunk.type == "content.delta":
|
||||
delta_content = chunk.delta
|
||||
if delta_content:
|
||||
@@ -1706,6 +1718,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
final_completion = stream.get_final_completion()
|
||||
@@ -1735,6 +1748,8 @@ class OpenAICompletion(BaseLLM):
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
for completion_chunk in completion_stream:
|
||||
response_id_stream=completion_chunk.id if hasattr(completion_chunk,"id") else None
|
||||
|
||||
if hasattr(completion_chunk, "usage") and completion_chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(completion_chunk)
|
||||
continue
|
||||
@@ -1751,6 +1766,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=chunk_delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
if chunk_delta.tool_calls:
|
||||
@@ -1789,6 +1805,7 @@ class OpenAICompletion(BaseLLM):
|
||||
"index": tool_calls[tool_index]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
@@ -2000,6 +2017,8 @@ class OpenAICompletion(BaseLLM):
|
||||
accumulated_content = ""
|
||||
usage_data = {"total_tokens": 0}
|
||||
async for chunk in completion_stream:
|
||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(chunk)
|
||||
continue
|
||||
@@ -2016,6 +2035,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
@@ -2051,6 +2071,8 @@ class OpenAICompletion(BaseLLM):
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
async for chunk in stream:
|
||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(chunk)
|
||||
continue
|
||||
@@ -2067,6 +2089,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=chunk_delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
if chunk_delta.tool_calls:
|
||||
@@ -2105,6 +2128,7 @@ class OpenAICompletion(BaseLLM):
|
||||
"index": tool_calls[tool_index]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
@@ -511,10 +511,13 @@ def test_openai_streaming_with_response_model():
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.type = "content.delta"
|
||||
mock_chunk1.delta = '{"answer": "test", '
|
||||
mock_chunk1.id = "response-1"
|
||||
|
||||
# Second chunk
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.type = "content.delta"
|
||||
mock_chunk2.delta = '"confidence": 0.95}'
|
||||
mock_chunk2.id = "response-2"
|
||||
|
||||
# Create mock final completion with parsed result
|
||||
mock_parsed = TestResponse(answer="test", confidence=0.95)
|
||||
|
||||
@@ -984,8 +984,8 @@ def test_streaming_fallback_to_non_streaming():
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
nonlocal fallback_called
|
||||
# Emit a couple of chunks to simulate partial streaming
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1", response_id = "Id"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2", response_id = "Id"))
|
||||
|
||||
# Mark that fallback would be called
|
||||
fallback_called = True
|
||||
@@ -1041,7 +1041,7 @@ def test_streaming_empty_response_handling():
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
# Emit a few empty chunks
|
||||
for _ in range(3):
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="",response_id="id"))
|
||||
|
||||
# Return the default message for empty responses
|
||||
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||
|
||||
Reference in New Issue
Block a user