diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py index a7026c5c5..717de499e 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -316,11 +316,33 @@ class BaseLLM(ABC): from_task: Task | None = None, from_agent: Agent | None = None, tool_call: dict[str, Any] | None = None, + call_type: LLMCallType | None = None, ) -> None: - """Emit stream chunk event.""" + """Emit stream chunk event. + + Args: + chunk: The text content of the chunk + from_task: Optional task that initiated the call + from_agent: Optional agent that initiated the call + tool_call: Optional tool call information as a dict with keys: + - id: Tool call ID + - function: Dict with 'name' and 'arguments' + - type: Tool call type (e.g., 'function') + - index: Index of the tool call + call_type: Optional call type. If not provided, it will be inferred + from the presence of tool_call (TOOL_CALL if tool_call is present, + LLM_CALL otherwise) + """ if not hasattr(crewai_event_bus, "emit"): raise ValueError("crewai_event_bus does not have an emit method") from None + # Infer call_type from tool_call presence if not explicitly provided + effective_call_type = call_type + if effective_call_type is None: + effective_call_type = ( + LLMCallType.TOOL_CALL if tool_call is not None else LLMCallType.LLM_CALL + ) + crewai_event_bus.emit( self, event=LLMStreamChunkEvent( @@ -328,6 +350,7 @@ class BaseLLM(ABC): tool_call=tool_call, from_task=from_task, from_agent=from_agent, + call_type=effective_call_type, ), ) diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index ea161fc63..a8ab257fb 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -450,9 +450,14 @@ class AnthropicCompletion(BaseLLM): # (the SDK sets it internally) stream_params = {k: v for k, v in params.items() if k != "stream"} + # Track tool use blocks during streaming + current_tool_use: dict[str, Any] = {} + tool_use_index = 0 + # Make streaming API call with self.client.messages.stream(**stream_params) as stream: for event in stream: + # Handle text content if hasattr(event, "delta") and hasattr(event.delta, "text"): text_delta = event.delta.text full_response += text_delta @@ -462,6 +467,55 @@ class AnthropicCompletion(BaseLLM): from_agent=from_agent, ) + # Handle tool use start (content_block_start event with tool_use type) + if hasattr(event, "content_block") and hasattr(event.content_block, "type"): + if event.content_block.type == "tool_use": + current_tool_use = { + "id": getattr(event.content_block, "id", None), + "name": getattr(event.content_block, "name", ""), + "input": "", + "index": tool_use_index, + } + tool_use_index += 1 + # Emit tool call start event + tool_call_event_data = { + "id": current_tool_use["id"], + "function": { + "name": current_tool_use["name"], + "arguments": "", + }, + "type": "function", + "index": current_tool_use["index"], + } + self._emit_stream_chunk_event( + chunk="", + from_task=from_task, + from_agent=from_agent, + tool_call=tool_call_event_data, + ) + + # Handle tool use input delta (input_json events) + if hasattr(event, "delta") and hasattr(event.delta, "partial_json"): + partial_json = event.delta.partial_json + if current_tool_use and partial_json: + current_tool_use["input"] += partial_json + # Emit tool call delta event + tool_call_event_data = { + "id": current_tool_use["id"], + "function": { + "name": current_tool_use["name"], + "arguments": partial_json, + }, + "type": "function", + "index": current_tool_use["index"], + } + self._emit_stream_chunk_event( + chunk=partial_json, + from_task=from_task, + from_agent=from_agent, + tool_call=tool_call_event_data, + ) + final_message: Message = stream.get_final_message() usage = self._extract_anthropic_token_usage(final_message) diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index e79bb72f2..1dfd4f7e2 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -503,8 +503,10 @@ class AzureCompletion(BaseLLM): call_id = tool_call.id or "default" if call_id not in tool_calls: tool_calls[call_id] = { + "id": call_id, "name": "", "arguments": "", + "index": getattr(tool_call, "index", 0) or 0, } if tool_call.function and tool_call.function.name: @@ -514,6 +516,23 @@ class AzureCompletion(BaseLLM): tool_call.function.arguments ) + # Emit tool call streaming event + tool_call_event_data = { + "id": tool_calls[call_id]["id"], + "function": { + "name": tool_calls[call_id]["name"], + "arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "", + }, + "type": "function", + "index": tool_calls[call_id]["index"], + } + self._emit_stream_chunk_event( + chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "", + from_task=from_task, + from_agent=from_agent, + tool_call=tool_call_event_data, + ) + # Handle completed tool calls if tool_calls and available_functions: for call_data in tool_calls.values(): diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index 20eabf763..7423e1057 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -567,12 +567,31 @@ class BedrockCompletion(BaseLLM): elif "contentBlockStart" in event: start = event["contentBlockStart"].get("start", {}) + block_index = event["contentBlockStart"].get("contentBlockIndex", 0) if "toolUse" in start: current_tool_use = start["toolUse"] + current_tool_use["_block_index"] = block_index + current_tool_use["_accumulated_input"] = "" tool_use_id = current_tool_use.get("toolUseId") logging.debug( f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})" ) + # Emit tool call start event + tool_call_event_data = { + "id": tool_use_id, + "function": { + "name": current_tool_use.get("name", ""), + "arguments": "", + }, + "type": "function", + "index": block_index, + } + self._emit_stream_chunk_event( + chunk="", + from_task=from_task, + from_agent=from_agent, + tool_call=tool_call_event_data, + ) elif "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] @@ -589,6 +608,23 @@ class BedrockCompletion(BaseLLM): tool_input = delta["toolUse"].get("input", "") if tool_input: logging.debug(f"Tool input delta: {tool_input}") + current_tool_use["_accumulated_input"] += tool_input + # Emit tool call delta event + tool_call_event_data = { + "id": current_tool_use.get("toolUseId"), + "function": { + "name": current_tool_use.get("name", ""), + "arguments": tool_input, + }, + "type": "function", + "index": current_tool_use.get("_block_index", 0), + } + self._emit_stream_chunk_event( + chunk=tool_input, + from_task=from_task, + from_agent=from_agent, + tool_call=tool_call_event_data, + ) # Content block stop - end of a content block elif "contentBlockStop" in event: diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index 027262865..f012c1943 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -1,3 +1,4 @@ +import json import logging import os import re @@ -496,7 +497,7 @@ class GeminiCompletion(BaseLLM): if hasattr(chunk, "candidates") and chunk.candidates: candidate = chunk.candidates[0] if candidate.content and candidate.content.parts: - for part in candidate.content.parts: + for part_index, part in enumerate(candidate.content.parts): if hasattr(part, "function_call") and part.function_call: call_id = part.function_call.name or "default" if call_id not in function_calls: @@ -505,8 +506,27 @@ class GeminiCompletion(BaseLLM): "args": dict(part.function_call.args) if part.function_call.args else {}, + "index": part_index, } + # Emit tool call streaming event + args_str = json.dumps(function_calls[call_id]["args"]) if function_calls[call_id]["args"] else "" + tool_call_event_data = { + "id": call_id, + "function": { + "name": function_calls[call_id]["name"], + "arguments": args_str, + }, + "type": "function", + "index": function_calls[call_id]["index"], + } + self._emit_stream_chunk_event( + chunk=args_str, + from_task=from_task, + from_agent=from_agent, + tool_call=tool_call_event_data, + ) + # Handle completed function calls if function_calls and available_functions: for call_data in function_calls.values(): diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llms/providers/openai/completion.py index b2aac6283..03b8ae6e5 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai/completion.py @@ -510,8 +510,10 @@ class OpenAICompletion(BaseLLM): call_id = tool_call.id or "default" if call_id not in tool_calls: tool_calls[call_id] = { + "id": call_id, "name": "", "arguments": "", + "index": tool_call.index if tool_call.index is not None else 0, } if tool_call.function and tool_call.function.name: @@ -519,6 +521,23 @@ class OpenAICompletion(BaseLLM): if tool_call.function and tool_call.function.arguments: tool_calls[call_id]["arguments"] += tool_call.function.arguments + # Emit tool call streaming event + tool_call_event_data = { + "id": tool_calls[call_id]["id"], + "function": { + "name": tool_calls[call_id]["name"], + "arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "", + }, + "type": "function", + "index": tool_calls[call_id]["index"], + } + self._emit_stream_chunk_event( + chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "", + from_task=from_task, + from_agent=from_agent, + tool_call=tool_call_event_data, + ) + if tool_calls and available_functions: for call_data in tool_calls.values(): function_name = call_data["name"] diff --git a/lib/crewai/tests/test_streaming.py b/lib/crewai/tests/test_streaming.py index 66f639d0f..849856653 100644 --- a/lib/crewai/tests/test_streaming.py +++ b/lib/crewai/tests/test_streaming.py @@ -715,3 +715,243 @@ class TestStreamingImports: assert StreamChunk is not None assert StreamChunkType is not None assert ToolCallChunk is not None + + +class TestLLMStreamChunkEventToolCall: + """Tests for LLMStreamChunkEvent with tool call information.""" + + def test_llm_stream_chunk_event_with_tool_call(self) -> None: + """Test that LLMStreamChunkEvent correctly handles tool call data.""" + from crewai.events.types.llm_events import ( + LLMCallType, + LLMStreamChunkEvent, + ToolCall, + FunctionCall, + ) + + # Create a tool call event + tool_call = ToolCall( + id="call-123", + function=FunctionCall( + name="search", + arguments='{"query": "test"}', + ), + type="function", + index=0, + ) + + event = LLMStreamChunkEvent( + chunk='{"query": "test"}', + tool_call=tool_call, + call_type=LLMCallType.TOOL_CALL, + ) + + assert event.chunk == '{"query": "test"}' + assert event.tool_call is not None + assert event.tool_call.id == "call-123" + assert event.tool_call.function.name == "search" + assert event.tool_call.function.arguments == '{"query": "test"}' + assert event.call_type == LLMCallType.TOOL_CALL + + def test_llm_stream_chunk_event_with_dict_tool_call(self) -> None: + """Test that LLMStreamChunkEvent correctly handles tool call as dict.""" + from crewai.events.types.llm_events import ( + LLMCallType, + LLMStreamChunkEvent, + ) + + # Create a tool call event using dict (as providers emit) + tool_call_dict = { + "id": "call-456", + "function": { + "name": "get_weather", + "arguments": '{"location": "NYC"}', + }, + "type": "function", + "index": 1, + } + + event = LLMStreamChunkEvent( + chunk='{"location": "NYC"}', + tool_call=tool_call_dict, + call_type=LLMCallType.TOOL_CALL, + ) + + assert event.chunk == '{"location": "NYC"}' + assert event.tool_call is not None + assert event.tool_call.id == "call-456" + assert event.tool_call.function.name == "get_weather" + assert event.tool_call.function.arguments == '{"location": "NYC"}' + assert event.call_type == LLMCallType.TOOL_CALL + + def test_llm_stream_chunk_event_text_only(self) -> None: + """Test that LLMStreamChunkEvent works for text-only chunks.""" + from crewai.events.types.llm_events import ( + LLMCallType, + LLMStreamChunkEvent, + ) + + event = LLMStreamChunkEvent( + chunk="Hello, world!", + tool_call=None, + call_type=LLMCallType.LLM_CALL, + ) + + assert event.chunk == "Hello, world!" + assert event.tool_call is None + assert event.call_type == LLMCallType.LLM_CALL + + +class TestBaseLLMEmitStreamChunkEvent: + """Tests for BaseLLM._emit_stream_chunk_event method.""" + + def test_emit_stream_chunk_event_infers_tool_call_type(self) -> None: + """Test that _emit_stream_chunk_event infers TOOL_CALL type when tool_call is present.""" + from unittest.mock import MagicMock, patch + from crewai.llms.base_llm import BaseLLM + from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent + + # Create a mock BaseLLM instance + with patch.object(BaseLLM, "__abstractmethods__", set()): + llm = BaseLLM(model="test-model") # type: ignore + + captured_events: list[LLMStreamChunkEvent] = [] + + def capture_emit(source: Any, event: Any) -> None: + if isinstance(event, LLMStreamChunkEvent): + captured_events.append(event) + + with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus: + mock_bus.emit = capture_emit + + # Emit with tool_call - should infer TOOL_CALL type + tool_call_dict = { + "id": "call-789", + "function": { + "name": "test_tool", + "arguments": '{"arg": "value"}', + }, + "type": "function", + "index": 0, + } + llm._emit_stream_chunk_event( + chunk='{"arg": "value"}', + tool_call=tool_call_dict, + ) + + assert len(captured_events) == 1 + assert captured_events[0].call_type == LLMCallType.TOOL_CALL + assert captured_events[0].tool_call is not None + + def test_emit_stream_chunk_event_infers_llm_call_type(self) -> None: + """Test that _emit_stream_chunk_event infers LLM_CALL type when tool_call is None.""" + from unittest.mock import patch + from crewai.llms.base_llm import BaseLLM + from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent + + # Create a mock BaseLLM instance + with patch.object(BaseLLM, "__abstractmethods__", set()): + llm = BaseLLM(model="test-model") # type: ignore + + captured_events: list[LLMStreamChunkEvent] = [] + + def capture_emit(source: Any, event: Any) -> None: + if isinstance(event, LLMStreamChunkEvent): + captured_events.append(event) + + with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus: + mock_bus.emit = capture_emit + + # Emit without tool_call - should infer LLM_CALL type + llm._emit_stream_chunk_event( + chunk="Hello, world!", + tool_call=None, + ) + + assert len(captured_events) == 1 + assert captured_events[0].call_type == LLMCallType.LLM_CALL + assert captured_events[0].tool_call is None + + def test_emit_stream_chunk_event_respects_explicit_call_type(self) -> None: + """Test that _emit_stream_chunk_event respects explicitly provided call_type.""" + from unittest.mock import patch + from crewai.llms.base_llm import BaseLLM + from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent + + # Create a mock BaseLLM instance + with patch.object(BaseLLM, "__abstractmethods__", set()): + llm = BaseLLM(model="test-model") # type: ignore + + captured_events: list[LLMStreamChunkEvent] = [] + + def capture_emit(source: Any, event: Any) -> None: + if isinstance(event, LLMStreamChunkEvent): + captured_events.append(event) + + with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus: + mock_bus.emit = capture_emit + + # Emit with explicit call_type - should use provided type + llm._emit_stream_chunk_event( + chunk="test", + tool_call=None, + call_type=LLMCallType.TOOL_CALL, # Explicitly set even though no tool_call + ) + + assert len(captured_events) == 1 + assert captured_events[0].call_type == LLMCallType.TOOL_CALL + + +class TestStreamingToolCallExtraction: + """Tests for tool call extraction from streaming events.""" + + def test_extract_tool_call_info_from_event(self) -> None: + """Test that tool call info is correctly extracted from LLMStreamChunkEvent.""" + from crewai.utilities.streaming import _extract_tool_call_info + from crewai.events.types.llm_events import ( + LLMStreamChunkEvent, + ToolCall, + FunctionCall, + ) + from crewai.types.streaming import StreamChunkType + + # Create event with tool call + tool_call = ToolCall( + id="call-extract-test", + function=FunctionCall( + name="extract_test", + arguments='{"key": "value"}', + ), + type="function", + index=2, + ) + + event = LLMStreamChunkEvent( + chunk='{"key": "value"}', + tool_call=tool_call, + ) + + chunk_type, tool_call_chunk = _extract_tool_call_info(event) + + assert chunk_type == StreamChunkType.TOOL_CALL + assert tool_call_chunk is not None + assert tool_call_chunk.tool_id == "call-extract-test" + assert tool_call_chunk.tool_name == "extract_test" + assert tool_call_chunk.arguments == '{"key": "value"}' + assert tool_call_chunk.index == 2 + + def test_extract_tool_call_info_returns_text_for_no_tool_call(self) -> None: + """Test that TEXT type is returned when no tool call is present.""" + from crewai.utilities.streaming import _extract_tool_call_info + from crewai.events.types.llm_events import LLMStreamChunkEvent + from crewai.types.streaming import StreamChunkType + + event = LLMStreamChunkEvent( + chunk="Just text content", + tool_call=None, + ) + + chunk_type, tool_call_chunk = _extract_tool_call_info(event) + + assert chunk_type == StreamChunkType.TEXT + assert tool_call_chunk is None