mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: ensure openai tool call stream is finalized
This commit is contained in:
@@ -1696,6 +1696,99 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
def _finalize_streaming_response(
|
||||||
|
self,
|
||||||
|
full_response: str,
|
||||||
|
tool_calls: dict[int, dict[str, Any]],
|
||||||
|
usage_data: dict[str, int],
|
||||||
|
params: dict[str, Any],
|
||||||
|
available_functions: dict[str, Any] | None = None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
) -> str | list[dict[str, Any]]:
|
||||||
|
"""Finalize a streaming response with usage tracking, tool call handling, and events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
full_response: The accumulated text response from the stream.
|
||||||
|
tool_calls: Accumulated tool calls from the stream, keyed by index.
|
||||||
|
usage_data: Token usage data from the stream.
|
||||||
|
params: The completion parameters containing messages.
|
||||||
|
available_functions: Available functions for tool calling.
|
||||||
|
from_task: Task that initiated the call.
|
||||||
|
from_agent: Agent that initiated the call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool calls list when tools were invoked without available_functions,
|
||||||
|
tool execution result when available_functions is provided,
|
||||||
|
or the text response string.
|
||||||
|
"""
|
||||||
|
self._track_token_usage_internal(usage_data)
|
||||||
|
|
||||||
|
if tool_calls and not available_functions:
|
||||||
|
tool_calls_list = [
|
||||||
|
{
|
||||||
|
"id": call_data["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": call_data["name"],
|
||||||
|
"arguments": call_data["arguments"],
|
||||||
|
},
|
||||||
|
"index": call_data["index"],
|
||||||
|
}
|
||||||
|
for call_data in tool_calls.values()
|
||||||
|
]
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=tool_calls_list,
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=params["messages"],
|
||||||
|
)
|
||||||
|
return tool_calls_list
|
||||||
|
|
||||||
|
if tool_calls and available_functions:
|
||||||
|
for call_data in tool_calls.values():
|
||||||
|
function_name = call_data["name"]
|
||||||
|
arguments = call_data["arguments"]
|
||||||
|
|
||||||
|
if not function_name or not arguments:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if function_name not in available_functions:
|
||||||
|
logging.warning(
|
||||||
|
f"Function '{function_name}' not found in available functions"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
function_args = json.loads(arguments)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
result = self._handle_tool_execution(
|
||||||
|
function_name=function_name,
|
||||||
|
function_args=function_args,
|
||||||
|
available_functions=available_functions,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
full_response = self._apply_stop_words(full_response)
|
||||||
|
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=full_response,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=params["messages"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return full_response
|
||||||
|
|
||||||
def _handle_streaming_completion(
|
def _handle_streaming_completion(
|
||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
@@ -1703,7 +1796,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | BaseModel:
|
) -> str | list[dict[str, Any]] | BaseModel:
|
||||||
"""Handle streaming chat completion."""
|
"""Handle streaming chat completion."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
tool_calls: dict[int, dict[str, Any]] = {}
|
tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
@@ -1820,54 +1913,20 @@ class OpenAICompletion(BaseLLM):
|
|||||||
response_id=response_id_stream,
|
response_id=response_id_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._track_token_usage_internal(usage_data)
|
result = self._finalize_streaming_response(
|
||||||
|
full_response=full_response,
|
||||||
if tool_calls and available_functions:
|
tool_calls=tool_calls,
|
||||||
for call_data in tool_calls.values():
|
usage_data=usage_data,
|
||||||
function_name = call_data["name"]
|
params=params,
|
||||||
arguments = call_data["arguments"]
|
available_functions=available_functions,
|
||||||
|
|
||||||
# Skip if function name is empty or arguments are empty
|
|
||||||
if not function_name or not arguments:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if function exists in available functions
|
|
||||||
if function_name not in available_functions:
|
|
||||||
logging.warning(
|
|
||||||
f"Function '{function_name}' not found in available functions"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
function_args = json.loads(arguments)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
result = self._handle_tool_execution(
|
|
||||||
function_name=function_name,
|
|
||||||
function_args=function_args,
|
|
||||||
available_functions=available_functions,
|
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is not None:
|
|
||||||
return result
|
|
||||||
|
|
||||||
full_response = self._apply_stop_words(full_response)
|
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
|
||||||
response=full_response,
|
|
||||||
call_type=LLMCallType.LLM_CALL,
|
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._invoke_after_llm_call_hooks(
|
|
||||||
params["messages"], full_response, from_agent
|
|
||||||
)
|
)
|
||||||
|
if isinstance(result, str):
|
||||||
|
return self._invoke_after_llm_call_hooks(
|
||||||
|
params["messages"], result, from_agent
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
async def _ahandle_completion(
|
async def _ahandle_completion(
|
||||||
self,
|
self,
|
||||||
@@ -2016,7 +2075,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | BaseModel:
|
) -> str | list[dict[str, Any]] | BaseModel:
|
||||||
"""Handle async streaming chat completion."""
|
"""Handle async streaming chat completion."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
tool_calls: dict[int, dict[str, Any]] = {}
|
tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
@@ -2142,51 +2201,16 @@ class OpenAICompletion(BaseLLM):
|
|||||||
response_id=response_id_stream,
|
response_id=response_id_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._track_token_usage_internal(usage_data)
|
return self._finalize_streaming_response(
|
||||||
|
full_response=full_response,
|
||||||
if tool_calls and available_functions:
|
tool_calls=tool_calls,
|
||||||
for call_data in tool_calls.values():
|
usage_data=usage_data,
|
||||||
function_name = call_data["name"]
|
params=params,
|
||||||
arguments = call_data["arguments"]
|
available_functions=available_functions,
|
||||||
|
|
||||||
if not function_name or not arguments:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if function_name not in available_functions:
|
|
||||||
logging.warning(
|
|
||||||
f"Function '{function_name}' not found in available functions"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
function_args = json.loads(arguments)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
result = self._handle_tool_execution(
|
|
||||||
function_name=function_name,
|
|
||||||
function_args=function_args,
|
|
||||||
available_functions=available_functions,
|
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is not None:
|
|
||||||
return result
|
|
||||||
|
|
||||||
full_response = self._apply_stop_words(full_response)
|
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
|
||||||
response=full_response,
|
|
||||||
call_type=LLMCallType.LLM_CALL,
|
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return full_response
|
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
def supports_function_calling(self) -> bool:
|
||||||
"""Check if the model supports function calling."""
|
"""Check if the model supports function calling."""
|
||||||
return not self.is_o1_model
|
return not self.is_o1_model
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
@@ -1578,3 +1579,167 @@ def test_openai_structured_output_preserves_json_with_stop_word_patterns():
|
|||||||
assert "Action:" in result.action_taken
|
assert "Action:" in result.action_taken
|
||||||
assert "Observation:" in result.observation_result
|
assert "Observation:" in result.observation_result
|
||||||
assert "Final Answer:" in result.final_answer
|
assert "Final Answer:" in result.final_answer
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_streaming_returns_tool_calls_without_available_functions():
|
||||||
|
"""Test that streaming returns tool calls list when available_functions is None.
|
||||||
|
|
||||||
|
This mirrors the non-streaming path where tool_calls are returned for
|
||||||
|
the executor to handle. Reproduces the bug where streaming with tool
|
||||||
|
calls would return empty text instead of tool_calls when
|
||||||
|
available_functions was not provided (as the crew executor does).
|
||||||
|
"""
|
||||||
|
llm = LLM(model="openai/gpt-4o-mini", stream=True)
|
||||||
|
|
||||||
|
mock_chunk_1 = MagicMock()
|
||||||
|
mock_chunk_1.choices = [MagicMock()]
|
||||||
|
mock_chunk_1.choices[0].delta = MagicMock()
|
||||||
|
mock_chunk_1.choices[0].delta.content = None
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls = [MagicMock()]
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].index = 0
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].id = "call_abc123"
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].function = MagicMock()
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].function.name = "calculator"
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].function.arguments = '{"expr'
|
||||||
|
mock_chunk_1.choices[0].finish_reason = None
|
||||||
|
mock_chunk_1.usage = None
|
||||||
|
mock_chunk_1.id = "chatcmpl-1"
|
||||||
|
|
||||||
|
mock_chunk_2 = MagicMock()
|
||||||
|
mock_chunk_2.choices = [MagicMock()]
|
||||||
|
mock_chunk_2.choices[0].delta = MagicMock()
|
||||||
|
mock_chunk_2.choices[0].delta.content = None
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls = [MagicMock()]
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].index = 0
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].id = None
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].function = MagicMock()
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].function.name = None
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].function.arguments = 'ession": "1+1"}'
|
||||||
|
mock_chunk_2.choices[0].finish_reason = None
|
||||||
|
mock_chunk_2.usage = None
|
||||||
|
mock_chunk_2.id = "chatcmpl-1"
|
||||||
|
|
||||||
|
mock_chunk_3 = MagicMock()
|
||||||
|
mock_chunk_3.choices = [MagicMock()]
|
||||||
|
mock_chunk_3.choices[0].delta = MagicMock()
|
||||||
|
mock_chunk_3.choices[0].delta.content = None
|
||||||
|
mock_chunk_3.choices[0].delta.tool_calls = None
|
||||||
|
mock_chunk_3.choices[0].finish_reason = "tool_calls"
|
||||||
|
mock_chunk_3.usage = MagicMock()
|
||||||
|
mock_chunk_3.usage.prompt_tokens = 10
|
||||||
|
mock_chunk_3.usage.completion_tokens = 5
|
||||||
|
mock_chunk_3.id = "chatcmpl-1"
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
llm.client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||||
|
):
|
||||||
|
result = llm.call(
|
||||||
|
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||||
|
tools=[{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "calculator",
|
||||||
|
"description": "Calculate expression",
|
||||||
|
"parameters": {"type": "object", "properties": {"expression": {"type": "string"}}},
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
available_functions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, list), f"Expected list of tool calls, got {type(result)}: {result}"
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["function"]["name"] == "calculator"
|
||||||
|
assert result[0]["function"]["arguments"] == '{"expression": "1+1"}'
|
||||||
|
assert result[0]["id"] == "call_abc123"
|
||||||
|
assert result[0]["type"] == "function"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_async_streaming_returns_tool_calls_without_available_functions():
|
||||||
|
"""Test that async streaming returns tool calls list when available_functions is None.
|
||||||
|
|
||||||
|
Same as the sync test but for the async path (_ahandle_streaming_completion).
|
||||||
|
"""
|
||||||
|
llm = LLM(model="openai/gpt-4o-mini", stream=True)
|
||||||
|
|
||||||
|
mock_chunk_1 = MagicMock()
|
||||||
|
mock_chunk_1.choices = [MagicMock()]
|
||||||
|
mock_chunk_1.choices[0].delta = MagicMock()
|
||||||
|
mock_chunk_1.choices[0].delta.content = None
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls = [MagicMock()]
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].index = 0
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].id = "call_abc123"
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].function = MagicMock()
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].function.name = "calculator"
|
||||||
|
mock_chunk_1.choices[0].delta.tool_calls[0].function.arguments = '{"expr'
|
||||||
|
mock_chunk_1.choices[0].finish_reason = None
|
||||||
|
mock_chunk_1.usage = None
|
||||||
|
mock_chunk_1.id = "chatcmpl-1"
|
||||||
|
|
||||||
|
mock_chunk_2 = MagicMock()
|
||||||
|
mock_chunk_2.choices = [MagicMock()]
|
||||||
|
mock_chunk_2.choices[0].delta = MagicMock()
|
||||||
|
mock_chunk_2.choices[0].delta.content = None
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls = [MagicMock()]
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].index = 0
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].id = None
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].function = MagicMock()
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].function.name = None
|
||||||
|
mock_chunk_2.choices[0].delta.tool_calls[0].function.arguments = 'ession": "1+1"}'
|
||||||
|
mock_chunk_2.choices[0].finish_reason = None
|
||||||
|
mock_chunk_2.usage = None
|
||||||
|
mock_chunk_2.id = "chatcmpl-1"
|
||||||
|
|
||||||
|
mock_chunk_3 = MagicMock()
|
||||||
|
mock_chunk_3.choices = [MagicMock()]
|
||||||
|
mock_chunk_3.choices[0].delta = MagicMock()
|
||||||
|
mock_chunk_3.choices[0].delta.content = None
|
||||||
|
mock_chunk_3.choices[0].delta.tool_calls = None
|
||||||
|
mock_chunk_3.choices[0].finish_reason = "tool_calls"
|
||||||
|
mock_chunk_3.usage = MagicMock()
|
||||||
|
mock_chunk_3.usage.prompt_tokens = 10
|
||||||
|
mock_chunk_3.usage.completion_tokens = 5
|
||||||
|
mock_chunk_3.id = "chatcmpl-1"
|
||||||
|
|
||||||
|
class MockAsyncStream:
|
||||||
|
"""Async iterator that mimics OpenAI's async streaming response."""
|
||||||
|
|
||||||
|
def __init__(self, chunks: list[Any]) -> None:
|
||||||
|
self._chunks = chunks
|
||||||
|
self._index = 0
|
||||||
|
|
||||||
|
def __aiter__(self) -> "MockAsyncStream":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> Any:
|
||||||
|
if self._index >= len(self._chunks):
|
||||||
|
raise StopAsyncIteration
|
||||||
|
chunk = self._chunks[self._index]
|
||||||
|
self._index += 1
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
async def mock_create(**kwargs: Any) -> MockAsyncStream:
|
||||||
|
return MockAsyncStream([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
llm.async_client.chat.completions, "create", side_effect=mock_create
|
||||||
|
):
|
||||||
|
result = await llm.acall(
|
||||||
|
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||||
|
tools=[{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "calculator",
|
||||||
|
"description": "Calculate expression",
|
||||||
|
"parameters": {"type": "object", "properties": {"expression": {"type": "string"}}},
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
available_functions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, list), f"Expected list of tool calls, got {type(result)}: {result}"
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["function"]["name"] == "calculator"
|
||||||
|
assert result[0]["function"]["arguments"] == '{"expression": "1+1"}'
|
||||||
|
assert result[0]["id"] == "call_abc123"
|
||||||
|
assert result[0]["type"] == "function"
|
||||||
|
|||||||
Reference in New Issue
Block a user