mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
fix: emit tool call events in provider-specific LLM streaming implementations
Fixes #3982 This commit adds tool call event emission to all provider-specific LLM streaming implementations. Previously, only text chunks were emitted during streaming, but tool call information was missing. Changes: - Update BaseLLM._emit_stream_chunk_event to infer call_type from tool_call presence when not explicitly provided - Add tool call event emission in OpenAI provider streaming - Add tool call event emission in Azure provider streaming - Add tool call event emission in Gemini provider streaming - Add tool call event emission in Bedrock provider streaming - Add tool call event emission in Anthropic provider streaming - Add comprehensive tests for tool call streaming events The fix ensures that LLMStreamChunkEvent is emitted with: - call_type=LLMCallType.TOOL_CALL when tool calls are received - tool_call dict containing id, function (name, arguments), type, index - chunk containing the tool call arguments being streamed Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -316,11 +316,33 @@ class BaseLLM(ABC):
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
tool_call: dict[str, Any] | None = None,
|
||||
call_type: LLMCallType | None = None,
|
||||
) -> None:
|
||||
"""Emit stream chunk event."""
|
||||
"""Emit stream chunk event.
|
||||
|
||||
Args:
|
||||
chunk: The text content of the chunk
|
||||
from_task: Optional task that initiated the call
|
||||
from_agent: Optional agent that initiated the call
|
||||
tool_call: Optional tool call information as a dict with keys:
|
||||
- id: Tool call ID
|
||||
- function: Dict with 'name' and 'arguments'
|
||||
- type: Tool call type (e.g., 'function')
|
||||
- index: Index of the tool call
|
||||
call_type: Optional call type. If not provided, it will be inferred
|
||||
from the presence of tool_call (TOOL_CALL if tool_call is present,
|
||||
LLM_CALL otherwise)
|
||||
"""
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||
|
||||
# Infer call_type from tool_call presence if not explicitly provided
|
||||
effective_call_type = call_type
|
||||
if effective_call_type is None:
|
||||
effective_call_type = (
|
||||
LLMCallType.TOOL_CALL if tool_call is not None else LLMCallType.LLM_CALL
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
@@ -328,6 +350,7 @@ class BaseLLM(ABC):
|
||||
tool_call=tool_call,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=effective_call_type,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -450,9 +450,14 @@ class AnthropicCompletion(BaseLLM):
|
||||
# (the SDK sets it internally)
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
# Track tool use blocks during streaming
|
||||
current_tool_use: dict[str, Any] = {}
|
||||
tool_use_index = 0
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
for event in stream:
|
||||
# Handle text content
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
@@ -462,6 +467,55 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Handle tool use start (content_block_start event with tool_use type)
|
||||
if hasattr(event, "content_block") and hasattr(event.content_block, "type"):
|
||||
if event.content_block.type == "tool_use":
|
||||
current_tool_use = {
|
||||
"id": getattr(event.content_block, "id", None),
|
||||
"name": getattr(event.content_block, "name", ""),
|
||||
"input": "",
|
||||
"index": tool_use_index,
|
||||
}
|
||||
tool_use_index += 1
|
||||
# Emit tool call start event
|
||||
tool_call_event_data = {
|
||||
"id": current_tool_use["id"],
|
||||
"function": {
|
||||
"name": current_tool_use["name"],
|
||||
"arguments": "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": current_tool_use["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk="",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
# Handle tool use input delta (input_json events)
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
|
||||
partial_json = event.delta.partial_json
|
||||
if current_tool_use and partial_json:
|
||||
current_tool_use["input"] += partial_json
|
||||
# Emit tool call delta event
|
||||
tool_call_event_data = {
|
||||
"id": current_tool_use["id"],
|
||||
"function": {
|
||||
"name": current_tool_use["name"],
|
||||
"arguments": partial_json,
|
||||
},
|
||||
"type": "function",
|
||||
"index": current_tool_use["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=partial_json,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
final_message: Message = stream.get_final_message()
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
|
||||
@@ -503,8 +503,10 @@ class AzureCompletion(BaseLLM):
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"id": call_id,
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
"index": getattr(tool_call, "index", 0) or 0,
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
@@ -514,6 +516,23 @@ class AzureCompletion(BaseLLM):
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Emit tool call streaming event
|
||||
tool_call_event_data = {
|
||||
"id": tool_calls[call_id]["id"],
|
||||
"function": {
|
||||
"name": tool_calls[call_id]["name"],
|
||||
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_calls[call_id]["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
# Handle completed tool calls
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
|
||||
@@ -567,12 +567,31 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
block_index = event["contentBlockStart"].get("contentBlockIndex", 0)
|
||||
if "toolUse" in start:
|
||||
current_tool_use = start["toolUse"]
|
||||
current_tool_use["_block_index"] = block_index
|
||||
current_tool_use["_accumulated_input"] = ""
|
||||
tool_use_id = current_tool_use.get("toolUseId")
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
||||
)
|
||||
# Emit tool call start event
|
||||
tool_call_event_data = {
|
||||
"id": tool_use_id,
|
||||
"function": {
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"arguments": "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": block_index,
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk="",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
elif "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
@@ -589,6 +608,23 @@ class BedrockCompletion(BaseLLM):
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
current_tool_use["_accumulated_input"] += tool_input
|
||||
# Emit tool call delta event
|
||||
tool_call_event_data = {
|
||||
"id": current_tool_use.get("toolUseId"),
|
||||
"function": {
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"arguments": tool_input,
|
||||
},
|
||||
"type": "function",
|
||||
"index": current_tool_use.get("_block_index", 0),
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_input,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
# Content block stop - end of a content block
|
||||
elif "contentBlockStop" in event:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -496,7 +497,7 @@ class GeminiCompletion(BaseLLM):
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
for part_index, part in enumerate(candidate.content.parts):
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
@@ -505,8 +506,27 @@ class GeminiCompletion(BaseLLM):
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
"index": part_index,
|
||||
}
|
||||
|
||||
# Emit tool call streaming event
|
||||
args_str = json.dumps(function_calls[call_id]["args"]) if function_calls[call_id]["args"] else ""
|
||||
tool_call_event_data = {
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": function_calls[call_id]["name"],
|
||||
"arguments": args_str,
|
||||
},
|
||||
"type": "function",
|
||||
"index": function_calls[call_id]["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=args_str,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
# Handle completed function calls
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
|
||||
@@ -510,8 +510,10 @@ class OpenAICompletion(BaseLLM):
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"id": call_id,
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
"index": tool_call.index if tool_call.index is not None else 0,
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
@@ -519,6 +521,23 @@ class OpenAICompletion(BaseLLM):
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Emit tool call streaming event
|
||||
tool_call_event_data = {
|
||||
"id": tool_calls[call_id]["id"],
|
||||
"function": {
|
||||
"name": tool_calls[call_id]["name"],
|
||||
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_calls[call_id]["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
|
||||
@@ -715,3 +715,243 @@ class TestStreamingImports:
|
||||
assert StreamChunk is not None
|
||||
assert StreamChunkType is not None
|
||||
assert ToolCallChunk is not None
|
||||
|
||||
|
||||
class TestLLMStreamChunkEventToolCall:
|
||||
"""Tests for LLMStreamChunkEvent with tool call information."""
|
||||
|
||||
def test_llm_stream_chunk_event_with_tool_call(self) -> None:
|
||||
"""Test that LLMStreamChunkEvent correctly handles tool call data."""
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
ToolCall,
|
||||
FunctionCall,
|
||||
)
|
||||
|
||||
# Create a tool call event
|
||||
tool_call = ToolCall(
|
||||
id="call-123",
|
||||
function=FunctionCall(
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk='{"query": "test"}',
|
||||
tool_call=tool_call,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
assert event.chunk == '{"query": "test"}'
|
||||
assert event.tool_call is not None
|
||||
assert event.tool_call.id == "call-123"
|
||||
assert event.tool_call.function.name == "search"
|
||||
assert event.tool_call.function.arguments == '{"query": "test"}'
|
||||
assert event.call_type == LLMCallType.TOOL_CALL
|
||||
|
||||
def test_llm_stream_chunk_event_with_dict_tool_call(self) -> None:
|
||||
"""Test that LLMStreamChunkEvent correctly handles tool call as dict."""
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
# Create a tool call event using dict (as providers emit)
|
||||
tool_call_dict = {
|
||||
"id": "call-456",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "NYC"}',
|
||||
},
|
||||
"type": "function",
|
||||
"index": 1,
|
||||
}
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk='{"location": "NYC"}',
|
||||
tool_call=tool_call_dict,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
assert event.chunk == '{"location": "NYC"}'
|
||||
assert event.tool_call is not None
|
||||
assert event.tool_call.id == "call-456"
|
||||
assert event.tool_call.function.name == "get_weather"
|
||||
assert event.tool_call.function.arguments == '{"location": "NYC"}'
|
||||
assert event.call_type == LLMCallType.TOOL_CALL
|
||||
|
||||
def test_llm_stream_chunk_event_text_only(self) -> None:
|
||||
"""Test that LLMStreamChunkEvent works for text-only chunks."""
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk="Hello, world!",
|
||||
tool_call=None,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
)
|
||||
|
||||
assert event.chunk == "Hello, world!"
|
||||
assert event.tool_call is None
|
||||
assert event.call_type == LLMCallType.LLM_CALL
|
||||
|
||||
|
||||
class TestBaseLLMEmitStreamChunkEvent:
|
||||
"""Tests for BaseLLM._emit_stream_chunk_event method."""
|
||||
|
||||
def test_emit_stream_chunk_event_infers_tool_call_type(self) -> None:
|
||||
"""Test that _emit_stream_chunk_event infers TOOL_CALL type when tool_call is present."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||
|
||||
# Create a mock BaseLLM instance
|
||||
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||
llm = BaseLLM(model="test-model") # type: ignore
|
||||
|
||||
captured_events: list[LLMStreamChunkEvent] = []
|
||||
|
||||
def capture_emit(source: Any, event: Any) -> None:
|
||||
if isinstance(event, LLMStreamChunkEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||
mock_bus.emit = capture_emit
|
||||
|
||||
# Emit with tool_call - should infer TOOL_CALL type
|
||||
tool_call_dict = {
|
||||
"id": "call-789",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"arg": "value"}',
|
||||
},
|
||||
"type": "function",
|
||||
"index": 0,
|
||||
}
|
||||
llm._emit_stream_chunk_event(
|
||||
chunk='{"arg": "value"}',
|
||||
tool_call=tool_call_dict,
|
||||
)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
|
||||
assert captured_events[0].tool_call is not None
|
||||
|
||||
def test_emit_stream_chunk_event_infers_llm_call_type(self) -> None:
|
||||
"""Test that _emit_stream_chunk_event infers LLM_CALL type when tool_call is None."""
|
||||
from unittest.mock import patch
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||
|
||||
# Create a mock BaseLLM instance
|
||||
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||
llm = BaseLLM(model="test-model") # type: ignore
|
||||
|
||||
captured_events: list[LLMStreamChunkEvent] = []
|
||||
|
||||
def capture_emit(source: Any, event: Any) -> None:
|
||||
if isinstance(event, LLMStreamChunkEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||
mock_bus.emit = capture_emit
|
||||
|
||||
# Emit without tool_call - should infer LLM_CALL type
|
||||
llm._emit_stream_chunk_event(
|
||||
chunk="Hello, world!",
|
||||
tool_call=None,
|
||||
)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].call_type == LLMCallType.LLM_CALL
|
||||
assert captured_events[0].tool_call is None
|
||||
|
||||
def test_emit_stream_chunk_event_respects_explicit_call_type(self) -> None:
|
||||
"""Test that _emit_stream_chunk_event respects explicitly provided call_type."""
|
||||
from unittest.mock import patch
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||
|
||||
# Create a mock BaseLLM instance
|
||||
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||
llm = BaseLLM(model="test-model") # type: ignore
|
||||
|
||||
captured_events: list[LLMStreamChunkEvent] = []
|
||||
|
||||
def capture_emit(source: Any, event: Any) -> None:
|
||||
if isinstance(event, LLMStreamChunkEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||
mock_bus.emit = capture_emit
|
||||
|
||||
# Emit with explicit call_type - should use provided type
|
||||
llm._emit_stream_chunk_event(
|
||||
chunk="test",
|
||||
tool_call=None,
|
||||
call_type=LLMCallType.TOOL_CALL, # Explicitly set even though no tool_call
|
||||
)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
|
||||
|
||||
|
||||
class TestStreamingToolCallExtraction:
|
||||
"""Tests for tool call extraction from streaming events."""
|
||||
|
||||
def test_extract_tool_call_info_from_event(self) -> None:
|
||||
"""Test that tool call info is correctly extracted from LLMStreamChunkEvent."""
|
||||
from crewai.utilities.streaming import _extract_tool_call_info
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
ToolCall,
|
||||
FunctionCall,
|
||||
)
|
||||
from crewai.types.streaming import StreamChunkType
|
||||
|
||||
# Create event with tool call
|
||||
tool_call = ToolCall(
|
||||
id="call-extract-test",
|
||||
function=FunctionCall(
|
||||
name="extract_test",
|
||||
arguments='{"key": "value"}',
|
||||
),
|
||||
type="function",
|
||||
index=2,
|
||||
)
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk='{"key": "value"}',
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||
|
||||
assert chunk_type == StreamChunkType.TOOL_CALL
|
||||
assert tool_call_chunk is not None
|
||||
assert tool_call_chunk.tool_id == "call-extract-test"
|
||||
assert tool_call_chunk.tool_name == "extract_test"
|
||||
assert tool_call_chunk.arguments == '{"key": "value"}'
|
||||
assert tool_call_chunk.index == 2
|
||||
|
||||
def test_extract_tool_call_info_returns_text_for_no_tool_call(self) -> None:
|
||||
"""Test that TEXT type is returned when no tool call is present."""
|
||||
from crewai.utilities.streaming import _extract_tool_call_info
|
||||
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
||||
from crewai.types.streaming import StreamChunkType
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk="Just text content",
|
||||
tool_call=None,
|
||||
)
|
||||
|
||||
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||
|
||||
assert chunk_type == StreamChunkType.TEXT
|
||||
assert tool_call_chunk is None
|
||||
|
||||
Reference in New Issue
Block a user