feat: add support for streaming tool call events

This commit is contained in:
Greyson Lalonde
2025-12-11 02:40:34 -05:00
parent bdafe0fac7
commit 5ebc3d4a6c
6 changed files with 296 additions and 26 deletions

View File

@@ -354,8 +354,17 @@ class BaseLLM(ABC):
from_task: Task | None = None,
from_agent: Agent | None = None,
tool_call: dict[str, Any] | None = None,
call_type: LLMCallType | None = None,
) -> None:
"""Emit stream chunk event."""
"""Emit stream chunk event.
Args:
chunk: The text content of the chunk.
from_task: The task that initiated the call.
from_agent: The agent that initiated the call.
tool_call: Tool call information if this is a tool call chunk.
call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
"""
if not hasattr(crewai_event_bus, "emit"):
raise ValueError("crewai_event_bus does not have an emit method") from None
@@ -366,6 +375,7 @@ class BaseLLM(ABC):
tool_call=tool_call,
from_task=from_task,
from_agent=from_agent,
call_type=call_type,
),
)

View File

@@ -598,6 +598,8 @@ class AnthropicCompletion(BaseLLM):
# (the SDK sets it internally)
stream_params = {k: v for k, v in params.items() if k != "stream"}
current_tool_calls: dict[int, dict[str, Any]] = {}
# Make streaming API call
with self.client.messages.stream(**stream_params) as stream:
for event in stream:
@@ -610,6 +612,55 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent,
)
if event.type == "content_block_start":
block = event.content_block
if block.type == "tool_use":
block_index = event.index
current_tool_calls[block_index] = {
"id": block.id,
"name": block.name,
"arguments": "",
"index": block_index,
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": block.id,
"function": {
"name": block.name,
"arguments": "",
},
"type": "function",
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
)
elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta":
block_index = event.index
partial_json = event.delta.partial_json
if block_index in current_tool_calls and partial_json:
current_tool_calls[block_index]["arguments"] += partial_json
self._emit_stream_chunk_event(
chunk=partial_json,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": current_tool_calls[block_index]["id"],
"function": {
"name": current_tool_calls[block_index]["name"],
"arguments": current_tool_calls[block_index][
"arguments"
],
},
"type": "function",
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
)
final_message: Message = stream.get_final_message()
thinking_blocks: list[ThinkingBlock] = []
@@ -941,6 +992,8 @@ class AnthropicCompletion(BaseLLM):
stream_params = {k: v for k, v in params.items() if k != "stream"}
current_tool_calls: dict[int, dict[str, Any]] = {}
async with self.async_client.messages.stream(**stream_params) as stream:
async for event in stream:
if hasattr(event, "delta") and hasattr(event.delta, "text"):
@@ -952,6 +1005,55 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent,
)
if event.type == "content_block_start":
block = event.content_block
if block.type == "tool_use":
block_index = event.index
current_tool_calls[block_index] = {
"id": block.id,
"name": block.name,
"arguments": "",
"index": block_index,
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": block.id,
"function": {
"name": block.name,
"arguments": "",
},
"type": "function",
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
)
elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta":
block_index = event.index
partial_json = event.delta.partial_json
if block_index in current_tool_calls and partial_json:
current_tool_calls[block_index]["arguments"] += partial_json
self._emit_stream_chunk_event(
chunk=partial_json,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": current_tool_calls[block_index]["id"],
"function": {
"name": current_tool_calls[block_index]["name"],
"arguments": current_tool_calls[block_index][
"arguments"
],
},
"type": "function",
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
)
final_message: Message = await stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message)

View File

