Unblock LLM(stream=True) to work with tools (#2582)

* feat: unblock LLM(stream=True) to work with tools

* feat: replace pytest-vcr by pytest-recording

1. pytest-vcr does not support httpx - which LiteLLM uses for streaming responses.
2. pytest-vcr is no longer maintained, last commit 6 years ago :fist::skin-tone-4:
3. pytest-recording supports modern request libraries (including httpx) and actively maintained

* refactor: remove @skip_streaming_in_ci

Since we have fixed streaming response issue we can remove this @skip_streaming_in_ci

---------

Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
Lucas Gomide
2025-04-17 12:58:52 -03:00
committed by GitHub
parent 8e555149f7
commit ced3c8f0e0
21 changed files with 995 additions and 402 deletions

View File

@@ -4,9 +4,12 @@ import os
import sys
import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from types import SimpleNamespace
from typing import (
Any,
DefaultDict,
Dict,
List,
Literal,
@@ -18,7 +21,8 @@ from typing import (
)
from dotenv import load_dotenv
from pydantic import BaseModel
from litellm.types.utils import ChatCompletionDeltaToolCall
from pydantic import BaseModel, Field
from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent,
@@ -219,6 +223,15 @@ class StreamingChoices(TypedDict):
finish_reason: Optional[str]
class FunctionArgs(BaseModel):
name: str = ""
arguments: str = ""
class AccumulatedToolArgs(BaseModel):
function: FunctionArgs = Field(default_factory=FunctionArgs)
class LLM(BaseLLM):
def __init__(
self,
@@ -371,6 +384,11 @@ class LLM(BaseLLM):
last_chunk = None
chunk_count = 0
usage_info = None
tool_calls = None
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs
)
# --- 2) Make sure stream is set to True and include usage metrics
params["stream"] = True
@@ -428,6 +446,20 @@ class LLM(BaseLLM):
if chunk_content is None and isinstance(delta, dict):
# Some models might send empty content chunks
chunk_content = ""
# Enable tool calls using streaming
if "tool_calls" in delta:
tool_calls = delta["tool_calls"]
if tool_calls:
result = self._handle_streaming_tool_calls(
tool_calls=tool_calls,
accumulated_tool_args=accumulated_tool_args,
available_functions=available_functions,
)
if result is not None:
chunk_content = result
except Exception as e:
logging.debug(f"Error extracting content from chunk: {e}")
logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}")
@@ -442,7 +474,6 @@ class LLM(BaseLLM):
self,
event=LLMStreamChunkEvent(chunk=chunk_content),
)
# --- 4) Fallback to non-streaming if no content received
if not full_response.strip() and chunk_count == 0:
logging.warning(
@@ -501,7 +532,7 @@ class LLM(BaseLLM):
)
# --- 6) If still empty, raise an error instead of using a default response
if not full_response.strip():
if not full_response.strip() and len(accumulated_tool_args) == 0:
raise Exception(
"No content received from streaming response. Received empty chunks or failed to extract content."
)
@@ -533,8 +564,8 @@ class LLM(BaseLLM):
tool_calls = getattr(message, "tool_calls")
except Exception as e:
logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly
if not tool_calls or not available_functions:
# Log token usage if available in streaming mode
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
@@ -568,6 +599,47 @@ class LLM(BaseLLM):
)
raise Exception(f"Failed to get streaming response: {str(e)}")
def _handle_streaming_tool_calls(
self,
tool_calls: List[ChatCompletionDeltaToolCall],
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
available_functions: Optional[Dict[str, Any]] = None,
) -> None | str:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
if tool_call.function.name:
current_tool_accumulator.function.name = tool_call.function.name
if tool_call.function.arguments:
current_tool_accumulator.function.arguments += (
tool_call.function.arguments
)
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
tool_call=tool_call.to_dict(),
chunk=tool_call.function.arguments,
),
)
if (
current_tool_accumulator.function.name
and current_tool_accumulator.function.arguments
and available_functions
):
try:
json.loads(current_tool_accumulator.function.arguments)
return self._handle_tool_call(
[current_tool_accumulator],
available_functions,
)
except json.JSONDecodeError:
continue
return None
def _handle_streaming_callbacks(
self,
callbacks: Optional[List[Any]],