(wip)feat: emit properly tools event when using a stream LLM mode

This commit is contained in:
Lucas Gomide
2025-04-10 17:43:31 -03:00
parent c9f47e6a37
commit 00b6e04106
2 changed files with 72 additions and 1 deletions

View File

@@ -4,7 +4,9 @@ import os
import sys import sys
import threading import threading
import warnings import warnings
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from types import SimpleNamespace
from typing import ( from typing import (
Any, Any,
Dict, Dict,
@@ -371,6 +373,15 @@ class LLM(BaseLLM):
last_chunk = None last_chunk = None
chunk_count = 0 chunk_count = 0
usage_info = None usage_info = None
tool_calls = None
accumulated_tool_args = defaultdict(
lambda: SimpleNamespace(
function=SimpleNamespace(
name="",
arguments="",
),
),
)
# --- 2) Make sure stream is set to True and include usage metrics # --- 2) Make sure stream is set to True and include usage metrics
params["stream"] = True params["stream"] = True
@@ -428,6 +439,19 @@ class LLM(BaseLLM):
if chunk_content is None and isinstance(delta, dict): if chunk_content is None and isinstance(delta, dict):
# Some models might send empty content chunks # Some models might send empty content chunks
chunk_content = "" chunk_content = ""
if "tool_calls" in delta:
tool_calls = delta["tool_calls"]
if tool_calls:
result = self._handle_streaming_tool_calls(
tool_calls=tool_calls,
accumulated_tool_args=accumulated_tool_args,
available_functions=available_functions,
)
if result is not None:
chunk_content = result
except Exception as e: except Exception as e:
logging.debug(f"Error extracting content from chunk: {e}") logging.debug(f"Error extracting content from chunk: {e}")
logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}") logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}")
@@ -501,7 +525,7 @@ class LLM(BaseLLM):
) )
# --- 6) If still empty, raise an error instead of using a default response # --- 6) If still empty, raise an error instead of using a default response
if not full_response.strip(): if not full_response.strip() and len(accumulated_tool_args) == 0:
raise Exception( raise Exception(
"No content received from streaming response. Received empty chunks or failed to extract content." "No content received from streaming response. Received empty chunks or failed to extract content."
) )
@@ -568,6 +592,52 @@ class LLM(BaseLLM):
) )
raise Exception(f"Failed to get streaming response: {str(e)}") raise Exception(f"Failed to get streaming response: {str(e)}")
def _handle_streaming_tool_calls(
self,
tool_calls: List[Any],
accumulated_tool_args: Dict[int, SimpleNamespace],
available_functions: Optional[Dict[str, Any]] = None,
) -> None | str:
for tool_call in tool_calls:
if tool_call.function.name:
accumulated_tool_args[tool_call.index].function.name = (
tool_call.function.name
)
if tool_call.function.arguments:
accumulated_tool_args[
tool_call.index
].function.arguments += tool_call.function.arguments
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
tool_call=tool_call.to_dict(),
chunk=tool_call.function.arguments,
),
)
if (
accumulated_tool_args[tool_call.index].function.name
and accumulated_tool_args[tool_call.index].function.arguments
and available_functions
):
try:
# Try to parse the accumulated arguments
json.loads(
accumulated_tool_args[tool_call.index].function.arguments
)
# Execute the tool call
return self._handle_tool_call(
[accumulated_tool_args[tool_call.index]],
available_functions,
)
except json.JSONDecodeError:
# If JSON is incomplete, continue accumulating
continue
def _handle_streaming_callbacks( def _handle_streaming_callbacks(
self, self,
callbacks: Optional[List[Any]], callbacks: Optional[List[Any]],

View File

@@ -46,3 +46,4 @@ class LLMStreamChunkEvent(BaseEvent):
type: str = "llm_stream_chunk" type: str = "llm_stream_chunk"
chunk: str chunk: str
tool_call: Optional[dict] = None