diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py index bb833ccc8..c09c26453 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -354,8 +354,17 @@ class BaseLLM(ABC): from_task: Task | None = None, from_agent: Agent | None = None, tool_call: dict[str, Any] | None = None, + call_type: LLMCallType | None = None, ) -> None: - """Emit stream chunk event.""" + """Emit stream chunk event. + + Args: + chunk: The text content of the chunk. + from_task: The task that initiated the call. + from_agent: The agent that initiated the call. + tool_call: Tool call information if this is a tool call chunk. + call_type: The type of LLM call (LLM_CALL or TOOL_CALL). + """ if not hasattr(crewai_event_bus, "emit"): raise ValueError("crewai_event_bus does not have an emit method") from None @@ -366,6 +375,7 @@ class BaseLLM(ABC): tool_call=tool_call, from_task=from_task, from_agent=from_agent, + call_type=call_type, ), ) diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index 79e53907d..5266c9097 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -598,6 +598,8 @@ class AnthropicCompletion(BaseLLM): # (the SDK sets it internally) stream_params = {k: v for k, v in params.items() if k != "stream"} + current_tool_calls: dict[int, dict[str, Any]] = {} + # Make streaming API call with self.client.messages.stream(**stream_params) as stream: for event in stream: @@ -610,6 +612,55 @@ class AnthropicCompletion(BaseLLM): from_agent=from_agent, ) + if event.type == "content_block_start": + block = event.content_block + if block.type == "tool_use": + block_index = event.index + current_tool_calls[block_index] = { + "id": block.id, + "name": block.name, + "arguments": "", + "index": block_index, + } + self._emit_stream_chunk_event( + chunk="", + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": block.id, + "function": { + "name": block.name, + "arguments": "", + }, + "type": "function", + "index": block_index, + }, + call_type=LLMCallType.TOOL_CALL, + ) + elif event.type == "content_block_delta": + if event.delta.type == "input_json_delta": + block_index = event.index + partial_json = event.delta.partial_json + if block_index in current_tool_calls and partial_json: + current_tool_calls[block_index]["arguments"] += partial_json + self._emit_stream_chunk_event( + chunk=partial_json, + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": current_tool_calls[block_index]["id"], + "function": { + "name": current_tool_calls[block_index]["name"], + "arguments": current_tool_calls[block_index][ + "arguments" + ], + }, + "type": "function", + "index": block_index, + }, + call_type=LLMCallType.TOOL_CALL, + ) + final_message: Message = stream.get_final_message() thinking_blocks: list[ThinkingBlock] = [] @@ -941,6 +992,8 @@ class AnthropicCompletion(BaseLLM): stream_params = {k: v for k, v in params.items() if k != "stream"} + current_tool_calls: dict[int, dict[str, Any]] = {} + async with self.async_client.messages.stream(**stream_params) as stream: async for event in stream: if hasattr(event, "delta") and hasattr(event.delta, "text"): @@ -952,6 +1005,55 @@ class AnthropicCompletion(BaseLLM): from_agent=from_agent, ) + if event.type == "content_block_start": + block = event.content_block + if block.type == "tool_use": + block_index = event.index + current_tool_calls[block_index] = { + "id": block.id, + "name": block.name, + "arguments": "", + "index": block_index, + } + self._emit_stream_chunk_event( + chunk="", + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": block.id, + "function": { + "name": block.name, + "arguments": "", + }, + "type": "function", + "index": block_index, + }, + call_type=LLMCallType.TOOL_CALL, + ) + elif event.type == "content_block_delta": + if event.delta.type == "input_json_delta": + block_index = event.index + partial_json = event.delta.partial_json + if block_index in current_tool_calls and partial_json: + current_tool_calls[block_index]["arguments"] += partial_json + self._emit_stream_chunk_event( + chunk=partial_json, + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": current_tool_calls[block_index]["id"], + "function": { + "name": current_tool_calls[block_index]["name"], + "arguments": current_tool_calls[block_index][ + "arguments" + ], + }, + "type": "function", + "index": block_index, + }, + call_type=LLMCallType.TOOL_CALL, + ) + final_message: Message = await stream.get_final_message() usage = self._extract_anthropic_token_usage(final_message) diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index 687dee9c6..f3cc16492 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -674,7 +674,7 @@ class AzureCompletion(BaseLLM): self, update: StreamingChatCompletionsUpdate, full_response: str, - tool_calls: dict[str, dict[str, str]], + tool_calls: dict[str, dict[str, Any]], from_task: Any | None = None, from_agent: Any | None = None, ) -> str: @@ -702,12 +702,13 @@ class AzureCompletion(BaseLLM): ) if choice.delta and choice.delta.tool_calls: - for tool_call in choice.delta.tool_calls: + for idx, tool_call in enumerate(choice.delta.tool_calls): call_id = tool_call.id or "default" if call_id not in tool_calls: tool_calls[call_id] = { "name": "", "arguments": "", + "index": idx, } if tool_call.function and tool_call.function.name: @@ -715,12 +716,30 @@ class AzureCompletion(BaseLLM): if tool_call.function and tool_call.function.arguments: tool_calls[call_id]["arguments"] += tool_call.function.arguments + self._emit_stream_chunk_event( + chunk=tool_call.function.arguments + if tool_call.function and tool_call.function.arguments + else "", + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": call_id, + "function": { + "name": tool_calls[call_id]["name"], + "arguments": tool_calls[call_id]["arguments"], + }, + "type": "function", + "index": tool_calls[call_id]["index"], + }, + call_type=LLMCallType.TOOL_CALL, + ) + return full_response def _finalize_streaming_response( self, full_response: str, - tool_calls: dict[str, dict[str, str]], + tool_calls: dict[str, dict[str, Any]], usage_data: dict[str, int], params: AzureCompletionParams, available_functions: dict[str, Any] | None = None, diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index 2057bd871..f66b1cf31 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -315,9 +315,7 @@ class BedrockCompletion(BaseLLM): messages ) - if not self._invoke_before_llm_call_hooks( - cast(list[LLMMessage], formatted_messages), from_agent - ): + if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent): raise ValueError("LLM call blocked by before_llm_call hook") # Prepare request body @@ -361,7 +359,7 @@ class BedrockCompletion(BaseLLM): if self.stream: return self._handle_streaming_converse( - cast(list[LLMMessage], formatted_messages), + formatted_messages, body, available_functions, from_task, @@ -369,7 +367,7 @@ class BedrockCompletion(BaseLLM): ) return self._handle_converse( - cast(list[LLMMessage], formatted_messages), + formatted_messages, body, available_functions, from_task, @@ -433,7 +431,7 @@ class BedrockCompletion(BaseLLM): ) formatted_messages, system_message = self._format_messages_for_converse( - messages # type: ignore[arg-type] + messages ) body: BedrockConverseRequestBody = { @@ -687,8 +685,10 @@ class BedrockCompletion(BaseLLM): ) -> str: """Handle streaming converse API call with comprehensive event handling.""" full_response = "" - current_tool_use = None - tool_use_id = None + current_tool_use: dict[str, Any] | None = None + tool_use_id: str | None = None + tool_use_index = 0 + accumulated_tool_input = "" try: response = self.client.converse_stream( @@ -709,9 +709,30 @@ class BedrockCompletion(BaseLLM): elif "contentBlockStart" in event: start = event["contentBlockStart"].get("start", {}) + content_block_index = event["contentBlockStart"].get( + "contentBlockIndex", 0 + ) if "toolUse" in start: - current_tool_use = start["toolUse"] + tool_use_block = start["toolUse"] + current_tool_use = cast(dict[str, Any], tool_use_block) tool_use_id = current_tool_use.get("toolUseId") + tool_use_index = content_block_index + accumulated_tool_input = "" + self._emit_stream_chunk_event( + chunk="", + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": tool_use_id or "", + "function": { + "name": current_tool_use.get("name", ""), + "arguments": "", + }, + "type": "function", + "index": tool_use_index, + }, + call_type=LLMCallType.TOOL_CALL, + ) logging.debug( f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})" ) @@ -730,7 +751,23 @@ class BedrockCompletion(BaseLLM): elif "toolUse" in delta and current_tool_use: tool_input = delta["toolUse"].get("input", "") if tool_input: + accumulated_tool_input += tool_input logging.debug(f"Tool input delta: {tool_input}") + self._emit_stream_chunk_event( + chunk=tool_input, + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": tool_use_id or "", + "function": { + "name": current_tool_use.get("name", ""), + "arguments": accumulated_tool_input, + }, + "type": "function", + "index": tool_use_index, + }, + call_type=LLMCallType.TOOL_CALL, + ) elif "contentBlockStop" in event: logging.debug("Content block stopped in stream") if current_tool_use and available_functions: @@ -848,7 +885,7 @@ class BedrockCompletion(BaseLLM): async def _ahandle_converse( self, - messages: list[dict[str, Any]], + messages: list[LLMMessage], body: BedrockConverseRequestBody, available_functions: Mapping[str, Any] | None = None, from_task: Any | None = None, @@ -1013,7 +1050,7 @@ class BedrockCompletion(BaseLLM): async def _ahandle_streaming_converse( self, - messages: list[dict[str, Any]], + messages: list[LLMMessage], body: BedrockConverseRequestBody, available_functions: dict[str, Any] | None = None, from_task: Any | None = None, @@ -1021,8 +1058,10 @@ class BedrockCompletion(BaseLLM): ) -> str: """Handle async streaming converse API call.""" full_response = "" - current_tool_use = None - tool_use_id = None + current_tool_use: dict[str, Any] | None = None + tool_use_id: str | None = None + tool_use_index = 0 + accumulated_tool_input = "" try: async_client = await self._ensure_async_client() @@ -1044,9 +1083,30 @@ class BedrockCompletion(BaseLLM): elif "contentBlockStart" in event: start = event["contentBlockStart"].get("start", {}) + content_block_index = event["contentBlockStart"].get( + "contentBlockIndex", 0 + ) if "toolUse" in start: - current_tool_use = start["toolUse"] + tool_use_block = start["toolUse"] + current_tool_use = cast(dict[str, Any], tool_use_block) tool_use_id = current_tool_use.get("toolUseId") + tool_use_index = content_block_index + accumulated_tool_input = "" + self._emit_stream_chunk_event( + chunk="", + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": tool_use_id or "", + "function": { + "name": current_tool_use.get("name", ""), + "arguments": "", + }, + "type": "function", + "index": tool_use_index, + }, + call_type=LLMCallType.TOOL_CALL, + ) logging.debug( f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})" ) @@ -1065,7 +1125,23 @@ class BedrockCompletion(BaseLLM): elif "toolUse" in delta and current_tool_use: tool_input = delta["toolUse"].get("input", "") if tool_input: + accumulated_tool_input += tool_input logging.debug(f"Tool input delta: {tool_input}") + self._emit_stream_chunk_event( + chunk=tool_input, + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": tool_use_id or "", + "function": { + "name": current_tool_use.get("name", ""), + "arguments": accumulated_tool_input, + }, + "type": "function", + "index": tool_use_index, + }, + call_type=LLMCallType.TOOL_CALL, + ) elif "contentBlockStop" in event: logging.debug("Content block stopped in stream") @@ -1174,7 +1250,7 @@ class BedrockCompletion(BaseLLM): def _format_messages_for_converse( self, messages: str | list[LLMMessage] - ) -> tuple[list[dict[str, Any]], str | None]: + ) -> tuple[list[LLMMessage], str | None]: """Format messages for Converse API following AWS documentation. Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses @@ -1184,7 +1260,7 @@ class BedrockCompletion(BaseLLM): # Use base class formatting first formatted_messages = self._format_messages(messages) - converse_messages: list[dict[str, Any]] = [] + converse_messages: list[LLMMessage] = [] system_message: str | None = None for message in formatted_messages: diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index e511c61b0..d6402cfa0 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import logging import os import re @@ -677,17 +678,39 @@ class GeminiCompletion(BaseLLM): if chunk.candidates: candidate = chunk.candidates[0] if candidate.content and candidate.content.parts: - for part in candidate.content.parts: + for idx, part in enumerate(candidate.content.parts): if hasattr(part, "function_call") and part.function_call: call_id = part.function_call.name or "default" + args_dict = ( + dict(part.function_call.args) + if part.function_call.args + else {} + ) + args_json = json.dumps(args_dict) + if call_id not in function_calls: function_calls[call_id] = { "name": part.function_call.name, - "args": dict(part.function_call.args) - if part.function_call.args - else {}, + "args": args_dict, + "index": idx, } + self._emit_stream_chunk_event( + chunk=args_json, + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": call_id, + "function": { + "name": part.function_call.name or "", + "arguments": args_json, + }, + "type": "function", + "index": function_calls[call_id]["index"], + }, + call_type=LLMCallType.TOOL_CALL, + ) + return full_response, function_calls, usage_data def _finalize_streaming_response( diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llms/providers/openai/completion.py index becb5209b..16840ea12 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai/completion.py @@ -521,7 +521,7 @@ class OpenAICompletion(BaseLLM): ) -> str: """Handle streaming chat completion.""" full_response = "" - tool_calls = {} + tool_calls: dict[str, dict[str, Any]] = {} if response_model: parse_params = { @@ -592,10 +592,12 @@ class OpenAICompletion(BaseLLM): if chunk_delta.tool_calls: for tool_call in chunk_delta.tool_calls: call_id = tool_call.id or "default" + tool_index = tool_call.index if tool_call.index is not None else 0 if call_id not in tool_calls: tool_calls[call_id] = { "name": "", "arguments": "", + "index": tool_index, } if tool_call.function and tool_call.function.name: @@ -603,6 +605,24 @@ class OpenAICompletion(BaseLLM): if tool_call.function and tool_call.function.arguments: tool_calls[call_id]["arguments"] += tool_call.function.arguments + self._emit_stream_chunk_event( + chunk=tool_call.function.arguments + if tool_call.function and tool_call.function.arguments + else "", + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": call_id, + "function": { + "name": tool_calls[call_id]["name"], + "arguments": tool_calls[call_id]["arguments"], + }, + "type": "function", + "index": tool_calls[call_id]["index"], + }, + call_type=LLMCallType.TOOL_CALL, + ) + self._track_token_usage_internal(usage_data) if tool_calls and available_functions: @@ -789,7 +809,7 @@ class OpenAICompletion(BaseLLM): ) -> str: """Handle async streaming chat completion.""" full_response = "" - tool_calls = {} + tool_calls: dict[str, dict[str, Any]] = {} if response_model: completion_stream: AsyncIterator[ @@ -871,10 +891,12 @@ class OpenAICompletion(BaseLLM): if chunk_delta.tool_calls: for tool_call in chunk_delta.tool_calls: call_id = tool_call.id or "default" + tool_index = tool_call.index if tool_call.index is not None else 0 if call_id not in tool_calls: tool_calls[call_id] = { "name": "", "arguments": "", + "index": tool_index, } if tool_call.function and tool_call.function.name: @@ -882,6 +904,24 @@ class OpenAICompletion(BaseLLM): if tool_call.function and tool_call.function.arguments: tool_calls[call_id]["arguments"] += tool_call.function.arguments + self._emit_stream_chunk_event( + chunk=tool_call.function.arguments + if tool_call.function and tool_call.function.arguments + else "", + from_task=from_task, + from_agent=from_agent, + tool_call={ + "id": call_id, + "function": { + "name": tool_calls[call_id]["name"], + "arguments": tool_calls[call_id]["arguments"], + }, + "type": "function", + "index": tool_calls[call_id]["index"], + }, + call_type=LLMCallType.TOOL_CALL, + ) + self._track_token_usage_internal(usage_data) if tool_calls and available_functions: