mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-05 01:02:37 +00:00
(wip)feat: emit properly tools event when using a stream LLM mode
This commit is contained in:
@@ -4,7 +4,9 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
@@ -371,6 +373,15 @@ class LLM(BaseLLM):
|
|||||||
last_chunk = None
|
last_chunk = None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
usage_info = None
|
usage_info = None
|
||||||
|
tool_calls = None
|
||||||
|
accumulated_tool_args = defaultdict(
|
||||||
|
lambda: SimpleNamespace(
|
||||||
|
function=SimpleNamespace(
|
||||||
|
name="",
|
||||||
|
arguments="",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# --- 2) Make sure stream is set to True and include usage metrics
|
# --- 2) Make sure stream is set to True and include usage metrics
|
||||||
params["stream"] = True
|
params["stream"] = True
|
||||||
@@ -428,6 +439,19 @@ class LLM(BaseLLM):
|
|||||||
if chunk_content is None and isinstance(delta, dict):
|
if chunk_content is None and isinstance(delta, dict):
|
||||||
# Some models might send empty content chunks
|
# Some models might send empty content chunks
|
||||||
chunk_content = ""
|
chunk_content = ""
|
||||||
|
|
||||||
|
if "tool_calls" in delta:
|
||||||
|
tool_calls = delta["tool_calls"]
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
result = self._handle_streaming_tool_calls(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
accumulated_tool_args=accumulated_tool_args,
|
||||||
|
available_functions=available_functions,
|
||||||
|
)
|
||||||
|
if result is not None:
|
||||||
|
chunk_content = result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(f"Error extracting content from chunk: {e}")
|
logging.debug(f"Error extracting content from chunk: {e}")
|
||||||
logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}")
|
logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}")
|
||||||
@@ -501,7 +525,7 @@ class LLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# --- 6) If still empty, raise an error instead of using a default response
|
# --- 6) If still empty, raise an error instead of using a default response
|
||||||
if not full_response.strip():
|
if not full_response.strip() and len(accumulated_tool_args) == 0:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"No content received from streaming response. Received empty chunks or failed to extract content."
|
"No content received from streaming response. Received empty chunks or failed to extract content."
|
||||||
)
|
)
|
||||||
@@ -568,6 +592,52 @@ class LLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||||
|
|
||||||
|
def _handle_streaming_tool_calls(
|
||||||
|
self,
|
||||||
|
tool_calls: List[Any],
|
||||||
|
accumulated_tool_args: Dict[int, SimpleNamespace],
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None | str:
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if tool_call.function.name:
|
||||||
|
accumulated_tool_args[tool_call.index].function.name = (
|
||||||
|
tool_call.function.name
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_call.function.arguments:
|
||||||
|
accumulated_tool_args[
|
||||||
|
tool_call.index
|
||||||
|
].function.arguments += tool_call.function.arguments
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=LLMStreamChunkEvent(
|
||||||
|
tool_call=tool_call.to_dict(),
|
||||||
|
chunk=tool_call.function.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
accumulated_tool_args[tool_call.index].function.name
|
||||||
|
and accumulated_tool_args[tool_call.index].function.arguments
|
||||||
|
and available_functions
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Try to parse the accumulated arguments
|
||||||
|
json.loads(
|
||||||
|
accumulated_tool_args[tool_call.index].function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the tool call
|
||||||
|
return self._handle_tool_call(
|
||||||
|
[accumulated_tool_args[tool_call.index]],
|
||||||
|
available_functions,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If JSON is incomplete, continue accumulating
|
||||||
|
continue
|
||||||
|
|
||||||
def _handle_streaming_callbacks(
|
def _handle_streaming_callbacks(
|
||||||
self,
|
self,
|
||||||
callbacks: Optional[List[Any]],
|
callbacks: Optional[List[Any]],
|
||||||
|
|||||||
@@ -46,3 +46,4 @@ class LLMStreamChunkEvent(BaseEvent):
|
|||||||
|
|
||||||
type: str = "llm_stream_chunk"
|
type: str = "llm_stream_chunk"
|
||||||
chunk: str
|
chunk: str
|
||||||
|
tool_call: Optional[dict] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user