Initial Stream working

This commit is contained in:
Brandon Hancock
2025-03-03 13:10:33 -05:00
parent a3d5c86218
commit 143832bd8b
5 changed files with 398 additions and 122 deletions

View File

@@ -224,6 +224,7 @@ CrewAI provides a wide range of events that you can listen for:
- **LLMCallStartedEvent**: Emitted when an LLM call starts - **LLMCallStartedEvent**: Emitted when an LLM call starts
- **LLMCallCompletedEvent**: Emitted when an LLM call completes - **LLMCallCompletedEvent**: Emitted when an LLM call completes
- **LLMCallFailedEvent**: Emitted when an LLM call fails - **LLMCallFailedEvent**: Emitted when an LLM call fails
- **LLMStreamChunkEvent**: Emitted for each chunk received during streaming LLM responses
## Event Handler Structure ## Event Handler Structure

View File

@@ -5,7 +5,17 @@ import sys
import threading import threading
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast from typing import (
Any,
Dict,
List,
Literal,
Optional,
Type,
TypedDict,
Union,
cast,
)
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel
@@ -15,6 +25,7 @@ from crewai.utilities.events.llm_events import (
LLMCallFailedEvent, LLMCallFailedEvent,
LLMCallStartedEvent, LLMCallStartedEvent,
LLMCallType, LLMCallType,
LLMStreamChunkEvent,
) )
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
@@ -126,6 +137,17 @@ def suppress_warnings():
sys.stderr = old_stderr sys.stderr = old_stderr
class Delta(TypedDict):
content: Optional[str]
role: Optional[str]
class StreamingChoices(TypedDict):
delta: Delta
index: int
finish_reason: Optional[str]
class LLM: class LLM:
def __init__( def __init__(
self, self,
@@ -150,6 +172,7 @@ class LLM:
api_key: Optional[str] = None, api_key: Optional[str] = None,
callbacks: List[Any] = [], callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
stream: bool = False,
**kwargs, **kwargs,
): ):
self.model = model self.model = model
@@ -175,6 +198,7 @@ class LLM:
self.reasoning_effort = reasoning_effort self.reasoning_effort = reasoning_effort
self.additional_params = kwargs self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model) self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
litellm.drop_params = True litellm.drop_params = True
@@ -201,6 +225,322 @@ class LLM:
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/") ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
"""Prepare parameters for the completion call.
Args:
messages: Input messages for the LLM
tools: Optional list of tool schemas
callbacks: Optional list of callback functions
available_functions: Optional dict of available functions
Returns:
Dict[str, Any]: Parameters for the completion call
"""
# --- 1) Format messages according to provider requirements
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
formatted_messages = self._format_messages_for_provider(messages)
# --- 2) Prepare the parameters for the completion call
params = {
"model": self.model,
"messages": formatted_messages,
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.api_base,
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": self.stream,
"tools": tools,
"reasoning_effort": self.reasoning_effort,
**self.additional_params,
}
# Remove None values from params
return {k: v for k, v in params.items() if v is not None}
def _handle_streaming_response(
self,
params: Dict[str, Any],
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""Handle a streaming response from the LLM.
Args:
params: Parameters for the completion call
available_functions: Dict of available functions
Returns:
str: The complete response text
"""
# --- 1) Initialize response tracking
full_response = ""
last_chunk = None
chunk_count = 0
debug_info = []
# --- 2) Make sure stream is set to True
params["stream"] = True
try:
# --- 3) Process each chunk in the stream
for chunk in litellm.completion(**params):
chunk_count += 1
last_chunk = chunk
# Add debug info
debug_info.append(f"Chunk type: {type(chunk)}")
# Extract content from the chunk
chunk_content = None
# Handle ModelResponse objects
if isinstance(chunk, ModelResponse):
debug_info.append("Chunk is ModelResponse")
choices = getattr(chunk, "choices", [])
if choices and len(choices) > 0:
choice = choices[0]
debug_info.append(f"Choice type: {type(choice)}")
# Handle dictionary-style choices
if isinstance(choice, dict):
delta = choice.get("delta", {})
debug_info.append(f"Delta: {delta}")
if (
isinstance(delta, dict)
and "content" in delta
and delta["content"] is not None
):
chunk_content = delta["content"]
# Handle object-style choices
else:
# Try to access delta attribute safely
delta = getattr(choice, "delta", None)
debug_info.append(f"Delta: {delta}")
if delta is not None:
# Try to get content from delta.content
if (
hasattr(delta, "content")
and getattr(delta, "content", None) is not None
):
chunk_content = getattr(delta, "content")
# Some models return delta as a string
elif isinstance(delta, str):
chunk_content = delta
# Add content to response if found
if chunk_content:
full_response += chunk_content
print(f"Chunk content: {chunk_content}")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(chunk=chunk_content),
)
else:
debug_info.append(f"No content found in chunk: {chunk}")
# --- 4) Fallback to non-streaming if no content received
if not full_response.strip() and chunk_count == 0:
logging.warning(
"No chunks received in streaming response, falling back to non-streaming"
)
# Try non-streaming as fallback
non_streaming_params = params.copy()
non_streaming_params["stream"] = False
return self._handle_non_streaming_response(
non_streaming_params, available_functions
)
# --- 5) Handle empty response with chunks
if not full_response.strip() and chunk_count > 0:
logging.warning(
f"Received {chunk_count} chunks but no content. Debug info: {debug_info}"
)
if last_chunk is not None:
# Try to extract any content from the last chunk
if isinstance(last_chunk, ModelResponse):
choices = getattr(last_chunk, "choices", [])
if choices and len(choices) > 0:
choice = choices[0]
# Try to get content from message
message = getattr(choice, "message", None)
if message is not None and getattr(
message, "content", None
):
full_response = getattr(message, "content")
logging.info(
f"Extracted content from last chunk message: {full_response}"
)
# Try to get content from text (some models use this)
elif getattr(choice, "text", None):
full_response = getattr(choice, "text")
logging.info(
f"Extracted text from last chunk: {full_response}"
)
# --- 6) If still empty, use a default response
if not full_response.strip():
logging.warning("Using default response as fallback")
full_response = "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
# --- 7) Check for tool calls in the final response
if isinstance(last_chunk, ModelResponse):
choices = getattr(last_chunk, "choices", [])
if choices and len(choices) > 0:
choice = choices[0]
message = getattr(choice, "message", None)
if message is not None:
tool_calls = getattr(message, "tool_calls", [])
tool_result = self._handle_tool_call(
tool_calls, available_functions
)
if tool_result is not None:
return tool_result
# --- 8) Emit completion event and return response
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response
except Exception as e:
logging.error(
f"Error in streaming response: {str(e)}, Debug info: {debug_info}"
)
# If we have any response content, return it instead of failing
if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}")
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response
# Try non-streaming as fallback
try:
logging.warning("Falling back to non-streaming after error")
non_streaming_params = params.copy()
non_streaming_params["stream"] = False
return self._handle_non_streaming_response(
non_streaming_params, available_functions
)
except Exception as fallback_error:
logging.error(
f"Fallback to non-streaming also failed: {str(fallback_error)}"
)
# Return a default response as last resort
default_response = "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
self._handle_emit_call_events(default_response, LLMCallType.LLM_CALL)
return default_response
def _handle_non_streaming_response(
self,
params: Dict[str, Any],
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""Handle a non-streaming response from the LLM.
Args:
params: Parameters for the completion call
available_functions: Dict of available functions
Returns:
str: The response text
"""
# --- 1) Make the completion call
response = litellm.completion(**params)
# --- 2) Extract response message and content
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
# --- 3) Check for tool calls
tool_calls = getattr(response_message, "tool_calls", [])
# --- 4) Handle tool calls if present
tool_result = self._handle_tool_call(tool_calls, available_functions)
if tool_result is not None:
return tool_result
# --- 5) Emit completion event and return response
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
return text_response
def _handle_tool_call(
self,
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
"""Handle a tool call from the LLM.
Args:
tool_calls: List of tool calls from the LLM
available_functions: Dict of available functions
Returns:
Optional[str]: The result of the tool call, or None if no tool call was made
"""
# --- 1) Validate tool calls and available functions
if not tool_calls or not available_functions:
return None
# --- 2) Extract function name from first tool call
tool_call = tool_calls[0]
function_name = tool_call.function.name
function_args = {} # Initialize to empty dict to avoid unbound variable
# --- 3) Check if function is available
if function_name in available_functions:
try:
# --- 3.1) Parse function arguments
function_args = json.loads(tool_call.function.arguments)
fn = available_functions[function_name]
# --- 3.2) Execute function
result = fn(**function_args)
# --- 3.3) Emit success event
self._handle_emit_call_events(result, LLMCallType.TOOL_CALL)
return result
except Exception as e:
# --- 3.4) Handle execution errors
fn = available_functions.get(
function_name, lambda: None
) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}")
crewai_event_bus.emit(
self,
event=ToolExecutionErrorEvent(
tool_name=function_name,
tool_args=function_args,
tool_class=fn,
error=str(e),
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
)
return None
def call( def call(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: Union[str, List[Dict[str, str]]],
@@ -230,22 +570,8 @@ class LLM:
TypeError: If messages format is invalid TypeError: If messages format is invalid
ValueError: If response format is not supported ValueError: If response format is not supported
LLMContextLengthExceededException: If input exceeds model's context limit LLMContextLengthExceededException: If input exceeds model's context limit
Examples:
# Example 1: Simple string input
>>> response = llm.call("Return the name of a random city.")
>>> print(response)
"Paris"
# Example 2: Message list with system and user messages
>>> messages = [
... {"role": "system", "content": "You are a geography expert"},
... {"role": "user", "content": "What is France's capital?"}
... ]
>>> response = llm.call(messages)
>>> print(response)
"The capital of France is Paris."
""" """
# --- 1) Emit call started event
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMCallStartedEvent( event=LLMCallStartedEvent(
@@ -255,127 +581,36 @@ class LLM:
available_functions=available_functions, available_functions=available_functions,
), ),
) )
# Validate parameters before proceeding with the call.
# --- 2) Validate parameters before proceeding with the call
self._validate_call_params() self._validate_call_params()
# --- 3) Convert string messages to proper format if needed
if isinstance(messages, str): if isinstance(messages, str):
messages = [{"role": "user", "content": messages}] messages = [{"role": "user", "content": messages}]
# For O1 models, system messages are not supported. # --- 4) Handle O1 model special case (system messages not supported)
# Convert any system messages into assistant messages.
if "o1" in self.model.lower(): if "o1" in self.model.lower():
for message in messages: for message in messages:
if message.get("role") == "system": if message.get("role") == "system":
message["role"] = "assistant" message["role"] = "assistant"
# --- 5) Set up callbacks if provided
with suppress_warnings(): with suppress_warnings():
if callbacks and len(callbacks) > 0: if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks) self.set_callbacks(callbacks)
try: try:
# --- 1) Format messages according to provider requirements # --- 6) Prepare parameters for the completion call
formatted_messages = self._format_messages_for_provider(messages) params = self._prepare_completion_params(messages, tools)
# --- 2) Prepare the parameters for the completion call
params = {
"model": self.model,
"messages": formatted_messages,
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.api_base,
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,
"tools": tools,
"reasoning_effort": self.reasoning_effort,
**self.additional_params,
}
# Remove None values from params
params = {k: v for k, v in params.items() if v is not None}
# --- 2) Make the completion call
response = litellm.completion(**params)
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
tool_calls = getattr(response_message, "tool_calls", [])
# --- 3) Handle callbacks with usage info
if callbacks and len(callbacks) > 0:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
usage_info = getattr(response, "usage", None)
if usage_info:
callback.log_success_event(
kwargs=params,
response_obj={"usage": usage_info},
start_time=0,
end_time=0,
)
# --- 4) If no tool calls, return the text response
if not tool_calls or not available_functions:
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
return text_response
# --- 5) Handle the tool call
tool_call = tool_calls[0]
function_name = tool_call.function.name
if function_name in available_functions:
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.warning(f"Failed to parse function arguments: {e}")
return text_response
fn = available_functions[function_name]
try:
# Call the actual tool function
result = fn(**function_args)
self._handle_emit_call_events(result, LLMCallType.TOOL_CALL)
return result
except Exception as e:
logging.error(
f"Error executing function '{function_name}': {e}"
)
crewai_event_bus.emit(
self,
event=ToolExecutionErrorEvent(
tool_name=function_name,
tool_args=function_args,
tool_class=fn,
error=str(e),
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=f"Tool execution error: {str(e)}"
),
)
return text_response
# --- 7) Make the completion call and handle response
if self.stream:
return self._handle_streaming_response(params, available_functions)
else: else:
logging.warning( return self._handle_non_streaming_response(
f"Tool call requested unknown function '{function_name}'" params, available_functions
) )
return text_response
except Exception as e: except Exception as e:
crewai_event_bus.emit( crewai_event_bus.emit(
@@ -426,6 +661,20 @@ class LLM:
"Invalid message format. Each message must be a dict with 'role' and 'content' keys" "Invalid message format. Each message must be a dict with 'role' and 'content' keys"
) )
# Handle O1 models specially
if "o1" in self.model.lower():
formatted_messages = []
for msg in messages:
# Convert system messages to assistant messages
if msg["role"] == "system":
formatted_messages.append(
{"role": "assistant", "content": msg["content"]}
)
else:
formatted_messages.append(msg)
return formatted_messages
# Handle Anthropic models
if not self.is_anthropic: if not self.is_anthropic:
return messages return messages

View File

@@ -14,7 +14,12 @@ from .agent_events import (
AgentExecutionCompletedEvent, AgentExecutionCompletedEvent,
AgentExecutionErrorEvent, AgentExecutionErrorEvent,
) )
from .task_events import TaskStartedEvent, TaskCompletedEvent, TaskFailedEvent, TaskEvaluationEvent from .task_events import (
TaskStartedEvent,
TaskCompletedEvent,
TaskFailedEvent,
TaskEvaluationEvent,
)
from .flow_events import ( from .flow_events import (
FlowCreatedEvent, FlowCreatedEvent,
FlowStartedEvent, FlowStartedEvent,
@@ -34,7 +39,13 @@ from .tool_usage_events import (
ToolUsageEvent, ToolUsageEvent,
ToolValidateInputErrorEvent, ToolValidateInputErrorEvent,
) )
from .llm_events import LLMCallCompletedEvent, LLMCallFailedEvent, LLMCallStartedEvent from .llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
LLMStreamChunkEvent,
)
# events # events
from .event_listener import EventListener from .event_listener import EventListener

