mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: add support for streaming tool call events
This commit is contained in:
@@ -354,8 +354,17 @@ class BaseLLM(ABC):
|
|||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: Agent | None = None,
|
||||||
tool_call: dict[str, Any] | None = None,
|
tool_call: dict[str, Any] | None = None,
|
||||||
|
call_type: LLMCallType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emit stream chunk event."""
|
"""Emit stream chunk event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: The text content of the chunk.
|
||||||
|
from_task: The task that initiated the call.
|
||||||
|
from_agent: The agent that initiated the call.
|
||||||
|
tool_call: Tool call information if this is a tool call chunk.
|
||||||
|
call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
|
||||||
|
"""
|
||||||
if not hasattr(crewai_event_bus, "emit"):
|
if not hasattr(crewai_event_bus, "emit"):
|
||||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||||
|
|
||||||
@@ -366,6 +375,7 @@ class BaseLLM(ABC):
|
|||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
|
call_type=call_type,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -598,6 +598,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
# (the SDK sets it internally)
|
# (the SDK sets it internally)
|
||||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||||
|
|
||||||
|
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
|
||||||
# Make streaming API call
|
# Make streaming API call
|
||||||
with self.client.messages.stream(**stream_params) as stream:
|
with self.client.messages.stream(**stream_params) as stream:
|
||||||
for event in stream:
|
for event in stream:
|
||||||
@@ -610,6 +612,55 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if event.type == "content_block_start":
|
||||||
|
block = event.content_block
|
||||||
|
if block.type == "tool_use":
|
||||||
|
block_index = event.index
|
||||||
|
current_tool_calls[block_index] = {
|
||||||
|
"id": block.id,
|
||||||
|
"name": block.name,
|
||||||
|
"arguments": "",
|
||||||
|
"index": block_index,
|
||||||
|
}
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk="",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": block.id,
|
||||||
|
"function": {
|
||||||
|
"name": block.name,
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": block_index,
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
elif event.type == "content_block_delta":
|
||||||
|
if event.delta.type == "input_json_delta":
|
||||||
|
block_index = event.index
|
||||||
|
partial_json = event.delta.partial_json
|
||||||
|
if block_index in current_tool_calls and partial_json:
|
||||||
|
current_tool_calls[block_index]["arguments"] += partial_json
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=partial_json,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": current_tool_calls[block_index]["id"],
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_calls[block_index]["name"],
|
||||||
|
"arguments": current_tool_calls[block_index][
|
||||||
|
"arguments"
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": block_index,
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
final_message: Message = stream.get_final_message()
|
final_message: Message = stream.get_final_message()
|
||||||
|
|
||||||
thinking_blocks: list[ThinkingBlock] = []
|
thinking_blocks: list[ThinkingBlock] = []
|
||||||
@@ -941,6 +992,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||||
|
|
||||||
|
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
|
||||||
async with self.async_client.messages.stream(**stream_params) as stream:
|
async with self.async_client.messages.stream(**stream_params) as stream:
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||||
@@ -952,6 +1005,55 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if event.type == "content_block_start":
|
||||||
|
block = event.content_block
|
||||||
|
if block.type == "tool_use":
|
||||||
|
block_index = event.index
|
||||||
|
current_tool_calls[block_index] = {
|
||||||
|
"id": block.id,
|
||||||
|
"name": block.name,
|
||||||
|
"arguments": "",
|
||||||
|
"index": block_index,
|
||||||
|
}
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk="",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": block.id,
|
||||||
|
"function": {
|
||||||
|
"name": block.name,
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": block_index,
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
elif event.type == "content_block_delta":
|
||||||
|
if event.delta.type == "input_json_delta":
|
||||||
|
block_index = event.index
|
||||||
|
partial_json = event.delta.partial_json
|
||||||
|
if block_index in current_tool_calls and partial_json:
|
||||||
|
current_tool_calls[block_index]["arguments"] += partial_json
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=partial_json,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": current_tool_calls[block_index]["id"],
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_calls[block_index]["name"],
|
||||||
|
"arguments": current_tool_calls[block_index][
|
||||||
|
"arguments"
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": block_index,
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
final_message: Message = await stream.get_final_message()
|
final_message: Message = await stream.get_final_message()
|
||||||
|
|
||||||
usage = self._extract_anthropic_token_usage(final_message)
|
usage = self._extract_anthropic_token_usage(final_message)
|
||||||
|
|||||||
@@ -674,7 +674,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
update: StreamingChatCompletionsUpdate,
|
update: StreamingChatCompletionsUpdate,
|
||||||
full_response: str,
|
full_response: str,
|
||||||
tool_calls: dict[str, dict[str, str]],
|
tool_calls: dict[str, dict[str, Any]],
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -702,12 +702,13 @@ class AzureCompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if choice.delta and choice.delta.tool_calls:
|
if choice.delta and choice.delta.tool_calls:
|
||||||
for tool_call in choice.delta.tool_calls:
|
for idx, tool_call in enumerate(choice.delta.tool_calls):
|
||||||
call_id = tool_call.id or "default"
|
call_id = tool_call.id or "default"
|
||||||
if call_id not in tool_calls:
|
if call_id not in tool_calls:
|
||||||
tool_calls[call_id] = {
|
tool_calls[call_id] = {
|
||||||
"name": "",
|
"name": "",
|
||||||
"arguments": "",
|
"arguments": "",
|
||||||
|
"index": idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool_call.function and tool_call.function.name:
|
if tool_call.function and tool_call.function.name:
|
||||||
@@ -715,12 +716,30 @@ class AzureCompletion(BaseLLM):
|
|||||||
if tool_call.function and tool_call.function.arguments:
|
if tool_call.function and tool_call.function.arguments:
|
||||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||||
|
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=tool_call.function.arguments
|
||||||
|
if tool_call.function and tool_call.function.arguments
|
||||||
|
else "",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": call_id,
|
||||||
|
"function": {
|
||||||
|
"name": tool_calls[call_id]["name"],
|
||||||
|
"arguments": tool_calls[call_id]["arguments"],
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": tool_calls[call_id]["index"],
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
def _finalize_streaming_response(
|
def _finalize_streaming_response(
|
||||||
self,
|
self,
|
||||||
full_response: str,
|
full_response: str,
|
||||||
tool_calls: dict[str, dict[str, str]],
|
tool_calls: dict[str, dict[str, Any]],
|
||||||
usage_data: dict[str, int],
|
usage_data: dict[str, int],
|
||||||
params: AzureCompletionParams,
|
params: AzureCompletionParams,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
|
|||||||
@@ -315,9 +315,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self._invoke_before_llm_call_hooks(
|
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||||
cast(list[LLMMessage], formatted_messages), from_agent
|
|
||||||
):
|
|
||||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||||
|
|
||||||
# Prepare request body
|
# Prepare request body
|
||||||
@@ -361,7 +359,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return self._handle_streaming_converse(
|
return self._handle_streaming_converse(
|
||||||
cast(list[LLMMessage], formatted_messages),
|
formatted_messages,
|
||||||
body,
|
body,
|
||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
@@ -369,7 +367,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_converse(
|
return self._handle_converse(
|
||||||
cast(list[LLMMessage], formatted_messages),
|
formatted_messages,
|
||||||
body,
|
body,
|
||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
@@ -433,7 +431,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages, system_message = self._format_messages_for_converse(
|
formatted_messages, system_message = self._format_messages_for_converse(
|
||||||
messages # type: ignore[arg-type]
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
body: BedrockConverseRequestBody = {
|
body: BedrockConverseRequestBody = {
|
||||||
@@ -687,8 +685,10 @@ class BedrockCompletion(BaseLLM):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming converse API call with comprehensive event handling."""
|
"""Handle streaming converse API call with comprehensive event handling."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
current_tool_use = None
|
current_tool_use: dict[str, Any] | None = None
|
||||||
tool_use_id = None
|
tool_use_id: str | None = None
|
||||||
|
tool_use_index = 0
|
||||||
|
accumulated_tool_input = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.converse_stream(
|
response = self.client.converse_stream(
|
||||||
@@ -709,9 +709,30 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
elif "contentBlockStart" in event:
|
elif "contentBlockStart" in event:
|
||||||
start = event["contentBlockStart"].get("start", {})
|
start = event["contentBlockStart"].get("start", {})
|
||||||
|
content_block_index = event["contentBlockStart"].get(
|
||||||
|
"contentBlockIndex", 0
|
||||||
|
)
|
||||||
if "toolUse" in start:
|
if "toolUse" in start:
|
||||||
current_tool_use = start["toolUse"]
|
tool_use_block = start["toolUse"]
|
||||||
|
current_tool_use = cast(dict[str, Any], tool_use_block)
|
||||||
tool_use_id = current_tool_use.get("toolUseId")
|
tool_use_id = current_tool_use.get("toolUseId")
|
||||||
|
tool_use_index = content_block_index
|
||||||
|
accumulated_tool_input = ""
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk="",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": tool_use_id or "",
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_use.get("name", ""),
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": tool_use_index,
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
|
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
|
||||||
)
|
)
|
||||||
@@ -730,7 +751,23 @@ class BedrockCompletion(BaseLLM):
|
|||||||
elif "toolUse" in delta and current_tool_use:
|
elif "toolUse" in delta and current_tool_use:
|
||||||
tool_input = delta["toolUse"].get("input", "")
|
tool_input = delta["toolUse"].get("input", "")
|
||||||
if tool_input:
|
if tool_input:
|
||||||
|
accumulated_tool_input += tool_input
|
||||||
logging.debug(f"Tool input delta: {tool_input}")
|
logging.debug(f"Tool input delta: {tool_input}")
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=tool_input,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": tool_use_id or "",
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_use.get("name", ""),
|
||||||
|
"arguments": accumulated_tool_input,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": tool_use_index,
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
elif "contentBlockStop" in event:
|
elif "contentBlockStop" in event:
|
||||||
logging.debug("Content block stopped in stream")
|
logging.debug("Content block stopped in stream")
|
||||||
if current_tool_use and available_functions:
|
if current_tool_use and available_functions:
|
||||||
@@ -848,7 +885,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
async def _ahandle_converse(
|
async def _ahandle_converse(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[LLMMessage],
|
||||||
body: BedrockConverseRequestBody,
|
body: BedrockConverseRequestBody,
|
||||||
available_functions: Mapping[str, Any] | None = None,
|
available_functions: Mapping[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
@@ -1013,7 +1050,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
async def _ahandle_streaming_converse(
|
async def _ahandle_streaming_converse(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[LLMMessage],
|
||||||
body: BedrockConverseRequestBody,
|
body: BedrockConverseRequestBody,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
@@ -1021,8 +1058,10 @@ class BedrockCompletion(BaseLLM):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Handle async streaming converse API call."""
|
"""Handle async streaming converse API call."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
current_tool_use = None
|
current_tool_use: dict[str, Any] | None = None
|
||||||
tool_use_id = None
|
tool_use_id: str | None = None
|
||||||
|
tool_use_index = 0
|
||||||
|
accumulated_tool_input = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async_client = await self._ensure_async_client()
|
async_client = await self._ensure_async_client()
|
||||||
@@ -1044,9 +1083,30 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
elif "contentBlockStart" in event:
|
elif "contentBlockStart" in event:
|
||||||
start = event["contentBlockStart"].get("start", {})
|
start = event["contentBlockStart"].get("start", {})
|
||||||
|
content_block_index = event["contentBlockStart"].get(
|
||||||
|
"contentBlockIndex", 0
|
||||||
|
)
|
||||||
if "toolUse" in start:
|
if "toolUse" in start:
|
||||||
current_tool_use = start["toolUse"]
|
tool_use_block = start["toolUse"]
|
||||||
|
current_tool_use = cast(dict[str, Any], tool_use_block)
|
||||||
tool_use_id = current_tool_use.get("toolUseId")
|
tool_use_id = current_tool_use.get("toolUseId")
|
||||||
|
tool_use_index = content_block_index
|
||||||
|
accumulated_tool_input = ""
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk="",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": tool_use_id or "",
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_use.get("name", ""),
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": tool_use_index,
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
||||||
)
|
)
|
||||||
@@ -1065,7 +1125,23 @@ class BedrockCompletion(BaseLLM):
|
|||||||
elif "toolUse" in delta and current_tool_use:
|
elif "toolUse" in delta and current_tool_use:
|
||||||
tool_input = delta["toolUse"].get("input", "")
|
tool_input = delta["toolUse"].get("input", "")
|
||||||
if tool_input:
|
if tool_input:
|
||||||
|
accumulated_tool_input += tool_input
|
||||||
logging.debug(f"Tool input delta: {tool_input}")
|
logging.debug(f"Tool input delta: {tool_input}")
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=tool_input,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": tool_use_id or "",
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_use.get("name", ""),
|
||||||
|
"arguments": accumulated_tool_input,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": tool_use_index,
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
elif "contentBlockStop" in event:
|
elif "contentBlockStop" in event:
|
||||||
logging.debug("Content block stopped in stream")
|
logging.debug("Content block stopped in stream")
|
||||||
@@ -1174,7 +1250,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
def _format_messages_for_converse(
|
def _format_messages_for_converse(
|
||||||
self, messages: str | list[LLMMessage]
|
self, messages: str | list[LLMMessage]
|
||||||
) -> tuple[list[dict[str, Any]], str | None]:
|
) -> tuple[list[LLMMessage], str | None]:
|
||||||
"""Format messages for Converse API following AWS documentation.
|
"""Format messages for Converse API following AWS documentation.
|
||||||
|
|
||||||
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
|
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
|
||||||
@@ -1184,7 +1260,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
# Use base class formatting first
|
# Use base class formatting first
|
||||||
formatted_messages = self._format_messages(messages)
|
formatted_messages = self._format_messages(messages)
|
||||||
|
|
||||||
converse_messages: list[dict[str, Any]] = []
|
converse_messages: list[LLMMessage] = []
|
||||||
system_message: str | None = None
|
system_message: str | None = None
|
||||||
|
|
||||||
for message in formatted_messages:
|
for message in formatted_messages:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -677,17 +678,39 @@ class GeminiCompletion(BaseLLM):
|
|||||||
if chunk.candidates:
|
if chunk.candidates:
|
||||||
candidate = chunk.candidates[0]
|
candidate = chunk.candidates[0]
|
||||||
if candidate.content and candidate.content.parts:
|
if candidate.content and candidate.content.parts:
|
||||||
for part in candidate.content.parts:
|
for idx, part in enumerate(candidate.content.parts):
|
||||||
if hasattr(part, "function_call") and part.function_call:
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
call_id = part.function_call.name or "default"
|
call_id = part.function_call.name or "default"
|
||||||
|
args_dict = (
|
||||||
|
dict(part.function_call.args)
|
||||||
|
if part.function_call.args
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
args_json = json.dumps(args_dict)
|
||||||
|
|
||||||
if call_id not in function_calls:
|
if call_id not in function_calls:
|
||||||
function_calls[call_id] = {
|
function_calls[call_id] = {
|
||||||
"name": part.function_call.name,
|
"name": part.function_call.name,
|
||||||
"args": dict(part.function_call.args)
|
"args": args_dict,
|
||||||
if part.function_call.args
|
"index": idx,
|
||||||
else {},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=args_json,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": call_id,
|
||||||
|
"function": {
|
||||||
|
"name": part.function_call.name or "",
|
||||||
|
"arguments": args_json,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": function_calls[call_id]["index"],
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
return full_response, function_calls, usage_data
|
return full_response, function_calls, usage_data
|
||||||
|
|
||||||
def _finalize_streaming_response(
|
def _finalize_streaming_response(
|
||||||
|
|||||||
@@ -521,7 +521,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming chat completion."""
|
"""Handle streaming chat completion."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
tool_calls = {}
|
tool_calls: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
if response_model:
|
if response_model:
|
||||||
parse_params = {
|
parse_params = {
|
||||||
@@ -592,10 +592,12 @@ class OpenAICompletion(BaseLLM):
|
|||||||
if chunk_delta.tool_calls:
|
if chunk_delta.tool_calls:
|
||||||
for tool_call in chunk_delta.tool_calls:
|
for tool_call in chunk_delta.tool_calls:
|
||||||
call_id = tool_call.id or "default"
|
call_id = tool_call.id or "default"
|
||||||
|
tool_index = tool_call.index if tool_call.index is not None else 0
|
||||||
if call_id not in tool_calls:
|
if call_id not in tool_calls:
|
||||||
tool_calls[call_id] = {
|
tool_calls[call_id] = {
|
||||||
"name": "",
|
"name": "",
|
||||||
"arguments": "",
|
"arguments": "",
|
||||||
|
"index": tool_index,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool_call.function and tool_call.function.name:
|
if tool_call.function and tool_call.function.name:
|
||||||
@@ -603,6 +605,24 @@ class OpenAICompletion(BaseLLM):
|
|||||||
if tool_call.function and tool_call.function.arguments:
|
if tool_call.function and tool_call.function.arguments:
|
||||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||||
|
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=tool_call.function.arguments
|
||||||
|
if tool_call.function and tool_call.function.arguments
|
||||||
|
else "",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": call_id,
|
||||||
|
"function": {
|
||||||
|
"name": tool_calls[call_id]["name"],
|
||||||
|
"arguments": tool_calls[call_id]["arguments"],
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": tool_calls[call_id]["index"],
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
self._track_token_usage_internal(usage_data)
|
self._track_token_usage_internal(usage_data)
|
||||||
|
|
||||||
if tool_calls and available_functions:
|
if tool_calls and available_functions:
|
||||||
@@ -789,7 +809,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Handle async streaming chat completion."""
|
"""Handle async streaming chat completion."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
tool_calls = {}
|
tool_calls: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
if response_model:
|
if response_model:
|
||||||
completion_stream: AsyncIterator[
|
completion_stream: AsyncIterator[
|
||||||
@@ -871,10 +891,12 @@ class OpenAICompletion(BaseLLM):
|
|||||||
if chunk_delta.tool_calls:
|
if chunk_delta.tool_calls:
|
||||||
for tool_call in chunk_delta.tool_calls:
|
for tool_call in chunk_delta.tool_calls:
|
||||||
call_id = tool_call.id or "default"
|
call_id = tool_call.id or "default"
|
||||||
|
tool_index = tool_call.index if tool_call.index is not None else 0
|
||||||
if call_id not in tool_calls:
|
if call_id not in tool_calls:
|
||||||
tool_calls[call_id] = {
|
tool_calls[call_id] = {
|
||||||
"name": "",
|
"name": "",
|
||||||
"arguments": "",
|
"arguments": "",
|
||||||
|
"index": tool_index,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool_call.function and tool_call.function.name:
|
if tool_call.function and tool_call.function.name:
|
||||||
@@ -882,6 +904,24 @@ class OpenAICompletion(BaseLLM):
|
|||||||
if tool_call.function and tool_call.function.arguments:
|
if tool_call.function and tool_call.function.arguments:
|
||||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||||
|
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=tool_call.function.arguments
|
||||||
|
if tool_call.function and tool_call.function.arguments
|
||||||
|
else "",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call={
|
||||||
|
"id": call_id,
|
||||||
|
"function": {
|
||||||
|
"name": tool_calls[call_id]["name"],
|
||||||
|
"arguments": tool_calls[call_id]["arguments"],
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": tool_calls[call_id]["index"],
|
||||||
|
},
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
self._track_token_usage_internal(usage_data)
|
self._track_token_usage_internal(usage_data)
|
||||||
|
|
||||||
if tool_calls and available_functions:
|
if tool_calls and available_functions:
|
||||||
|
|||||||
Reference in New Issue
Block a user