fix: emit tool call events in provider-specific LLM streaming implementations

Fixes #3982

This commit adds tool call event emission to all provider-specific LLM
streaming implementations. Previously, only text chunks were emitted
during streaming, but tool call information was missing.

Changes:
- Update BaseLLM._emit_stream_chunk_event to infer call_type from
  tool_call presence when not explicitly provided
- Add tool call event emission in OpenAI provider streaming
- Add tool call event emission in Azure provider streaming
- Add tool call event emission in Gemini provider streaming
- Add tool call event emission in Bedrock provider streaming
- Add tool call event emission in Anthropic provider streaming
- Add comprehensive tests for tool call streaming events

The fix ensures that LLMStreamChunkEvent is emitted with:
- call_type=LLMCallType.TOOL_CALL when tool calls are received
- tool_call dict containing id, function (name, arguments), type, index
- chunk containing the tool call arguments being streamed

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-11-27 07:19:36 +00:00
parent 2025a26fc3
commit b70c4499d7
7 changed files with 413 additions and 2 deletions

View File

@@ -316,11 +316,33 @@ class BaseLLM(ABC):
from_task: Task | None = None,
from_agent: Agent | None = None,
tool_call: dict[str, Any] | None = None,
call_type: LLMCallType | None = None,
) -> None:
"""Emit stream chunk event."""
"""Emit stream chunk event.
Args:
chunk: The text content of the chunk
from_task: Optional task that initiated the call
from_agent: Optional agent that initiated the call
tool_call: Optional tool call information as a dict with keys:
- id: Tool call ID
- function: Dict with 'name' and 'arguments'
- type: Tool call type (e.g., 'function')
- index: Index of the tool call
call_type: Optional call type. If not provided, it will be inferred
from the presence of tool_call (TOOL_CALL if tool_call is present,
LLM_CALL otherwise)
"""
if not hasattr(crewai_event_bus, "emit"):
raise ValueError("crewai_event_bus does not have an emit method") from None
# Infer call_type from tool_call presence if not explicitly provided
effective_call_type = call_type
if effective_call_type is None:
effective_call_type = (
LLMCallType.TOOL_CALL if tool_call is not None else LLMCallType.LLM_CALL
)
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
@@ -328,6 +350,7 @@ class BaseLLM(ABC):
tool_call=tool_call,
from_task=from_task,
from_agent=from_agent,
call_type=effective_call_type,
),
)

View File

@@ -450,9 +450,14 @@ class AnthropicCompletion(BaseLLM):
# (the SDK sets it internally)
stream_params = {k: v for k, v in params.items() if k != "stream"}
# Track tool use blocks during streaming
current_tool_use: dict[str, Any] = {}
tool_use_index = 0
# Make streaming API call
with self.client.messages.stream(**stream_params) as stream:
for event in stream:
# Handle text content
if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text
full_response += text_delta
@@ -462,6 +467,55 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent,
)
# Handle tool use start (content_block_start event with tool_use type)
if hasattr(event, "content_block") and hasattr(event.content_block, "type"):
if event.content_block.type == "tool_use":
current_tool_use = {
"id": getattr(event.content_block, "id", None),
"name": getattr(event.content_block, "name", ""),
"input": "",
"index": tool_use_index,
}
tool_use_index += 1
# Emit tool call start event
tool_call_event_data = {
"id": current_tool_use["id"],
"function": {
"name": current_tool_use["name"],
"arguments": "",
},
"type": "function",
"index": current_tool_use["index"],
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Handle tool use input delta (input_json events)
if hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
partial_json = event.delta.partial_json
if current_tool_use and partial_json:
current_tool_use["input"] += partial_json
# Emit tool call delta event
tool_call_event_data = {
"id": current_tool_use["id"],
"function": {
"name": current_tool_use["name"],
"arguments": partial_json,
},
"type": "function",
"index": current_tool_use["index"],
}
self._emit_stream_chunk_event(
chunk=partial_json,
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
final_message: Message = stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message)

View File

@@ -503,8 +503,10 @@ class AzureCompletion(BaseLLM):
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"id": call_id,
"name": "",
"arguments": "",
"index": getattr(tool_call, "index", 0) or 0,
}
if tool_call.function and tool_call.function.name:
@@ -514,6 +516,23 @@ class AzureCompletion(BaseLLM):
tool_call.function.arguments
)
# Emit tool call streaming event
tool_call_event_data = {
"id": tool_calls[call_id]["id"],
"function": {
"name": tool_calls[call_id]["name"],
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
},
"type": "function",
"index": tool_calls[call_id]["index"],
}
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Handle completed tool calls
if tool_calls and available_functions:
for call_data in tool_calls.values():

View File

@@ -567,12 +567,31 @@ class BedrockCompletion(BaseLLM):
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
block_index = event["contentBlockStart"].get("contentBlockIndex", 0)
if "toolUse" in start:
current_tool_use = start["toolUse"]
current_tool_use["_block_index"] = block_index
current_tool_use["_accumulated_input"] = ""
tool_use_id = current_tool_use.get("toolUseId")
logging.debug(
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
)
# Emit tool call start event
tool_call_event_data = {
"id": tool_use_id,
"function": {
"name": current_tool_use.get("name", ""),
"arguments": "",
},
"type": "function",
"index": block_index,
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
elif "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
@@ -589,6 +608,23 @@ class BedrockCompletion(BaseLLM):
tool_input = delta["toolUse"].get("input", "")
if tool_input:
logging.debug(f"Tool input delta: {tool_input}")
current_tool_use["_accumulated_input"] += tool_input
# Emit tool call delta event
tool_call_event_data = {
"id": current_tool_use.get("toolUseId"),
"function": {
"name": current_tool_use.get("name", ""),
"arguments": tool_input,
},
"type": "function",
"index": current_tool_use.get("_block_index", 0),
}
self._emit_stream_chunk_event(
chunk=tool_input,
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Content block stop - end of a content block
elif "contentBlockStop" in event:

View File

@@ -1,3 +1,4 @@
import json
import logging
import os
import re
@@ -496,7 +497,7 @@ class GeminiCompletion(BaseLLM):
if hasattr(chunk, "candidates") and chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
for part_index, part in enumerate(candidate.content.parts):
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
if call_id not in function_calls:
@@ -505,8 +506,27 @@ class GeminiCompletion(BaseLLM):
"args": dict(part.function_call.args)
if part.function_call.args
else {},
"index": part_index,
}
# Emit tool call streaming event
args_str = json.dumps(function_calls[call_id]["args"]) if function_calls[call_id]["args"] else ""
tool_call_event_data = {
"id": call_id,
"function": {
"name": function_calls[call_id]["name"],
"arguments": args_str,
},
"type": "function",
"index": function_calls[call_id]["index"],
}
self._emit_stream_chunk_event(
chunk=args_str,
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Handle completed function calls
if function_calls and available_functions:
for call_data in function_calls.values():

View File

@@ -510,8 +510,10 @@ class OpenAICompletion(BaseLLM):
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"id": call_id,
"name": "",
"arguments": "",
"index": tool_call.index if tool_call.index is not None else 0,
}
if tool_call.function and tool_call.function.name:
@@ -519,6 +521,23 @@ class OpenAICompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
# Emit tool call streaming event
tool_call_event_data = {
"id": tool_calls[call_id]["id"],
"function": {
"name": tool_calls[call_id]["name"],
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
},
"type": "function",
"index": tool_calls[call_id]["index"],
}
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]

View File

@@ -715,3 +715,243 @@ class TestStreamingImports:
assert StreamChunk is not None
assert StreamChunkType is not None
assert ToolCallChunk is not None
class TestLLMStreamChunkEventToolCall:
"""Tests for LLMStreamChunkEvent with tool call information."""
def test_llm_stream_chunk_event_with_tool_call(self) -> None:
"""Test that LLMStreamChunkEvent correctly handles tool call data."""
from crewai.events.types.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
ToolCall,
FunctionCall,
)
# Create a tool call event
tool_call = ToolCall(
id="call-123",
function=FunctionCall(
name="search",
arguments='{"query": "test"}',
),
type="function",
index=0,
)
event = LLMStreamChunkEvent(
chunk='{"query": "test"}',
tool_call=tool_call,
call_type=LLMCallType.TOOL_CALL,
)
assert event.chunk == '{"query": "test"}'
assert event.tool_call is not None
assert event.tool_call.id == "call-123"
assert event.tool_call.function.name == "search"
assert event.tool_call.function.arguments == '{"query": "test"}'
assert event.call_type == LLMCallType.TOOL_CALL
def test_llm_stream_chunk_event_with_dict_tool_call(self) -> None:
"""Test that LLMStreamChunkEvent correctly handles tool call as dict."""
from crewai.events.types.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
)
# Create a tool call event using dict (as providers emit)
tool_call_dict = {
"id": "call-456",
"function": {
"name": "get_weather",
"arguments": '{"location": "NYC"}',
},
"type": "function",
"index": 1,
}
event = LLMStreamChunkEvent(
chunk='{"location": "NYC"}',
tool_call=tool_call_dict,
call_type=LLMCallType.TOOL_CALL,
)
assert event.chunk == '{"location": "NYC"}'
assert event.tool_call is not None
assert event.tool_call.id == "call-456"
assert event.tool_call.function.name == "get_weather"
assert event.tool_call.function.arguments == '{"location": "NYC"}'
assert event.call_type == LLMCallType.TOOL_CALL
def test_llm_stream_chunk_event_text_only(self) -> None:
"""Test that LLMStreamChunkEvent works for text-only chunks."""
from crewai.events.types.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
)
event = LLMStreamChunkEvent(
chunk="Hello, world!",
tool_call=None,
call_type=LLMCallType.LLM_CALL,
)
assert event.chunk == "Hello, world!"
assert event.tool_call is None
assert event.call_type == LLMCallType.LLM_CALL
class TestBaseLLMEmitStreamChunkEvent:
"""Tests for BaseLLM._emit_stream_chunk_event method."""
def test_emit_stream_chunk_event_infers_tool_call_type(self) -> None:
"""Test that _emit_stream_chunk_event infers TOOL_CALL type when tool_call is present."""
from unittest.mock import MagicMock, patch
from crewai.llms.base_llm import BaseLLM
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
# Create a mock BaseLLM instance
with patch.object(BaseLLM, "__abstractmethods__", set()):
llm = BaseLLM(model="test-model") # type: ignore
captured_events: list[LLMStreamChunkEvent] = []
def capture_emit(source: Any, event: Any) -> None:
if isinstance(event, LLMStreamChunkEvent):
captured_events.append(event)
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
mock_bus.emit = capture_emit
# Emit with tool_call - should infer TOOL_CALL type
tool_call_dict = {
"id": "call-789",
"function": {
"name": "test_tool",
"arguments": '{"arg": "value"}',
},
"type": "function",
"index": 0,
}
llm._emit_stream_chunk_event(
chunk='{"arg": "value"}',
tool_call=tool_call_dict,
)
assert len(captured_events) == 1
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
assert captured_events[0].tool_call is not None
def test_emit_stream_chunk_event_infers_llm_call_type(self) -> None:
"""Test that _emit_stream_chunk_event infers LLM_CALL type when tool_call is None."""
from unittest.mock import patch
from crewai.llms.base_llm import BaseLLM
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
# Create a mock BaseLLM instance
with patch.object(BaseLLM, "__abstractmethods__", set()):
llm = BaseLLM(model="test-model") # type: ignore
captured_events: list[LLMStreamChunkEvent] = []
def capture_emit(source: Any, event: Any) -> None:
if isinstance(event, LLMStreamChunkEvent):
captured_events.append(event)
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
mock_bus.emit = capture_emit
# Emit without tool_call - should infer LLM_CALL type
llm._emit_stream_chunk_event(
chunk="Hello, world!",
tool_call=None,
)
assert len(captured_events) == 1
assert captured_events[0].call_type == LLMCallType.LLM_CALL
assert captured_events[0].tool_call is None
def test_emit_stream_chunk_event_respects_explicit_call_type(self) -> None:
"""Test that _emit_stream_chunk_event respects explicitly provided call_type."""
from unittest.mock import patch
from crewai.llms.base_llm import BaseLLM
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
# Create a mock BaseLLM instance
with patch.object(BaseLLM, "__abstractmethods__", set()):
llm = BaseLLM(model="test-model") # type: ignore
captured_events: list[LLMStreamChunkEvent] = []
def capture_emit(source: Any, event: Any) -> None:
if isinstance(event, LLMStreamChunkEvent):
captured_events.append(event)
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
mock_bus.emit = capture_emit
# Emit with explicit call_type - should use provided type
llm._emit_stream_chunk_event(
chunk="test",
tool_call=None,
call_type=LLMCallType.TOOL_CALL, # Explicitly set even though no tool_call
)
assert len(captured_events) == 1
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
class TestStreamingToolCallExtraction:
"""Tests for tool call extraction from streaming events."""
def test_extract_tool_call_info_from_event(self) -> None:
"""Test that tool call info is correctly extracted from LLMStreamChunkEvent."""
from crewai.utilities.streaming import _extract_tool_call_info
from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
ToolCall,
FunctionCall,
)
from crewai.types.streaming import StreamChunkType
# Create event with tool call
tool_call = ToolCall(
id="call-extract-test",
function=FunctionCall(
name="extract_test",
arguments='{"key": "value"}',
),
type="function",
index=2,
)
event = LLMStreamChunkEvent(
chunk='{"key": "value"}',
tool_call=tool_call,
)
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
assert chunk_type == StreamChunkType.TOOL_CALL
assert tool_call_chunk is not None
assert tool_call_chunk.tool_id == "call-extract-test"
assert tool_call_chunk.tool_name == "extract_test"
assert tool_call_chunk.arguments == '{"key": "value"}'
assert tool_call_chunk.index == 2
def test_extract_tool_call_info_returns_text_for_no_tool_call(self) -> None:
"""Test that TEXT type is returned when no tool call is present."""
from crewai.utilities.streaming import _extract_tool_call_info
from crewai.events.types.llm_events import LLMStreamChunkEvent
from crewai.types.streaming import StreamChunkType
event = LLMStreamChunkEvent(
chunk="Just text content",
tool_call=None,
)
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
assert chunk_type == StreamChunkType.TEXT
assert tool_call_chunk is None