@@ -674,7 +674,7 @@ class AzureCompletion(BaseLLM):
self,
update: StreamingChatCompletionsUpdate,
full_response: str,
tool_calls: dict[str, dict[str, str]],
tool_calls: dict[str, dict[str, Any]],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
@@ -702,12 +702,13 @@ class AzureCompletion(BaseLLM):
)
if choice.delta and choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
for idx, tool_call in enumerate(choice.delta.tool_calls):
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
"index": idx,
}
if tool_call.function and tool_call.function.name:
@@ -715,12 +716,30 @@ class AzureCompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments
if tool_call.function and tool_call.function.arguments
else "",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": call_id,
"function": {
"name": tool_calls[call_id]["name"],
"arguments": tool_calls[call_id]["arguments"],
},
"type": "function",
"index": tool_calls[call_id]["index"],
},
call_type=LLMCallType.TOOL_CALL,
)
return full_response
def _finalize_streaming_response(
self,
full_response: str,
tool_calls: dict[str, dict[str, str]],
tool_calls: dict[str, dict[str, Any]],
usage_data: dict[str, int],
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,

View File

@@ -315,9 +315,7 @@ class BedrockCompletion(BaseLLM):
messages
)
if not self._invoke_before_llm_call_hooks(
cast(list[LLMMessage], formatted_messages), from_agent
):
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
# Prepare request body
@@ -361,7 +359,7 @@ class BedrockCompletion(BaseLLM):
if self.stream:
return self._handle_streaming_converse(
cast(list[LLMMessage], formatted_messages),
formatted_messages,
body,
available_functions,
from_task,
@@ -369,7 +367,7 @@ class BedrockCompletion(BaseLLM):
)
return self._handle_converse(
cast(list[LLMMessage], formatted_messages),
formatted_messages,
body,
available_functions,
from_task,
@@ -433,7 +431,7 @@ class BedrockCompletion(BaseLLM):
)
formatted_messages, system_message = self._format_messages_for_converse(
messages # type: ignore[arg-type]
messages
)
body: BedrockConverseRequestBody = {
@@ -687,8 +685,10 @@ class BedrockCompletion(BaseLLM):
) -> str:
"""Handle streaming converse API call with comprehensive event handling."""
full_response = ""
current_tool_use = None
tool_use_id = None
current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None
tool_use_index = 0
accumulated_tool_input = ""
try:
response = self.client.converse_stream(
@@ -709,9 +709,30 @@ class BedrockCompletion(BaseLLM):
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
content_block_index = event["contentBlockStart"].get(
"contentBlockIndex", 0
)
if "toolUse" in start:
current_tool_use = start["toolUse"]
tool_use_block = start["toolUse"]
current_tool_use = cast(dict[str, Any], tool_use_block)
tool_use_id = current_tool_use.get("toolUseId")
tool_use_index = content_block_index
accumulated_tool_input = ""
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_use_id or "",
"function": {
"name": current_tool_use.get("name", ""),
"arguments": "",
},
"type": "function",
"index": tool_use_index,
},
call_type=LLMCallType.TOOL_CALL,
)
logging.debug(
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
)
@@ -730,7 +751,23 @@ class BedrockCompletion(BaseLLM):
elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "")
if tool_input:
accumulated_tool_input += tool_input
logging.debug(f"Tool input delta: {tool_input}")
self._emit_stream_chunk_event(
chunk=tool_input,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_use_id or "",
"function": {
"name": current_tool_use.get("name", ""),
"arguments": accumulated_tool_input,
},
"type": "function",
"index": tool_use_index,
},
call_type=LLMCallType.TOOL_CALL,
)
elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream")
if current_tool_use and available_functions:
@@ -848,7 +885,7 @@ class BedrockCompletion(BaseLLM):
async def _ahandle_converse(
self,
messages: list[dict[str, Any]],
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
@@ -1013,7 +1050,7 @@ class BedrockCompletion(BaseLLM):
async def _ahandle_streaming_converse(
self,
messages: list[dict[str, Any]],
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -1021,8 +1058,10 @@ class BedrockCompletion(BaseLLM):
) -> str:
"""Handle async streaming converse API call."""
full_response = ""
current_tool_use = None
tool_use_id = None
current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None
tool_use_index = 0
accumulated_tool_input = ""
try:
async_client = await self._ensure_async_client()
@@ -1044,9 +1083,30 @@ class BedrockCompletion(BaseLLM):
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
content_block_index = event["contentBlockStart"].get(
"contentBlockIndex", 0
)
if "toolUse" in start:
current_tool_use = start["toolUse"]
tool_use_block = start["toolUse"]
current_tool_use = cast(dict[str, Any], tool_use_block)
tool_use_id = current_tool_use.get("toolUseId")
tool_use_index = content_block_index
accumulated_tool_input = ""
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_use_id or "",
"function": {
"name": current_tool_use.get("name", ""),
"arguments": "",
},
"type": "function",
"index": tool_use_index,
},
call_type=LLMCallType.TOOL_CALL,
)
logging.debug(
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
)
@@ -1065,7 +1125,23 @@ class BedrockCompletion(BaseLLM):
elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "")
if tool_input:
accumulated_tool_input += tool_input
logging.debug(f"Tool input delta: {tool_input}")
self._emit_stream_chunk_event(
chunk=tool_input,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_use_id or "",
"function": {
"name": current_tool_use.get("name", ""),
"arguments": accumulated_tool_input,
},
"type": "function",
"index": tool_use_index,
},
call_type=LLMCallType.TOOL_CALL,
)
elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream")
@@ -1174,7 +1250,7 @@ class BedrockCompletion(BaseLLM):
def _format_messages_for_converse(
self, messages: str | list[LLMMessage]
) -> tuple[list[dict[str, Any]], str | None]:
) -> tuple[list[LLMMessage], str | None]:
"""Format messages for Converse API following AWS documentation.
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
@@ -1184,7 +1260,7 @@ class BedrockCompletion(BaseLLM):
# Use base class formatting first
formatted_messages = self._format_messages(messages)
converse_messages: list[dict[str, Any]] = []
converse_messages: list[LLMMessage] = []
system_message: str | None = None
for message in formatted_messages:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import json
import logging
import os
import re
@@ -677,17 +678,39 @@ class GeminiCompletion(BaseLLM):
if chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
for idx, part in enumerate(candidate.content.parts):
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
args_dict = (
dict(part.function_call.args)
if part.function_call.args
else {}
)
args_json = json.dumps(args_dict)
if call_id not in function_calls:
function_calls[call_id] = {
"name": part.function_call.name,
"args": dict(part.function_call.args)
if part.function_call.args
else {},
"args": args_dict,
"index": idx,
}
self._emit_stream_chunk_event(
chunk=args_json,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": call_id,
"function": {
"name": part.function_call.name or "",
"arguments": args_json,
},
"type": "function",
"index": function_calls[call_id]["index"],
},
call_type=LLMCallType.TOOL_CALL,
)
return full_response, function_calls, usage_data
def _finalize_streaming_response(

View File

@@ -521,7 +521,7 @@ class OpenAICompletion(BaseLLM):
) -> str:
"""Handle streaming chat completion."""
full_response = ""
tool_calls = {}
tool_calls: dict[str, dict[str, Any]] = {}
if response_model:
parse_params = {
@@ -592,10 +592,12 @@ class OpenAICompletion(BaseLLM):
if chunk_delta.tool_calls:
for tool_call in chunk_delta.tool_calls:
call_id = tool_call.id or "default"
tool_index = tool_call.index if tool_call.index is not None else 0
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
"index": tool_index,
}
if tool_call.function and tool_call.function.name:
@@ -603,6 +605,24 @@ class OpenAICompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments
if tool_call.function and tool_call.function.arguments
else "",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": call_id,
"function": {
"name": tool_calls[call_id]["name"],
"arguments": tool_calls[call_id]["arguments"],
},
"type": "function",
"index": tool_calls[call_id]["index"],
},
call_type=LLMCallType.TOOL_CALL,
)
self._track_token_usage_internal(usage_data)
if tool_calls and available_functions:
@@ -789,7 +809,7 @@ class OpenAICompletion(BaseLLM):
) -> str:
"""Handle async streaming chat completion."""
full_response = ""
tool_calls = {}
tool_calls: dict[str, dict[str, Any]] = {}
if response_model:
completion_stream: AsyncIterator[
@@ -871,10 +891,12 @@ class OpenAICompletion(BaseLLM):
if chunk_delta.tool_calls:
for tool_call in chunk_delta.tool_calls:
call_id = tool_call.id or "default"
tool_index = tool_call.index if tool_call.index is not None else 0
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
"index": tool_index,
}
if tool_call.function and tool_call.function.name:
@@ -882,6 +904,24 @@ class OpenAICompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments
if tool_call.function and tool_call.function.arguments
else "",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": call_id,
"function": {
"name": tool_calls[call_id]["name"],
"arguments": tool_calls[call_id]["arguments"],
},
"type": "function",
"index": tool_calls[call_id]["index"],
},
call_type=LLMCallType.TOOL_CALL,
)
self._track_token_usage_internal(usage_data)
if tool_calls and available_functions: