mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
feat: add support for streaming tool call events
This commit is contained in:
@@ -354,8 +354,17 @@ class BaseLLM(ABC):
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
tool_call: dict[str, Any] | None = None,
|
||||
call_type: LLMCallType | None = None,
|
||||
) -> None:
|
||||
"""Emit stream chunk event."""
|
||||
"""Emit stream chunk event.
|
||||
|
||||
Args:
|
||||
chunk: The text content of the chunk.
|
||||
from_task: The task that initiated the call.
|
||||
from_agent: The agent that initiated the call.
|
||||
tool_call: Tool call information if this is a tool call chunk.
|
||||
call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
|
||||
"""
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||
|
||||
@@ -366,6 +375,7 @@ class BaseLLM(ABC):
|
||||
tool_call=tool_call,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=call_type,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -598,6 +598,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
# (the SDK sets it internally)
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
for event in stream:
|
||||
@@ -610,6 +612,55 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if event.type == "content_block_start":
|
||||
block = event.content_block
|
||||
if block.type == "tool_use":
|
||||
block_index = event.index
|
||||
current_tool_calls[block_index] = {
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"arguments": "",
|
||||
"index": block_index,
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk="",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": block.id,
|
||||
"function": {
|
||||
"name": block.name,
|
||||
"arguments": "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "input_json_delta":
|
||||
block_index = event.index
|
||||
partial_json = event.delta.partial_json
|
||||
if block_index in current_tool_calls and partial_json:
|
||||
current_tool_calls[block_index]["arguments"] += partial_json
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=partial_json,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": current_tool_calls[block_index]["id"],
|
||||
"function": {
|
||||
"name": current_tool_calls[block_index]["name"],
|
||||
"arguments": current_tool_calls[block_index][
|
||||
"arguments"
|
||||
],
|
||||
},
|
||||
"type": "function",
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
final_message: Message = stream.get_final_message()
|
||||
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
@@ -941,6 +992,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
async with self.async_client.messages.stream(**stream_params) as stream:
|
||||
async for event in stream:
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
@@ -952,6 +1005,55 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if event.type == "content_block_start":
|
||||
block = event.content_block
|
||||
if block.type == "tool_use":
|
||||
block_index = event.index
|
||||
current_tool_calls[block_index] = {
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"arguments": "",
|
||||
"index": block_index,
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk="",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": block.id,
|
||||
"function": {
|
||||
"name": block.name,
|
||||
"arguments": "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "input_json_delta":
|
||||
block_index = event.index
|
||||
partial_json = event.delta.partial_json
|
||||
if block_index in current_tool_calls and partial_json:
|
||||
current_tool_calls[block_index]["arguments"] += partial_json
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=partial_json,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": current_tool_calls[block_index]["id"],
|
||||
"function": {
|
||||
"name": current_tool_calls[block_index]["name"],
|
||||
"arguments": current_tool_calls[block_index][
|
||||
"arguments"
|
||||
],
|
||||
},
|
||||
"type": "function",
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
final_message: Message = await stream.get_final_message()
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
|
||||
@@ -674,7 +674,7 @@ class AzureCompletion(BaseLLM):
|
||||
self,
|
||||
update: StreamingChatCompletionsUpdate,
|
||||
full_response: str,
|
||||
tool_calls: dict[str, dict[str, str]],
|
||||
tool_calls: dict[str, dict[str, Any]],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
@@ -702,12 +702,13 @@ class AzureCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
for idx, tool_call in enumerate(choice.delta.tool_calls):
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
"index": idx,
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
@@ -715,12 +716,30 @@ class AzureCompletion(BaseLLM):
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_call.function.arguments
|
||||
if tool_call.function and tool_call.function.arguments
|
||||
else "",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": tool_calls[call_id]["name"],
|
||||
"arguments": tool_calls[call_id]["arguments"],
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_calls[call_id]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def _finalize_streaming_response(
|
||||
self,
|
||||
full_response: str,
|
||||
tool_calls: dict[str, dict[str, str]],
|
||||
tool_calls: dict[str, dict[str, Any]],
|
||||
usage_data: dict[str, int],
|
||||
params: AzureCompletionParams,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
|
||||
@@ -315,9 +315,7 @@ class BedrockCompletion(BaseLLM):
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
cast(list[LLMMessage], formatted_messages), from_agent
|
||||
):
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare request body
|
||||
@@ -361,7 +359,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
cast(list[LLMMessage], formatted_messages),
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
@@ -369,7 +367,7 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
cast(list[LLMMessage], formatted_messages),
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
@@ -433,7 +431,7 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages # type: ignore[arg-type]
|
||||
messages
|
||||
)
|
||||
|
||||
body: BedrockConverseRequestBody = {
|
||||
@@ -687,8 +685,10 @@ class BedrockCompletion(BaseLLM):
|
||||
) -> str:
|
||||
"""Handle streaming converse API call with comprehensive event handling."""
|
||||
full_response = ""
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
tool_use_id: str | None = None
|
||||
tool_use_index = 0
|
||||
accumulated_tool_input = ""
|
||||
|
||||
try:
|
||||
response = self.client.converse_stream(
|
||||
@@ -709,9 +709,30 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
content_block_index = event["contentBlockStart"].get(
|
||||
"contentBlockIndex", 0
|
||||
)
|
||||
if "toolUse" in start:
|
||||
current_tool_use = start["toolUse"]
|
||||
tool_use_block = start["toolUse"]
|
||||
current_tool_use = cast(dict[str, Any], tool_use_block)
|
||||
tool_use_id = current_tool_use.get("toolUseId")
|
||||
tool_use_index = content_block_index
|
||||
accumulated_tool_input = ""
|
||||
self._emit_stream_chunk_event(
|
||||
chunk="",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": tool_use_id or "",
|
||||
"function": {
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"arguments": "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_use_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
|
||||
)
|
||||
@@ -730,7 +751,23 @@ class BedrockCompletion(BaseLLM):
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
if tool_input:
|
||||
accumulated_tool_input += tool_input
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_input,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": tool_use_id or "",
|
||||
"function": {
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"arguments": accumulated_tool_input,
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_use_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
if current_tool_use and available_functions:
|
||||
@@ -848,7 +885,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
async def _ahandle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
messages: list[LLMMessage],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -1013,7 +1050,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
async def _ahandle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
messages: list[LLMMessage],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -1021,8 +1058,10 @@ class BedrockCompletion(BaseLLM):
|
||||
) -> str:
|
||||
"""Handle async streaming converse API call."""
|
||||
full_response = ""
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
tool_use_id: str | None = None
|
||||
tool_use_index = 0
|
||||
accumulated_tool_input = ""
|
||||
|
||||
try:
|
||||
async_client = await self._ensure_async_client()
|
||||
@@ -1044,9 +1083,30 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
content_block_index = event["contentBlockStart"].get(
|
||||
"contentBlockIndex", 0
|
||||
)
|
||||
if "toolUse" in start:
|
||||
current_tool_use = start["toolUse"]
|
||||
tool_use_block = start["toolUse"]
|
||||
current_tool_use = cast(dict[str, Any], tool_use_block)
|
||||
tool_use_id = current_tool_use.get("toolUseId")
|
||||
tool_use_index = content_block_index
|
||||
accumulated_tool_input = ""
|
||||
self._emit_stream_chunk_event(
|
||||
chunk="",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": tool_use_id or "",
|
||||
"function": {
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"arguments": "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_use_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
||||
)
|
||||
@@ -1065,7 +1125,23 @@ class BedrockCompletion(BaseLLM):
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
if tool_input:
|
||||
accumulated_tool_input += tool_input
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_input,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": tool_use_id or "",
|
||||
"function": {
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"arguments": accumulated_tool_input,
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_use_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
@@ -1174,7 +1250,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
def _format_messages_for_converse(
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
) -> tuple[list[LLMMessage], str | None]:
|
||||
"""Format messages for Converse API following AWS documentation.
|
||||
|
||||
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
|
||||
@@ -1184,7 +1260,7 @@ class BedrockCompletion(BaseLLM):
|
||||
# Use base class formatting first
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
converse_messages: list[dict[str, Any]] = []
|
||||
converse_messages: list[LLMMessage] = []
|
||||
system_message: str | None = None
|
||||
|
||||
for message in formatted_messages:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -677,17 +678,39 @@ class GeminiCompletion(BaseLLM):
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
for idx, part in enumerate(candidate.content.parts):
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
args_dict = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {}
|
||||
)
|
||||
args_json = json.dumps(args_dict)
|
||||
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
"args": args_dict,
|
||||
"index": idx,
|
||||
}
|
||||
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=args_json,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": part.function_call.name or "",
|
||||
"arguments": args_json,
|
||||
},
|
||||
"type": "function",
|
||||
"index": function_calls[call_id]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
return full_response, function_calls, usage_data
|
||||
|
||||
def _finalize_streaming_response(
|
||||
|
||||
@@ -521,7 +521,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
tool_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if response_model:
|
||||
parse_params = {
|
||||
@@ -592,10 +592,12 @@ class OpenAICompletion(BaseLLM):
|
||||
if chunk_delta.tool_calls:
|
||||
for tool_call in chunk_delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
tool_index = tool_call.index if tool_call.index is not None else 0
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
"index": tool_index,
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
@@ -603,6 +605,24 @@ class OpenAICompletion(BaseLLM):
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_call.function.arguments
|
||||
if tool_call.function and tool_call.function.arguments
|
||||
else "",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": tool_calls[call_id]["name"],
|
||||
"arguments": tool_calls[call_id]["arguments"],
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_calls[call_id]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
@@ -789,7 +809,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str:
|
||||
"""Handle async streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
tool_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if response_model:
|
||||
completion_stream: AsyncIterator[
|
||||
@@ -871,10 +891,12 @@ class OpenAICompletion(BaseLLM):
|
||||
if chunk_delta.tool_calls:
|
||||
for tool_call in chunk_delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
tool_index = tool_call.index if tool_call.index is not None else 0
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
"index": tool_index,
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
@@ -882,6 +904,24 @@ class OpenAICompletion(BaseLLM):
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_call.function.arguments
|
||||
if tool_call.function and tool_call.function.arguments
|
||||
else "",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call={
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": tool_calls[call_id]["name"],
|
||||
"arguments": tool_calls[call_id]["arguments"],
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_calls[call_id]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
|
||||
Reference in New Issue
Block a user