fix: emit tool call events in provider-specific LLM streaming implementations

Fixes #3982

This commit adds tool call event emission to all provider-specific LLM
streaming implementations. Previously, only text chunks were emitted
during streaming, but tool call information was missing.

Changes:
- Update BaseLLM._emit_stream_chunk_event to infer call_type from
  tool_call presence when not explicitly provided
- Add tool call event emission in OpenAI provider streaming
- Add tool call event emission in Azure provider streaming
- Add tool call event emission in Gemini provider streaming
- Add tool call event emission in Bedrock provider streaming
- Add tool call event emission in Anthropic provider streaming
- Add comprehensive tests for tool call streaming events

The fix ensures that LLMStreamChunkEvent is emitted with:
- call_type=LLMCallType.TOOL_CALL when tool calls are received
- tool_call dict containing id, function (name, arguments), type, index
- chunk containing the tool call arguments being streamed

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-11-27 07:19:36 +00:00
parent 2025a26fc3
commit b70c4499d7
7 changed files with 413 additions and 2 deletions

View File

@@ -316,11 +316,33 @@ class BaseLLM(ABC):
from_task: Task | None = None, from_task: Task | None = None,
from_agent: Agent | None = None, from_agent: Agent | None = None,
tool_call: dict[str, Any] | None = None, tool_call: dict[str, Any] | None = None,
call_type: LLMCallType | None = None,
) -> None: ) -> None:
"""Emit stream chunk event.""" """Emit stream chunk event.
Args:
chunk: The text content of the chunk
from_task: Optional task that initiated the call
from_agent: Optional agent that initiated the call
tool_call: Optional tool call information as a dict with keys:
- id: Tool call ID
- function: Dict with 'name' and 'arguments'
- type: Tool call type (e.g., 'function')
- index: Index of the tool call
call_type: Optional call type. If not provided, it will be inferred
from the presence of tool_call (TOOL_CALL if tool_call is present,
LLM_CALL otherwise)
"""
if not hasattr(crewai_event_bus, "emit"): if not hasattr(crewai_event_bus, "emit"):
raise ValueError("crewai_event_bus does not have an emit method") from None raise ValueError("crewai_event_bus does not have an emit method") from None
# Infer call_type from tool_call presence if not explicitly provided
effective_call_type = call_type
if effective_call_type is None:
effective_call_type = (
LLMCallType.TOOL_CALL if tool_call is not None else LLMCallType.LLM_CALL
)
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMStreamChunkEvent( event=LLMStreamChunkEvent(
@@ -328,6 +350,7 @@ class BaseLLM(ABC):
tool_call=tool_call, tool_call=tool_call,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
call_type=effective_call_type,
), ),
) )

View File

@@ -450,9 +450,14 @@ class AnthropicCompletion(BaseLLM):
# (the SDK sets it internally) # (the SDK sets it internally)
stream_params = {k: v for k, v in params.items() if k != "stream"} stream_params = {k: v for k, v in params.items() if k != "stream"}
# Track tool use blocks during streaming
current_tool_use: dict[str, Any] = {}
tool_use_index = 0
# Make streaming API call # Make streaming API call
with self.client.messages.stream(**stream_params) as stream: with self.client.messages.stream(**stream_params) as stream:
for event in stream: for event in stream:
# Handle text content
if hasattr(event, "delta") and hasattr(event.delta, "text"): if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text text_delta = event.delta.text
full_response += text_delta full_response += text_delta
@@ -462,6 +467,55 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
) )
# Handle tool use start (content_block_start event with tool_use type)
if hasattr(event, "content_block") and hasattr(event.content_block, "type"):
if event.content_block.type == "tool_use":
current_tool_use = {
"id": getattr(event.content_block, "id", None),
"name": getattr(event.content_block, "name", ""),
"input": "",
"index": tool_use_index,
}
tool_use_index += 1
# Emit tool call start event
tool_call_event_data = {
"id": current_tool_use["id"],
"function": {
"name": current_tool_use["name"],
"arguments": "",
},
"type": "function",
"index": current_tool_use["index"],
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Handle tool use input delta (input_json events)
if hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
partial_json = event.delta.partial_json
if current_tool_use and partial_json:
current_tool_use["input"] += partial_json
# Emit tool call delta event
tool_call_event_data = {
"id": current_tool_use["id"],
"function": {
"name": current_tool_use["name"],
"arguments": partial_json,
},
"type": "function",
"index": current_tool_use["index"],
}
self._emit_stream_chunk_event(
chunk=partial_json,
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
final_message: Message = stream.get_final_message() final_message: Message = stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message) usage = self._extract_anthropic_token_usage(final_message)

View File

@@ -503,8 +503,10 @@ class AzureCompletion(BaseLLM):
call_id = tool_call.id or "default" call_id = tool_call.id or "default"
if call_id not in tool_calls: if call_id not in tool_calls:
tool_calls[call_id] = { tool_calls[call_id] = {
"id": call_id,
"name": "", "name": "",
"arguments": "", "arguments": "",
"index": getattr(tool_call, "index", 0) or 0,
} }
if tool_call.function and tool_call.function.name: if tool_call.function and tool_call.function.name:
@@ -514,6 +516,23 @@ class AzureCompletion(BaseLLM):
tool_call.function.arguments tool_call.function.arguments
) )
# Emit tool call streaming event
tool_call_event_data = {
"id": tool_calls[call_id]["id"],
"function": {
"name": tool_calls[call_id]["name"],
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
},
"type": "function",
"index": tool_calls[call_id]["index"],
}
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Handle completed tool calls # Handle completed tool calls
if tool_calls and available_functions: if tool_calls and available_functions:
for call_data in tool_calls.values(): for call_data in tool_calls.values():

View File

@@ -567,12 +567,31 @@ class BedrockCompletion(BaseLLM):
elif "contentBlockStart" in event: elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {}) start = event["contentBlockStart"].get("start", {})
block_index = event["contentBlockStart"].get("contentBlockIndex", 0)
if "toolUse" in start: if "toolUse" in start:
current_tool_use = start["toolUse"] current_tool_use = start["toolUse"]
current_tool_use["_block_index"] = block_index
current_tool_use["_accumulated_input"] = ""
tool_use_id = current_tool_use.get("toolUseId") tool_use_id = current_tool_use.get("toolUseId")
logging.debug( logging.debug(
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})" f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
) )
# Emit tool call start event
tool_call_event_data = {
"id": tool_use_id,
"function": {
"name": current_tool_use.get("name", ""),
"arguments": "",
},
"type": "function",
"index": block_index,
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
elif "contentBlockDelta" in event: elif "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"] delta = event["contentBlockDelta"]["delta"]
@@ -589,6 +608,23 @@ class BedrockCompletion(BaseLLM):
tool_input = delta["toolUse"].get("input", "") tool_input = delta["toolUse"].get("input", "")
if tool_input: if tool_input:
logging.debug(f"Tool input delta: {tool_input}") logging.debug(f"Tool input delta: {tool_input}")
current_tool_use["_accumulated_input"] += tool_input
# Emit tool call delta event
tool_call_event_data = {
"id": current_tool_use.get("toolUseId"),
"function": {
"name": current_tool_use.get("name", ""),
"arguments": tool_input,
},
"type": "function",
"index": current_tool_use.get("_block_index", 0),
}
self._emit_stream_chunk_event(
chunk=tool_input,
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Content block stop - end of a content block # Content block stop - end of a content block
elif "contentBlockStop" in event: elif "contentBlockStop" in event:

View File

@@ -1,3 +1,4 @@
import json
import logging import logging
import os import os
import re import re
@@ -496,7 +497,7 @@ class GeminiCompletion(BaseLLM):
if hasattr(chunk, "candidates") and chunk.candidates: if hasattr(chunk, "candidates") and chunk.candidates:
candidate = chunk.candidates[0] candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts: if candidate.content and candidate.content.parts:
for part in candidate.content.parts: for part_index, part in enumerate(candidate.content.parts):
if hasattr(part, "function_call") and part.function_call: if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default" call_id = part.function_call.name or "default"
if call_id not in function_calls: if call_id not in function_calls:
@@ -505,8 +506,27 @@ class GeminiCompletion(BaseLLM):
"args": dict(part.function_call.args) "args": dict(part.function_call.args)
if part.function_call.args if part.function_call.args
else {}, else {},
"index": part_index,
} }
# Emit tool call streaming event
args_str = json.dumps(function_calls[call_id]["args"]) if function_calls[call_id]["args"] else ""
tool_call_event_data = {
"id": call_id,
"function": {
"name": function_calls[call_id]["name"],
"arguments": args_str,
},
"type": "function",
"index": function_calls[call_id]["index"],
}
self._emit_stream_chunk_event(
chunk=args_str,
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Handle completed function calls # Handle completed function calls
if function_calls and available_functions: if function_calls and available_functions:
for call_data in function_calls.values(): for call_data in function_calls.values():

View File

@@ -510,8 +510,10 @@ class OpenAICompletion(BaseLLM):
call_id = tool_call.id or "default" call_id = tool_call.id or "default"
if call_id not in tool_calls: if call_id not in tool_calls:
tool_calls[call_id] = { tool_calls[call_id] = {
"id": call_id,
"name": "", "name": "",
"arguments": "", "arguments": "",
"index": tool_call.index if tool_call.index is not None else 0,
} }
if tool_call.function and tool_call.function.name: if tool_call.function and tool_call.function.name:
@@ -519,6 +521,23 @@ class OpenAICompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments tool_calls[call_id]["arguments"] += tool_call.function.arguments
# Emit tool call streaming event
tool_call_event_data = {
"id": tool_calls[call_id]["id"],
"function": {
"name": tool_calls[call_id]["name"],
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
},
"type": "function",
"index": tool_calls[call_id]["index"],
}
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
if tool_calls and available_functions: if tool_calls and available_functions:
for call_data in tool_calls.values(): for call_data in tool_calls.values():
function_name = call_data["name"] function_name = call_data["name"]

View File

@@ -715,3 +715,243 @@ class TestStreamingImports:
assert StreamChunk is not None assert StreamChunk is not None
assert StreamChunkType is not None assert StreamChunkType is not None
assert ToolCallChunk is not None assert ToolCallChunk is not None
class TestLLMStreamChunkEventToolCall:
"""Tests for LLMStreamChunkEvent with tool call information."""
def test_llm_stream_chunk_event_with_tool_call(self) -> None:
"""Test that LLMStreamChunkEvent correctly handles tool call data."""
from crewai.events.types.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
ToolCall,
FunctionCall,
)
# Create a tool call event
tool_call = ToolCall(
id="call-123",
function=FunctionCall(
name="search",
arguments='{"query": "test"}',
),
type="function",
index=0,
)
event = LLMStreamChunkEvent(
chunk='{"query": "test"}',
tool_call=tool_call,
call_type=LLMCallType.TOOL_CALL,
)
assert event.chunk == '{"query": "test"}'
assert event.tool_call is not None
assert event.tool_call.id == "call-123"
assert event.tool_call.function.name == "search"
assert event.tool_call.function.arguments == '{"query": "test"}'
assert event.call_type == LLMCallType.TOOL_CALL
def test_llm_stream_chunk_event_with_dict_tool_call(self) -> None:
"""Test that LLMStreamChunkEvent correctly handles tool call as dict."""
from crewai.events.types.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
)
# Create a tool call event using dict (as providers emit)
tool_call_dict = {
"id": "call-456",
"function": {
"name": "get_weather",
"arguments": '{"location": "NYC"}',
},
"type": "function",
"index": 1,
}
event = LLMStreamChunkEvent(
chunk='{"location": "NYC"}',
tool_call=tool_call_dict,
call_type=LLMCallType.TOOL_CALL,
)
assert event.chunk == '{"location": "NYC"}'
assert event.tool_call is not None
assert event.tool_call.id == "call-456"
assert event.tool_call.function.name == "get_weather"
assert event.tool_call.function.arguments == '{"location": "NYC"}'
assert event.call_type == LLMCallType.TOOL_CALL
def test_llm_stream_chunk_event_text_only(self) -> None:
"""Test that LLMStreamChunkEvent works for text-only chunks."""
from crewai.events.types.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
)
event = LLMStreamChunkEvent(
chunk="Hello, world!",
tool_call=None,
call_type=LLMCallType.LLM_CALL,
)
assert event.chunk == "Hello, world!"
assert event.tool_call is None
assert event.call_type == LLMCallType.LLM_CALL
class TestBaseLLMEmitStreamChunkEvent:
"""Tests for BaseLLM._emit_stream_chunk_event method."""
def test_emit_stream_chunk_event_infers_tool_call_type(self) -> None:
"""Test that _emit_stream_chunk_event infers TOOL_CALL type when tool_call is present."""
from unittest.mock import MagicMock, patch
from crewai.llms.base_llm import BaseLLM
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
# Create a mock BaseLLM instance
with patch.object(BaseLLM, "__abstractmethods__", set()):
llm = BaseLLM(model="test-model") # type: ignore
captured_events: list[LLMStreamChunkEvent] = []
def capture_emit(source: Any, event: Any) -> None:
if isinstance(event, LLMStreamChunkEvent):
captured_events.append(event)
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
mock_bus.emit = capture_emit
# Emit with tool_call - should infer TOOL_CALL type
tool_call_dict = {
"id": "call-789",
"function": {
"name": "test_tool",
"arguments": '{"arg": "value"}',
},
"type": "function",
"index": 0,
}
llm._emit_stream_chunk_event(
chunk='{"arg": "value"}',
tool_call=tool_call_dict,
)
assert len(captured_events) == 1
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
assert captured_events[0].tool_call is not None
def test_emit_stream_chunk_event_infers_llm_call_type(self) -> None:
"""Test that _emit_stream_chunk_event infers LLM_CALL type when tool_call is None."""
from unittest.mock import patch
from crewai.llms.base_llm import BaseLLM
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
# Create a mock BaseLLM instance
with patch.object(BaseLLM, "__abstractmethods__", set()):
llm = BaseLLM(model="test-model") # type: ignore
captured_events: list[LLMStreamChunkEvent] = []
def capture_emit(source: Any, event: Any) -> None:
if isinstance(event, LLMStreamChunkEvent):
captured_events.append(event)
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
mock_bus.emit = capture_emit
# Emit without tool_call - should infer LLM_CALL type
llm._emit_stream_chunk_event(
chunk="Hello, world!",
tool_call=None,
)
assert len(captured_events) == 1
assert captured_events[0].call_type == LLMCallType.LLM_CALL
assert captured_events[0].tool_call is None
def test_emit_stream_chunk_event_respects_explicit_call_type(self) -> None:
"""Test that _emit_stream_chunk_event respects explicitly provided call_type."""
from unittest.mock import patch
from crewai.llms.base_llm import BaseLLM
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
# Create a mock BaseLLM instance
with patch.object(BaseLLM, "__abstractmethods__", set()):
llm = BaseLLM(model="test-model") # type: ignore
captured_events: list[LLMStreamChunkEvent] = []
def capture_emit(source: Any, event: Any) -> None:
if isinstance(event, LLMStreamChunkEvent):
captured_events.append(event)
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
mock_bus.emit = capture_emit
# Emit with explicit call_type - should use provided type
llm._emit_stream_chunk_event(
chunk="test",
tool_call=None,
call_type=LLMCallType.TOOL_CALL, # Explicitly set even though no tool_call
)
assert len(captured_events) == 1
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
class TestStreamingToolCallExtraction:
"""Tests for tool call extraction from streaming events."""
def test_extract_tool_call_info_from_event(self) -> None:
"""Test that tool call info is correctly extracted from LLMStreamChunkEvent."""
from crewai.utilities.streaming import _extract_tool_call_info
from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
ToolCall,
FunctionCall,
)
from crewai.types.streaming import StreamChunkType
# Create event with tool call
tool_call = ToolCall(
id="call-extract-test",
function=FunctionCall(
name="extract_test",
arguments='{"key": "value"}',
),
type="function",
index=2,
)
event = LLMStreamChunkEvent(
chunk='{"key": "value"}',
tool_call=tool_call,
)
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
assert chunk_type == StreamChunkType.TOOL_CALL
assert tool_call_chunk is not None
assert tool_call_chunk.tool_id == "call-extract-test"
assert tool_call_chunk.tool_name == "extract_test"
assert tool_call_chunk.arguments == '{"key": "value"}'
assert tool_call_chunk.index == 2
def test_extract_tool_call_info_returns_text_for_no_tool_call(self) -> None:
"""Test that TEXT type is returned when no tool call is present."""
from crewai.utilities.streaming import _extract_tool_call_info
from crewai.events.types.llm_events import LLMStreamChunkEvent
from crewai.types.streaming import StreamChunkType
event = LLMStreamChunkEvent(
chunk="Just text content",
tool_call=None,
)
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
assert chunk_type == StreamChunkType.TEXT
assert tool_call_chunk is None