mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
fix: emit tool call events in provider-specific LLM streaming implementations
Fixes #3982 This commit adds tool call event emission to all provider-specific LLM streaming implementations. Previously, only text chunks were emitted during streaming, but tool call information was missing. Changes: - Update BaseLLM._emit_stream_chunk_event to infer call_type from tool_call presence when not explicitly provided - Add tool call event emission in OpenAI provider streaming - Add tool call event emission in Azure provider streaming - Add tool call event emission in Gemini provider streaming - Add tool call event emission in Bedrock provider streaming - Add tool call event emission in Anthropic provider streaming - Add comprehensive tests for tool call streaming events The fix ensures that LLMStreamChunkEvent is emitted with: - call_type=LLMCallType.TOOL_CALL when tool calls are received - tool_call dict containing id, function (name, arguments), type, index - chunk containing the tool call arguments being streamed Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -715,3 +715,243 @@ class TestStreamingImports:
|
||||
assert StreamChunk is not None
|
||||
assert StreamChunkType is not None
|
||||
assert ToolCallChunk is not None
|
||||
|
||||
|
||||
class TestLLMStreamChunkEventToolCall:
|
||||
"""Tests for LLMStreamChunkEvent with tool call information."""
|
||||
|
||||
def test_llm_stream_chunk_event_with_tool_call(self) -> None:
|
||||
"""Test that LLMStreamChunkEvent correctly handles tool call data."""
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
ToolCall,
|
||||
FunctionCall,
|
||||
)
|
||||
|
||||
# Create a tool call event
|
||||
tool_call = ToolCall(
|
||||
id="call-123",
|
||||
function=FunctionCall(
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk='{"query": "test"}',
|
||||
tool_call=tool_call,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
assert event.chunk == '{"query": "test"}'
|
||||
assert event.tool_call is not None
|
||||
assert event.tool_call.id == "call-123"
|
||||
assert event.tool_call.function.name == "search"
|
||||
assert event.tool_call.function.arguments == '{"query": "test"}'
|
||||
assert event.call_type == LLMCallType.TOOL_CALL
|
||||
|
||||
def test_llm_stream_chunk_event_with_dict_tool_call(self) -> None:
|
||||
"""Test that LLMStreamChunkEvent correctly handles tool call as dict."""
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
# Create a tool call event using dict (as providers emit)
|
||||
tool_call_dict = {
|
||||
"id": "call-456",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "NYC"}',
|
||||
},
|
||||
"type": "function",
|
||||
"index": 1,
|
||||
}
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk='{"location": "NYC"}',
|
||||
tool_call=tool_call_dict,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
assert event.chunk == '{"location": "NYC"}'
|
||||
assert event.tool_call is not None
|
||||
assert event.tool_call.id == "call-456"
|
||||
assert event.tool_call.function.name == "get_weather"
|
||||
assert event.tool_call.function.arguments == '{"location": "NYC"}'
|
||||
assert event.call_type == LLMCallType.TOOL_CALL
|
||||
|
||||
def test_llm_stream_chunk_event_text_only(self) -> None:
|
||||
"""Test that LLMStreamChunkEvent works for text-only chunks."""
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk="Hello, world!",
|
||||
tool_call=None,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
)
|
||||
|
||||
assert event.chunk == "Hello, world!"
|
||||
assert event.tool_call is None
|
||||
assert event.call_type == LLMCallType.LLM_CALL
|
||||
|
||||
|
||||
class TestBaseLLMEmitStreamChunkEvent:
|
||||
"""Tests for BaseLLM._emit_stream_chunk_event method."""
|
||||
|
||||
def test_emit_stream_chunk_event_infers_tool_call_type(self) -> None:
|
||||
"""Test that _emit_stream_chunk_event infers TOOL_CALL type when tool_call is present."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||
|
||||
# Create a mock BaseLLM instance
|
||||
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||
llm = BaseLLM(model="test-model") # type: ignore
|
||||
|
||||
captured_events: list[LLMStreamChunkEvent] = []
|
||||
|
||||
def capture_emit(source: Any, event: Any) -> None:
|
||||
if isinstance(event, LLMStreamChunkEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||
mock_bus.emit = capture_emit
|
||||
|
||||
# Emit with tool_call - should infer TOOL_CALL type
|
||||
tool_call_dict = {
|
||||
"id": "call-789",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"arg": "value"}',
|
||||
},
|
||||
"type": "function",
|
||||
"index": 0,
|
||||
}
|
||||
llm._emit_stream_chunk_event(
|
||||
chunk='{"arg": "value"}',
|
||||
tool_call=tool_call_dict,
|
||||
)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
|
||||
assert captured_events[0].tool_call is not None
|
||||
|
||||
def test_emit_stream_chunk_event_infers_llm_call_type(self) -> None:
|
||||
"""Test that _emit_stream_chunk_event infers LLM_CALL type when tool_call is None."""
|
||||
from unittest.mock import patch
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||
|
||||
# Create a mock BaseLLM instance
|
||||
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||
llm = BaseLLM(model="test-model") # type: ignore
|
||||
|
||||
captured_events: list[LLMStreamChunkEvent] = []
|
||||
|
||||
def capture_emit(source: Any, event: Any) -> None:
|
||||
if isinstance(event, LLMStreamChunkEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||
mock_bus.emit = capture_emit
|
||||
|
||||
# Emit without tool_call - should infer LLM_CALL type
|
||||
llm._emit_stream_chunk_event(
|
||||
chunk="Hello, world!",
|
||||
tool_call=None,
|
||||
)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].call_type == LLMCallType.LLM_CALL
|
||||
assert captured_events[0].tool_call is None
|
||||
|
||||
def test_emit_stream_chunk_event_respects_explicit_call_type(self) -> None:
|
||||
"""Test that _emit_stream_chunk_event respects explicitly provided call_type."""
|
||||
from unittest.mock import patch
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||
|
||||
# Create a mock BaseLLM instance
|
||||
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||
llm = BaseLLM(model="test-model") # type: ignore
|
||||
|
||||
captured_events: list[LLMStreamChunkEvent] = []
|
||||
|
||||
def capture_emit(source: Any, event: Any) -> None:
|
||||
if isinstance(event, LLMStreamChunkEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||
mock_bus.emit = capture_emit
|
||||
|
||||
# Emit with explicit call_type - should use provided type
|
||||
llm._emit_stream_chunk_event(
|
||||
chunk="test",
|
||||
tool_call=None,
|
||||
call_type=LLMCallType.TOOL_CALL, # Explicitly set even though no tool_call
|
||||
)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
|
||||
|
||||
|
||||
class TestStreamingToolCallExtraction:
|
||||
"""Tests for tool call extraction from streaming events."""
|
||||
|
||||
def test_extract_tool_call_info_from_event(self) -> None:
|
||||
"""Test that tool call info is correctly extracted from LLMStreamChunkEvent."""
|
||||
from crewai.utilities.streaming import _extract_tool_call_info
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
ToolCall,
|
||||
FunctionCall,
|
||||
)
|
||||
from crewai.types.streaming import StreamChunkType
|
||||
|
||||
# Create event with tool call
|
||||
tool_call = ToolCall(
|
||||
id="call-extract-test",
|
||||
function=FunctionCall(
|
||||
name="extract_test",
|
||||
arguments='{"key": "value"}',
|
||||
),
|
||||
type="function",
|
||||
index=2,
|
||||
)
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk='{"key": "value"}',
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||
|
||||
assert chunk_type == StreamChunkType.TOOL_CALL
|
||||
assert tool_call_chunk is not None
|
||||
assert tool_call_chunk.tool_id == "call-extract-test"
|
||||
assert tool_call_chunk.tool_name == "extract_test"
|
||||
assert tool_call_chunk.arguments == '{"key": "value"}'
|
||||
assert tool_call_chunk.index == 2
|
||||
|
||||
def test_extract_tool_call_info_returns_text_for_no_tool_call(self) -> None:
|
||||
"""Test that TEXT type is returned when no tool call is present."""
|
||||
from crewai.utilities.streaming import _extract_tool_call_info
|
||||
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
||||
from crewai.types.streaming import StreamChunkType
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk="Just text content",
|
||||
tool_call=None,
|
||||
)
|
||||
|
||||
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||
|
||||
assert chunk_type == StreamChunkType.TEXT
|
||||
assert tool_call_chunk is None
|
||||
|
||||
Reference in New Issue
Block a user