feat: add streaming tool call events; fix provider id tracking; add tests and cassettes
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled

Adds support for streaming tool call events with test coverage, fixes tool-stream ID tracking (including OpenAI-style tracking for Azure), improves Gemini tool calling + streaming tests, adds Anthropic tests, generates Azure cassettes, and fixes Azure cassette URIs.
This commit is contained in:
Greyson LaLonde
2026-01-05 14:33:36 -05:00
committed by GitHub
parent f3c17a249b
commit f8deb0fd18
15 changed files with 1798 additions and 60 deletions

View File

@@ -354,8 +354,17 @@ class BaseLLM(ABC):
from_task: Task | None = None,
from_agent: Agent | None = None,
tool_call: dict[str, Any] | None = None,
call_type: LLMCallType | None = None,
) -> None:
"""Emit stream chunk event."""
"""Emit stream chunk event.
Args:
chunk: The text content of the chunk.
from_task: The task that initiated the call.
from_agent: The agent that initiated the call.
tool_call: Tool call information if this is a tool call chunk.
call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
"""
if not hasattr(crewai_event_bus, "emit"):
raise ValueError("crewai_event_bus does not have an emit method") from None
@@ -366,6 +375,7 @@ class BaseLLM(ABC):
tool_call=tool_call,
from_task=from_task,
from_agent=from_agent,
call_type=call_type,
),
)

View File

@@ -598,6 +598,8 @@ class AnthropicCompletion(BaseLLM):
# (the SDK sets it internally)
stream_params = {k: v for k, v in params.items() if k != "stream"}
current_tool_calls: dict[int, dict[str, Any]] = {}
# Make streaming API call
with self.client.messages.stream(**stream_params) as stream:
for event in stream:
@@ -610,6 +612,55 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent,
)
if event.type == "content_block_start":
block = event.content_block
if block.type == "tool_use":
block_index = event.index
current_tool_calls[block_index] = {
"id": block.id,
"name": block.name,
"arguments": "",
"index": block_index,
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": block.id,
"function": {
"name": block.name,
"arguments": "",
},
"type": "function",
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
)
elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta":
block_index = event.index
partial_json = event.delta.partial_json
if block_index in current_tool_calls and partial_json:
current_tool_calls[block_index]["arguments"] += partial_json
self._emit_stream_chunk_event(
chunk=partial_json,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": current_tool_calls[block_index]["id"],
"function": {
"name": current_tool_calls[block_index]["name"],
"arguments": current_tool_calls[block_index][
"arguments"
],
},
"type": "function",
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
)
final_message: Message = stream.get_final_message()
thinking_blocks: list[ThinkingBlock] = []
@@ -941,6 +992,8 @@ class AnthropicCompletion(BaseLLM):
stream_params = {k: v for k, v in params.items() if k != "stream"}
current_tool_calls: dict[int, dict[str, Any]] = {}
async with self.async_client.messages.stream(**stream_params) as stream:
async for event in stream:
if hasattr(event, "delta") and hasattr(event.delta, "text"):
@@ -952,6 +1005,55 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent,
)
if event.type == "content_block_start":
block = event.content_block
if block.type == "tool_use":
block_index = event.index
current_tool_calls[block_index] = {
"id": block.id,
"name": block.name,
"arguments": "",
"index": block_index,
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": block.id,
"function": {
"name": block.name,
"arguments": "",
},
"type": "function",
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
)
elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta":
block_index = event.index
partial_json = event.delta.partial_json
if block_index in current_tool_calls and partial_json:
current_tool_calls[block_index]["arguments"] += partial_json
self._emit_stream_chunk_event(
chunk=partial_json,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": current_tool_calls[block_index]["id"],
"function": {
"name": current_tool_calls[block_index]["name"],
"arguments": current_tool_calls[block_index][
"arguments"
],
},
"type": "function",
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
)
final_message: Message = await stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message)

View File

@@ -674,7 +674,7 @@ class AzureCompletion(BaseLLM):
self,
update: StreamingChatCompletionsUpdate,
full_response: str,
tool_calls: dict[str, dict[str, str]],
tool_calls: dict[int, dict[str, Any]],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
@@ -702,25 +702,45 @@ class AzureCompletion(BaseLLM):
)
if choice.delta and choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
for idx, tool_call in enumerate(choice.delta.tool_calls):
if idx not in tool_calls:
tool_calls[idx] = {
"id": tool_call.id,
"name": "",
"arguments": "",
}
elif tool_call.id and not tool_calls[idx]["id"]:
tool_calls[idx]["id"] = tool_call.id
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
tool_calls[idx]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
tool_calls[idx]["arguments"] += tool_call.function.arguments
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments
if tool_call.function and tool_call.function.arguments
else "",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_calls[idx]["id"],
"function": {
"name": tool_calls[idx]["name"],
"arguments": tool_calls[idx]["arguments"],
},
"type": "function",
"index": idx,
},
call_type=LLMCallType.TOOL_CALL,
)
return full_response
def _finalize_streaming_response(
self,
full_response: str,
tool_calls: dict[str, dict[str, str]],
tool_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,
@@ -804,7 +824,7 @@ class AzureCompletion(BaseLLM):
) -> str | Any:
"""Handle streaming chat completion."""
full_response = ""
tool_calls: dict[str, dict[str, Any]] = {}
tool_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
for update in self.client.complete(**params): # type: ignore[arg-type]
@@ -870,7 +890,7 @@ class AzureCompletion(BaseLLM):
) -> str | Any:
"""Handle streaming chat completion asynchronously."""
full_response = ""
tool_calls: dict[str, dict[str, Any]] = {}
tool_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}

View File

@@ -315,9 +315,7 @@ class BedrockCompletion(BaseLLM):
messages
)
if not self._invoke_before_llm_call_hooks(
cast(list[LLMMessage], formatted_messages), from_agent
):
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
# Prepare request body
@@ -361,7 +359,7 @@ class BedrockCompletion(BaseLLM):
if self.stream:
return self._handle_streaming_converse(
cast(list[LLMMessage], formatted_messages),
formatted_messages,
body,
available_functions,
from_task,
@@ -369,7 +367,7 @@ class BedrockCompletion(BaseLLM):
)
return self._handle_converse(
cast(list[LLMMessage], formatted_messages),
formatted_messages,
body,
available_functions,
from_task,
@@ -433,7 +431,7 @@ class BedrockCompletion(BaseLLM):
)
formatted_messages, system_message = self._format_messages_for_converse(
messages # type: ignore[arg-type]
messages
)
body: BedrockConverseRequestBody = {
@@ -687,8 +685,10 @@ class BedrockCompletion(BaseLLM):
) -> str:
"""Handle streaming converse API call with comprehensive event handling."""
full_response = ""
current_tool_use = None
tool_use_id = None
current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None
tool_use_index = 0
accumulated_tool_input = ""
try:
response = self.client.converse_stream(
@@ -709,9 +709,30 @@ class BedrockCompletion(BaseLLM):
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
content_block_index = event["contentBlockStart"].get(
"contentBlockIndex", 0
)
if "toolUse" in start:
current_tool_use = start["toolUse"]
tool_use_block = start["toolUse"]
current_tool_use = cast(dict[str, Any], tool_use_block)
tool_use_id = current_tool_use.get("toolUseId")
tool_use_index = content_block_index
accumulated_tool_input = ""
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_use_id or "",
"function": {
"name": current_tool_use.get("name", ""),
"arguments": "",
},
"type": "function",
"index": tool_use_index,
},
call_type=LLMCallType.TOOL_CALL,
)
logging.debug(
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
)
@@ -730,7 +751,23 @@ class BedrockCompletion(BaseLLM):
elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "")
if tool_input:
accumulated_tool_input += tool_input
logging.debug(f"Tool input delta: {tool_input}")
self._emit_stream_chunk_event(
chunk=tool_input,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_use_id or "",
"function": {
"name": current_tool_use.get("name", ""),
"arguments": accumulated_tool_input,
},
"type": "function",
"index": tool_use_index,
},
call_type=LLMCallType.TOOL_CALL,
)
elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream")
if current_tool_use and available_functions:
@@ -848,7 +885,7 @@ class BedrockCompletion(BaseLLM):
async def _ahandle_converse(
self,
messages: list[dict[str, Any]],
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
@@ -1013,7 +1050,7 @@ class BedrockCompletion(BaseLLM):
async def _ahandle_streaming_converse(
self,
messages: list[dict[str, Any]],
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -1021,8 +1058,10 @@ class BedrockCompletion(BaseLLM):
) -> str:
"""Handle async streaming converse API call."""
full_response = ""
current_tool_use = None
tool_use_id = None
current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None
tool_use_index = 0
accumulated_tool_input = ""
try:
async_client = await self._ensure_async_client()
@@ -1044,9 +1083,30 @@ class BedrockCompletion(BaseLLM):
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
content_block_index = event["contentBlockStart"].get(
"contentBlockIndex", 0
)
if "toolUse" in start:
current_tool_use = start["toolUse"]
tool_use_block = start["toolUse"]
current_tool_use = cast(dict[str, Any], tool_use_block)
tool_use_id = current_tool_use.get("toolUseId")
tool_use_index = content_block_index
accumulated_tool_input = ""
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_use_id or "",
"function": {
"name": current_tool_use.get("name", ""),
"arguments": "",
},
"type": "function",
"index": tool_use_index,
},
call_type=LLMCallType.TOOL_CALL,
)
logging.debug(
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
)
@@ -1065,7 +1125,23 @@ class BedrockCompletion(BaseLLM):
elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "")
if tool_input:
accumulated_tool_input += tool_input
logging.debug(f"Tool input delta: {tool_input}")
self._emit_stream_chunk_event(
chunk=tool_input,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_use_id or "",
"function": {
"name": current_tool_use.get("name", ""),
"arguments": accumulated_tool_input,
},
"type": "function",
"index": tool_use_index,
},
call_type=LLMCallType.TOOL_CALL,
)
elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream")
@@ -1174,7 +1250,7 @@ class BedrockCompletion(BaseLLM):
def _format_messages_for_converse(
self, messages: str | list[LLMMessage]
) -> tuple[list[dict[str, Any]], str | None]:
) -> tuple[list[LLMMessage], str | None]:
"""Format messages for Converse API following AWS documentation.
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
@@ -1184,7 +1260,7 @@ class BedrockCompletion(BaseLLM):
# Use base class formatting first
formatted_messages = self._format_messages(messages)
converse_messages: list[dict[str, Any]] = []
converse_messages: list[LLMMessage] = []
system_message: str | None = None
for message in formatted_messages:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import json
import logging
import os
import re
@@ -24,7 +25,7 @@ try:
from google import genai
from google.genai import types
from google.genai.errors import APIError
from google.genai.types import GenerateContentResponse, Schema
from google.genai.types import GenerateContentResponse
except ImportError:
raise ImportError(
'Google Gen AI native provider not available, to install: uv add "crewai[google-genai]"'
@@ -434,12 +435,9 @@ class GeminiCompletion(BaseLLM):
function_declaration = types.FunctionDeclaration(
name=name,
description=description,
parameters=parameters if parameters else None,
)
# Add parameters if present - ensure parameters is a dict
if parameters and isinstance(parameters, Schema):
function_declaration.parameters = parameters
gemini_tool = types.Tool(function_declarations=[function_declaration])
gemini_tools.append(gemini_tool)
@@ -609,7 +607,7 @@ class GeminiCompletion(BaseLLM):
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
if part.function_call:
function_name = part.function_call.name
if function_name is None:
continue
@@ -645,17 +643,17 @@ class GeminiCompletion(BaseLLM):
self,
chunk: GenerateContentResponse,
full_response: str,
function_calls: dict[str, dict[str, Any]],
function_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> tuple[str, dict[str, dict[str, Any]], dict[str, int]]:
) -> tuple[str, dict[int, dict[str, Any]], dict[str, int]]:
"""Process a single streaming chunk.
Args:
chunk: The streaming chunk response
full_response: Accumulated response text
function_calls: Accumulated function calls
function_calls: Accumulated function calls keyed by sequential index
usage_data: Accumulated usage data
from_task: Task that initiated the call
from_agent: Agent that initiated the call
@@ -678,22 +676,44 @@ class GeminiCompletion(BaseLLM):
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
if call_id not in function_calls:
function_calls[call_id] = {
"name": part.function_call.name,
"args": dict(part.function_call.args)
if part.function_call.args
else {},
}
if part.function_call:
call_index = len(function_calls)
call_id = f"call_{call_index}"
args_dict = (
dict(part.function_call.args)
if part.function_call.args
else {}
)
args_json = json.dumps(args_dict)
function_calls[call_index] = {
"id": call_id,
"name": part.function_call.name,
"args": args_dict,
}
self._emit_stream_chunk_event(
chunk=args_json,
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": call_id,
"function": {
"name": part.function_call.name or "",
"arguments": args_json,
},
"type": "function",
"index": call_index,
},
call_type=LLMCallType.TOOL_CALL,
)
return full_response, function_calls, usage_data
def _finalize_streaming_response(
self,
full_response: str,
function_calls: dict[str, dict[str, Any]],
function_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
contents: list[types.Content],
available_functions: dict[str, Any] | None = None,
@@ -800,7 +820,7 @@ class GeminiCompletion(BaseLLM):
) -> str:
"""Handle streaming content generation."""
full_response = ""
function_calls: dict[str, dict[str, Any]] = {}
function_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
# The API accepts list[Content] but mypy is overly strict about variance
@@ -878,7 +898,7 @@ class GeminiCompletion(BaseLLM):
) -> str:
"""Handle async streaming content generation."""
full_response = ""
function_calls: dict[str, dict[str, Any]] = {}
function_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
# The API accepts list[Content] but mypy is overly strict about variance

