mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Unblock LLM(stream=True) to work with tools (#2582)
* feat: unblock LLM(stream=True) to work with tools * feat: replace pytest-vcr by pytest-recording 1. pytest-vcr does not support httpx - which LiteLLM uses for streaming responses. 2. pytest-vcr is no longer maintained, last commit 6 years ago :fist::skin-tone-4: 3. pytest-recording supports modern request libraries (including httpx) and actively maintained * refactor: remove @skip_streaming_in_ci Since we have fixed streaming response issue we can remove this @skip_streaming_in_ci --------- Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
@@ -4,9 +4,12 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
@@ -18,7 +21,8 @@ from typing import (
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
@@ -219,6 +223,15 @@ class StreamingChoices(TypedDict):
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
class FunctionArgs(BaseModel):
|
||||
name: str = ""
|
||||
arguments: str = ""
|
||||
|
||||
|
||||
class AccumulatedToolArgs(BaseModel):
|
||||
function: FunctionArgs = Field(default_factory=FunctionArgs)
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -371,6 +384,11 @@ class LLM(BaseLLM):
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
usage_info = None
|
||||
tool_calls = None
|
||||
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
# --- 2) Make sure stream is set to True and include usage metrics
|
||||
params["stream"] = True
|
||||
@@ -428,6 +446,20 @@ class LLM(BaseLLM):
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
# Some models might send empty content chunks
|
||||
chunk_content = ""
|
||||
|
||||
# Enable tool calls using streaming
|
||||
if "tool_calls" in delta:
|
||||
tool_calls = delta["tool_calls"]
|
||||
|
||||
if tool_calls:
|
||||
result = self._handle_streaming_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
accumulated_tool_args=accumulated_tool_args,
|
||||
available_functions=available_functions,
|
||||
)
|
||||
if result is not None:
|
||||
chunk_content = result
|
||||
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting content from chunk: {e}")
|
||||
logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}")
|
||||
@@ -442,7 +474,6 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
event=LLMStreamChunkEvent(chunk=chunk_content),
|
||||
)
|
||||
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
if not full_response.strip() and chunk_count == 0:
|
||||
logging.warning(
|
||||
@@ -501,7 +532,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
|
||||
# --- 6) If still empty, raise an error instead of using a default response
|
||||
if not full_response.strip():
|
||||
if not full_response.strip() and len(accumulated_tool_args) == 0:
|
||||
raise Exception(
|
||||
"No content received from streaming response. Received empty chunks or failed to extract content."
|
||||
)
|
||||
@@ -533,8 +564,8 @@ class LLM(BaseLLM):
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
|
||||
if not tool_calls or not available_functions:
|
||||
# Log token usage if available in streaming mode
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
@@ -568,6 +599,47 @@ class LLM(BaseLLM):
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
|
||||
def _handle_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> None | str:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
|
||||
if tool_call.function.name:
|
||||
current_tool_accumulator.function.name = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments:
|
||||
current_tool_accumulator.function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
tool_call=tool_call.to_dict(),
|
||||
chunk=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
current_tool_accumulator.function.name
|
||||
and current_tool_accumulator.function.arguments
|
||||
and available_functions
|
||||
):
|
||||
try:
|
||||
json.loads(current_tool_accumulator.function.arguments)
|
||||
|
||||
return self._handle_tool_call(
|
||||
[current_tool_accumulator],
|
||||
available_functions,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
def _handle_streaming_callbacks(
|
||||
self,
|
||||
callbacks: Optional[List[Any]],
|
||||
|
||||
Reference in New Issue
Block a user