View File

@@ -11,6 +11,7 @@ from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent, LLMCallCompletedEvent,
LLMCallFailedEvent, LLMCallFailedEvent,
LLMCallStartedEvent, LLMCallStartedEvent,
LLMStreamChunkEvent,
) )
from .agent_events import AgentExecutionCompletedEvent, AgentExecutionStartedEvent from .agent_events import AgentExecutionCompletedEvent, AgentExecutionStartedEvent
@@ -280,7 +281,14 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(LLMCallFailedEvent) @crewai_event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(source, event: LLMCallFailedEvent): def on_llm_call_failed(source, event: LLMCallFailedEvent):
self.logger.log( self.logger.log(
f"❌ LLM Call Failed: '{event.error}'", f"❌ LLM call failed: {event.error}",
event.timestamp,
)
@crewai_event_bus.on(LLMStreamChunkEvent)
def on_llm_stream_chunk(source, event: LLMStreamChunkEvent):
self.logger.log(
f"📝 LLM stream chunk received",
event.timestamp, event.timestamp,
) )

View File

@@ -34,3 +34,10 @@ class LLMCallFailedEvent(CrewEvent):
error: str error: str
type: str = "llm_call_failed" type: str = "llm_call_failed"
class LLMStreamChunkEvent(CrewEvent):
"""Event emitted when a streaming chunk is received"""
type: str = "llm_stream_chunk"
chunk: str