View File

@@ -521,7 +521,7 @@ class OpenAICompletion(BaseLLM):
) -> str:
"""Handle streaming chat completion."""
full_response = ""
tool_calls = {}
tool_calls: dict[int, dict[str, Any]] = {}
if response_model:
parse_params = {
@@ -591,17 +591,41 @@ class OpenAICompletion(BaseLLM):
if chunk_delta.tool_calls:
for tool_call in chunk_delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
tool_index = tool_call.index if tool_call.index is not None else 0
if tool_index not in tool_calls:
tool_calls[tool_index] = {
"id": tool_call.id,
"name": "",
"arguments": "",
"index": tool_index,
}
elif tool_call.id and not tool_calls[tool_index]["id"]:
tool_calls[tool_index]["id"] = tool_call.id
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
tool_calls[tool_index]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
tool_calls[tool_index]["arguments"] += (
tool_call.function.arguments
)
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments
if tool_call.function and tool_call.function.arguments
else "",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_calls[tool_index]["id"],
"function": {
"name": tool_calls[tool_index]["name"],
"arguments": tool_calls[tool_index]["arguments"],
},
"type": "function",
"index": tool_calls[tool_index]["index"],
},
call_type=LLMCallType.TOOL_CALL,
)
self._track_token_usage_internal(usage_data)
@@ -789,7 +813,7 @@ class OpenAICompletion(BaseLLM):
) -> str:
"""Handle async streaming chat completion."""
full_response = ""
tool_calls = {}
tool_calls: dict[int, dict[str, Any]] = {}
if response_model:
completion_stream: AsyncIterator[
@@ -870,17 +894,41 @@ class OpenAICompletion(BaseLLM):
if chunk_delta.tool_calls:
for tool_call in chunk_delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
tool_index = tool_call.index if tool_call.index is not None else 0
if tool_index not in tool_calls:
tool_calls[tool_index] = {
"id": tool_call.id,
"name": "",
"arguments": "",
"index": tool_index,
}
elif tool_call.id and not tool_calls[tool_index]["id"]:
tool_calls[tool_index]["id"] = tool_call.id
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
tool_calls[tool_index]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
tool_calls[tool_index]["arguments"] += (
tool_call.function.arguments
)
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments
if tool_call.function and tool_call.function.arguments
else "",
from_task=from_task,
from_agent=from_agent,
tool_call={
"id": tool_calls[tool_index]["id"],
"function": {
"name": tool_calls[tool_index]["name"],
"arguments": tool_calls[tool_index]["arguments"],
},
"type": "function",
"index": tool_calls[tool_index]["index"],
},
call_type=LLMCallType.TOOL_CALL,
)
self._track_token_usage_internal(usage_data)