mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
fix: emit tool call events in provider-specific LLM streaming implementations
Fixes #3982 This commit adds tool call event emission to all provider-specific LLM streaming implementations. Previously, only text chunks were emitted during streaming, but tool call information was missing. Changes: - Update BaseLLM._emit_stream_chunk_event to infer call_type from tool_call presence when not explicitly provided - Add tool call event emission in OpenAI provider streaming - Add tool call event emission in Azure provider streaming - Add tool call event emission in Gemini provider streaming - Add tool call event emission in Bedrock provider streaming - Add tool call event emission in Anthropic provider streaming - Add comprehensive tests for tool call streaming events The fix ensures that LLMStreamChunkEvent is emitted with: - call_type=LLMCallType.TOOL_CALL when tool calls are received - tool_call dict containing id, function (name, arguments), type, index - chunk containing the tool call arguments being streamed Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -316,11 +316,33 @@ class BaseLLM(ABC):
|
|||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: Agent | None = None,
|
||||||
tool_call: dict[str, Any] | None = None,
|
tool_call: dict[str, Any] | None = None,
|
||||||
|
call_type: LLMCallType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emit stream chunk event."""
|
"""Emit stream chunk event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: The text content of the chunk
|
||||||
|
from_task: Optional task that initiated the call
|
||||||
|
from_agent: Optional agent that initiated the call
|
||||||
|
tool_call: Optional tool call information as a dict with keys:
|
||||||
|
- id: Tool call ID
|
||||||
|
- function: Dict with 'name' and 'arguments'
|
||||||
|
- type: Tool call type (e.g., 'function')
|
||||||
|
- index: Index of the tool call
|
||||||
|
call_type: Optional call type. If not provided, it will be inferred
|
||||||
|
from the presence of tool_call (TOOL_CALL if tool_call is present,
|
||||||
|
LLM_CALL otherwise)
|
||||||
|
"""
|
||||||
if not hasattr(crewai_event_bus, "emit"):
|
if not hasattr(crewai_event_bus, "emit"):
|
||||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||||
|
|
||||||
|
# Infer call_type from tool_call presence if not explicitly provided
|
||||||
|
effective_call_type = call_type
|
||||||
|
if effective_call_type is None:
|
||||||
|
effective_call_type = (
|
||||||
|
LLMCallType.TOOL_CALL if tool_call is not None else LLMCallType.LLM_CALL
|
||||||
|
)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=LLMStreamChunkEvent(
|
event=LLMStreamChunkEvent(
|
||||||
@@ -328,6 +350,7 @@ class BaseLLM(ABC):
|
|||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
|
call_type=effective_call_type,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -450,9 +450,14 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
# (the SDK sets it internally)
|
# (the SDK sets it internally)
|
||||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||||
|
|
||||||
|
# Track tool use blocks during streaming
|
||||||
|
current_tool_use: dict[str, Any] = {}
|
||||||
|
tool_use_index = 0
|
||||||
|
|
||||||
# Make streaming API call
|
# Make streaming API call
|
||||||
with self.client.messages.stream(**stream_params) as stream:
|
with self.client.messages.stream(**stream_params) as stream:
|
||||||
for event in stream:
|
for event in stream:
|
||||||
|
# Handle text content
|
||||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||||
text_delta = event.delta.text
|
text_delta = event.delta.text
|
||||||
full_response += text_delta
|
full_response += text_delta
|
||||||
@@ -462,6 +467,55 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle tool use start (content_block_start event with tool_use type)
|
||||||
|
if hasattr(event, "content_block") and hasattr(event.content_block, "type"):
|
||||||
|
if event.content_block.type == "tool_use":
|
||||||
|
current_tool_use = {
|
||||||
|
"id": getattr(event.content_block, "id", None),
|
||||||
|
"name": getattr(event.content_block, "name", ""),
|
||||||
|
"input": "",
|
||||||
|
"index": tool_use_index,
|
||||||
|
}
|
||||||
|
tool_use_index += 1
|
||||||
|
# Emit tool call start event
|
||||||
|
tool_call_event_data = {
|
||||||
|
"id": current_tool_use["id"],
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_use["name"],
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": current_tool_use["index"],
|
||||||
|
}
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk="",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call=tool_call_event_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle tool use input delta (input_json events)
|
||||||
|
if hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
|
||||||
|
partial_json = event.delta.partial_json
|
||||||
|
if current_tool_use and partial_json:
|
||||||
|
current_tool_use["input"] += partial_json
|
||||||
|
# Emit tool call delta event
|
||||||
|
tool_call_event_data = {
|
||||||
|
"id": current_tool_use["id"],
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_use["name"],
|
||||||
|
"arguments": partial_json,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": current_tool_use["index"],
|
||||||
|
}
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=partial_json,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call=tool_call_event_data,
|
||||||
|
)
|
||||||
|
|
||||||
final_message: Message = stream.get_final_message()
|
final_message: Message = stream.get_final_message()
|
||||||
|
|
||||||
usage = self._extract_anthropic_token_usage(final_message)
|
usage = self._extract_anthropic_token_usage(final_message)
|
||||||
|
|||||||
@@ -503,8 +503,10 @@ class AzureCompletion(BaseLLM):
|
|||||||
call_id = tool_call.id or "default"
|
call_id = tool_call.id or "default"
|
||||||
if call_id not in tool_calls:
|
if call_id not in tool_calls:
|
||||||
tool_calls[call_id] = {
|
tool_calls[call_id] = {
|
||||||
|
"id": call_id,
|
||||||
"name": "",
|
"name": "",
|
||||||
"arguments": "",
|
"arguments": "",
|
||||||
|
"index": getattr(tool_call, "index", 0) or 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool_call.function and tool_call.function.name:
|
if tool_call.function and tool_call.function.name:
|
||||||
@@ -514,6 +516,23 @@ class AzureCompletion(BaseLLM):
|
|||||||
tool_call.function.arguments
|
tool_call.function.arguments
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Emit tool call streaming event
|
||||||
|
tool_call_event_data = {
|
||||||
|
"id": tool_calls[call_id]["id"],
|
||||||
|
"function": {
|
||||||
|
"name": tool_calls[call_id]["name"],
|
||||||
|
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": tool_calls[call_id]["index"],
|
||||||
|
}
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call=tool_call_event_data,
|
||||||
|
)
|
||||||
|
|
||||||
# Handle completed tool calls
|
# Handle completed tool calls
|
||||||
if tool_calls and available_functions:
|
if tool_calls and available_functions:
|
||||||
for call_data in tool_calls.values():
|
for call_data in tool_calls.values():
|
||||||
|
|||||||
@@ -567,12 +567,31 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
elif "contentBlockStart" in event:
|
elif "contentBlockStart" in event:
|
||||||
start = event["contentBlockStart"].get("start", {})
|
start = event["contentBlockStart"].get("start", {})
|
||||||
|
block_index = event["contentBlockStart"].get("contentBlockIndex", 0)
|
||||||
if "toolUse" in start:
|
if "toolUse" in start:
|
||||||
current_tool_use = start["toolUse"]
|
current_tool_use = start["toolUse"]
|
||||||
|
current_tool_use["_block_index"] = block_index
|
||||||
|
current_tool_use["_accumulated_input"] = ""
|
||||||
tool_use_id = current_tool_use.get("toolUseId")
|
tool_use_id = current_tool_use.get("toolUseId")
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
||||||
)
|
)
|
||||||
|
# Emit tool call start event
|
||||||
|
tool_call_event_data = {
|
||||||
|
"id": tool_use_id,
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_use.get("name", ""),
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": block_index,
|
||||||
|
}
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk="",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call=tool_call_event_data,
|
||||||
|
)
|
||||||
|
|
||||||
elif "contentBlockDelta" in event:
|
elif "contentBlockDelta" in event:
|
||||||
delta = event["contentBlockDelta"]["delta"]
|
delta = event["contentBlockDelta"]["delta"]
|
||||||
@@ -589,6 +608,23 @@ class BedrockCompletion(BaseLLM):
|
|||||||
tool_input = delta["toolUse"].get("input", "")
|
tool_input = delta["toolUse"].get("input", "")
|
||||||
if tool_input:
|
if tool_input:
|
||||||
logging.debug(f"Tool input delta: {tool_input}")
|
logging.debug(f"Tool input delta: {tool_input}")
|
||||||
|
current_tool_use["_accumulated_input"] += tool_input
|
||||||
|
# Emit tool call delta event
|
||||||
|
tool_call_event_data = {
|
||||||
|
"id": current_tool_use.get("toolUseId"),
|
||||||
|
"function": {
|
||||||
|
"name": current_tool_use.get("name", ""),
|
||||||
|
"arguments": tool_input,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": current_tool_use.get("_block_index", 0),
|
||||||
|
}
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=tool_input,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call=tool_call_event_data,
|
||||||
|
)
|
||||||
|
|
||||||
# Content block stop - end of a content block
|
# Content block stop - end of a content block
|
||||||
elif "contentBlockStop" in event:
|
elif "contentBlockStop" in event:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -496,7 +497,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||||
candidate = chunk.candidates[0]
|
candidate = chunk.candidates[0]
|
||||||
if candidate.content and candidate.content.parts:
|
if candidate.content and candidate.content.parts:
|
||||||
for part in candidate.content.parts:
|
for part_index, part in enumerate(candidate.content.parts):
|
||||||
if hasattr(part, "function_call") and part.function_call:
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
call_id = part.function_call.name or "default"
|
call_id = part.function_call.name or "default"
|
||||||
if call_id not in function_calls:
|
if call_id not in function_calls:
|
||||||
@@ -505,8 +506,27 @@ class GeminiCompletion(BaseLLM):
|
|||||||
"args": dict(part.function_call.args)
|
"args": dict(part.function_call.args)
|
||||||
if part.function_call.args
|
if part.function_call.args
|
||||||
else {},
|
else {},
|
||||||
|
"index": part_index,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Emit tool call streaming event
|
||||||
|
args_str = json.dumps(function_calls[call_id]["args"]) if function_calls[call_id]["args"] else ""
|
||||||
|
tool_call_event_data = {
|
||||||
|
"id": call_id,
|
||||||
|
"function": {
|
||||||
|
"name": function_calls[call_id]["name"],
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": function_calls[call_id]["index"],
|
||||||
|
}
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=args_str,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call=tool_call_event_data,
|
||||||
|
)
|
||||||
|
|
||||||
# Handle completed function calls
|
# Handle completed function calls
|
||||||
if function_calls and available_functions:
|
if function_calls and available_functions:
|
||||||
for call_data in function_calls.values():
|
for call_data in function_calls.values():
|
||||||
|
|||||||
@@ -510,8 +510,10 @@ class OpenAICompletion(BaseLLM):
|
|||||||
call_id = tool_call.id or "default"
|
call_id = tool_call.id or "default"
|
||||||
if call_id not in tool_calls:
|
if call_id not in tool_calls:
|
||||||
tool_calls[call_id] = {
|
tool_calls[call_id] = {
|
||||||
|
"id": call_id,
|
||||||
"name": "",
|
"name": "",
|
||||||
"arguments": "",
|
"arguments": "",
|
||||||
|
"index": tool_call.index if tool_call.index is not None else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool_call.function and tool_call.function.name:
|
if tool_call.function and tool_call.function.name:
|
||||||
@@ -519,6 +521,23 @@ class OpenAICompletion(BaseLLM):
|
|||||||
if tool_call.function and tool_call.function.arguments:
|
if tool_call.function and tool_call.function.arguments:
|
||||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||||
|
|
||||||
|
# Emit tool call streaming event
|
||||||
|
tool_call_event_data = {
|
||||||
|
"id": tool_calls[call_id]["id"],
|
||||||
|
"function": {
|
||||||
|
"name": tool_calls[call_id]["name"],
|
||||||
|
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": tool_calls[call_id]["index"],
|
||||||
|
}
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
tool_call=tool_call_event_data,
|
||||||
|
)
|
||||||
|
|
||||||
if tool_calls and available_functions:
|
if tool_calls and available_functions:
|
||||||
for call_data in tool_calls.values():
|
for call_data in tool_calls.values():
|
||||||
function_name = call_data["name"]
|
function_name = call_data["name"]
|
||||||
|
|||||||
@@ -715,3 +715,243 @@ class TestStreamingImports:
|
|||||||
assert StreamChunk is not None
|
assert StreamChunk is not None
|
||||||
assert StreamChunkType is not None
|
assert StreamChunkType is not None
|
||||||
assert ToolCallChunk is not None
|
assert ToolCallChunk is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMStreamChunkEventToolCall:
|
||||||
|
"""Tests for LLMStreamChunkEvent with tool call information."""
|
||||||
|
|
||||||
|
def test_llm_stream_chunk_event_with_tool_call(self) -> None:
|
||||||
|
"""Test that LLMStreamChunkEvent correctly handles tool call data."""
|
||||||
|
from crewai.events.types.llm_events import (
|
||||||
|
LLMCallType,
|
||||||
|
LLMStreamChunkEvent,
|
||||||
|
ToolCall,
|
||||||
|
FunctionCall,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a tool call event
|
||||||
|
tool_call = ToolCall(
|
||||||
|
id="call-123",
|
||||||
|
function=FunctionCall(
|
||||||
|
name="search",
|
||||||
|
arguments='{"query": "test"}',
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = LLMStreamChunkEvent(
|
||||||
|
chunk='{"query": "test"}',
|
||||||
|
tool_call=tool_call,
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.chunk == '{"query": "test"}'
|
||||||
|
assert event.tool_call is not None
|
||||||
|
assert event.tool_call.id == "call-123"
|
||||||
|
assert event.tool_call.function.name == "search"
|
||||||
|
assert event.tool_call.function.arguments == '{"query": "test"}'
|
||||||
|
assert event.call_type == LLMCallType.TOOL_CALL
|
||||||
|
|
||||||
|
def test_llm_stream_chunk_event_with_dict_tool_call(self) -> None:
|
||||||
|
"""Test that LLMStreamChunkEvent correctly handles tool call as dict."""
|
||||||
|
from crewai.events.types.llm_events import (
|
||||||
|
LLMCallType,
|
||||||
|
LLMStreamChunkEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a tool call event using dict (as providers emit)
|
||||||
|
tool_call_dict = {
|
||||||
|
"id": "call-456",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location": "NYC"}',
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
event = LLMStreamChunkEvent(
|
||||||
|
chunk='{"location": "NYC"}',
|
||||||
|
tool_call=tool_call_dict,
|
||||||
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.chunk == '{"location": "NYC"}'
|
||||||
|
assert event.tool_call is not None
|
||||||
|
assert event.tool_call.id == "call-456"
|
||||||
|
assert event.tool_call.function.name == "get_weather"
|
||||||
|
assert event.tool_call.function.arguments == '{"location": "NYC"}'
|
||||||
|
assert event.call_type == LLMCallType.TOOL_CALL
|
||||||
|
|
||||||
|
def test_llm_stream_chunk_event_text_only(self) -> None:
|
||||||
|
"""Test that LLMStreamChunkEvent works for text-only chunks."""
|
||||||
|
from crewai.events.types.llm_events import (
|
||||||
|
LLMCallType,
|
||||||
|
LLMStreamChunkEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = LLMStreamChunkEvent(
|
||||||
|
chunk="Hello, world!",
|
||||||
|
tool_call=None,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.chunk == "Hello, world!"
|
||||||
|
assert event.tool_call is None
|
||||||
|
assert event.call_type == LLMCallType.LLM_CALL
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseLLMEmitStreamChunkEvent:
|
||||||
|
"""Tests for BaseLLM._emit_stream_chunk_event method."""
|
||||||
|
|
||||||
|
def test_emit_stream_chunk_event_infers_tool_call_type(self) -> None:
|
||||||
|
"""Test that _emit_stream_chunk_event infers TOOL_CALL type when tool_call is present."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||||
|
|
||||||
|
# Create a mock BaseLLM instance
|
||||||
|
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||||
|
llm = BaseLLM(model="test-model") # type: ignore
|
||||||
|
|
||||||
|
captured_events: list[LLMStreamChunkEvent] = []
|
||||||
|
|
||||||
|
def capture_emit(source: Any, event: Any) -> None:
|
||||||
|
if isinstance(event, LLMStreamChunkEvent):
|
||||||
|
captured_events.append(event)
|
||||||
|
|
||||||
|
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||||
|
mock_bus.emit = capture_emit
|
||||||
|
|
||||||
|
# Emit with tool_call - should infer TOOL_CALL type
|
||||||
|
tool_call_dict = {
|
||||||
|
"id": "call-789",
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"arg": "value"}',
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
llm._emit_stream_chunk_event(
|
||||||
|
chunk='{"arg": "value"}',
|
||||||
|
tool_call=tool_call_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_events) == 1
|
||||||
|
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
|
||||||
|
assert captured_events[0].tool_call is not None
|
||||||
|
|
||||||
|
def test_emit_stream_chunk_event_infers_llm_call_type(self) -> None:
|
||||||
|
"""Test that _emit_stream_chunk_event infers LLM_CALL type when tool_call is None."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||||
|
|
||||||
|
# Create a mock BaseLLM instance
|
||||||
|
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||||
|
llm = BaseLLM(model="test-model") # type: ignore
|
||||||
|
|
||||||
|
captured_events: list[LLMStreamChunkEvent] = []
|
||||||
|
|
||||||
|
def capture_emit(source: Any, event: Any) -> None:
|
||||||
|
if isinstance(event, LLMStreamChunkEvent):
|
||||||
|
captured_events.append(event)
|
||||||
|
|
||||||
|
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||||
|
mock_bus.emit = capture_emit
|
||||||
|
|
||||||
|
# Emit without tool_call - should infer LLM_CALL type
|
||||||
|
llm._emit_stream_chunk_event(
|
||||||
|
chunk="Hello, world!",
|
||||||
|
tool_call=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_events) == 1
|
||||||
|
assert captured_events[0].call_type == LLMCallType.LLM_CALL
|
||||||
|
assert captured_events[0].tool_call is None
|
||||||
|
|
||||||
|
def test_emit_stream_chunk_event_respects_explicit_call_type(self) -> None:
|
||||||
|
"""Test that _emit_stream_chunk_event respects explicitly provided call_type."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||||
|
|
||||||
|
# Create a mock BaseLLM instance
|
||||||
|
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||||
|
llm = BaseLLM(model="test-model") # type: ignore
|
||||||
|
|
||||||
|
captured_events: list[LLMStreamChunkEvent] = []
|
||||||
|
|
||||||
|
def capture_emit(source: Any, event: Any) -> None:
|
||||||
|
if isinstance(event, LLMStreamChunkEvent):
|
||||||
|
captured_events.append(event)
|
||||||
|
|
||||||
|
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||||
|
mock_bus.emit = capture_emit
|
||||||
|
|
||||||
|
# Emit with explicit call_type - should use provided type
|
||||||
|
llm._emit_stream_chunk_event(
|
||||||
|
chunk="test",
|
||||||
|
tool_call=None,
|
||||||
|
call_type=LLMCallType.TOOL_CALL, # Explicitly set even though no tool_call
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_events) == 1
|
||||||
|
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingToolCallExtraction:
|
||||||
|
"""Tests for tool call extraction from streaming events."""
|
||||||
|
|
||||||
|
def test_extract_tool_call_info_from_event(self) -> None:
|
||||||
|
"""Test that tool call info is correctly extracted from LLMStreamChunkEvent."""
|
||||||
|
from crewai.utilities.streaming import _extract_tool_call_info
|
||||||
|
from crewai.events.types.llm_events import (
|
||||||
|
LLMStreamChunkEvent,
|
||||||
|
ToolCall,
|
||||||
|
FunctionCall,
|
||||||
|
)
|
||||||
|
from crewai.types.streaming import StreamChunkType
|
||||||
|
|
||||||
|
# Create event with tool call
|
||||||
|
tool_call = ToolCall(
|
||||||
|
id="call-extract-test",
|
||||||
|
function=FunctionCall(
|
||||||
|
name="extract_test",
|
||||||
|
arguments='{"key": "value"}',
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = LLMStreamChunkEvent(
|
||||||
|
chunk='{"key": "value"}',
|
||||||
|
tool_call=tool_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||||
|
|
||||||
|
assert chunk_type == StreamChunkType.TOOL_CALL
|
||||||
|
assert tool_call_chunk is not None
|
||||||
|
assert tool_call_chunk.tool_id == "call-extract-test"
|
||||||
|
assert tool_call_chunk.tool_name == "extract_test"
|
||||||
|
assert tool_call_chunk.arguments == '{"key": "value"}'
|
||||||
|
assert tool_call_chunk.index == 2
|
||||||
|
|
||||||
|
def test_extract_tool_call_info_returns_text_for_no_tool_call(self) -> None:
|
||||||
|
"""Test that TEXT type is returned when no tool call is present."""
|
||||||
|
from crewai.utilities.streaming import _extract_tool_call_info
|
||||||
|
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
||||||
|
from crewai.types.streaming import StreamChunkType
|
||||||
|
|
||||||
|
event = LLMStreamChunkEvent(
|
||||||
|
chunk="Just text content",
|
||||||
|
tool_call=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||||
|
|
||||||
|
assert chunk_type == StreamChunkType.TEXT
|
||||||
|
assert tool_call_chunk is None
|
||||||
|
|||||||
Reference in New Issue
Block a user