Initial Stream working

This commit is contained in:
Brandon Hancock
2025-03-03 13:10:33 -05:00
parent a3d5c86218
commit 143832bd8b
5 changed files with 398 additions and 122 deletions

View File

@@ -224,6 +224,7 @@ CrewAI provides a wide range of events that you can listen for:
- **LLMCallStartedEvent**: Emitted when an LLM call starts
- **LLMCallCompletedEvent**: Emitted when an LLM call completes
- **LLMCallFailedEvent**: Emitted when an LLM call fails
- **LLMStreamChunkEvent**: Emitted for each chunk received during streaming LLM responses
## Event Handler Structure

View File

@@ -5,7 +5,17 @@ import sys
import threading
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Type,
TypedDict,
Union,
cast,
)
from dotenv import load_dotenv
from pydantic import BaseModel
@@ -15,6 +25,7 @@ from crewai.utilities.events.llm_events import (
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
LLMStreamChunkEvent,
)
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
@@ -126,6 +137,17 @@ def suppress_warnings():
sys.stderr = old_stderr
class Delta(TypedDict):
content: Optional[str]
role: Optional[str]
class StreamingChoices(TypedDict):
delta: Delta
index: int
finish_reason: Optional[str]
class LLM:
def __init__(
self,
@@ -150,6 +172,7 @@ class LLM:
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
stream: bool = False,
**kwargs,
):
self.model = model
@@ -175,6 +198,7 @@ class LLM:
self.reasoning_effort = reasoning_effort
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
litellm.drop_params = True
@@ -201,6 +225,322 @@ class LLM:
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
"""Prepare parameters for the completion call.
Args:
messages: Input messages for the LLM
tools: Optional list of tool schemas
callbacks: Optional list of callback functions
available_functions: Optional dict of available functions
Returns:
Dict[str, Any]: Parameters for the completion call
"""
# --- 1) Format messages according to provider requirements
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
formatted_messages = self._format_messages_for_provider(messages)
# --- 2) Prepare the parameters for the completion call
params = {
"model": self.model,
"messages": formatted_messages,
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.api_base,
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": self.stream,
"tools": tools,
"reasoning_effort": self.reasoning_effort,
**self.additional_params,
}
# Remove None values from params
return {k: v for k, v in params.items() if v is not None}
def _handle_streaming_response(
self,
params: Dict[str, Any],
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""Handle a streaming response from the LLM.
Args:
params: Parameters for the completion call
available_functions: Dict of available functions
Returns:
str: The complete response text
"""
# --- 1) Initialize response tracking
full_response = ""
last_chunk = None
chunk_count = 0
debug_info = []
# --- 2) Make sure stream is set to True
params["stream"] = True
try:
# --- 3) Process each chunk in the stream
for chunk in litellm.completion(**params):
chunk_count += 1
last_chunk = chunk
# Add debug info
debug_info.append(f"Chunk type: {type(chunk)}")
# Extract content from the chunk
chunk_content = None
# Handle ModelResponse objects
if isinstance(chunk, ModelResponse):
debug_info.append("Chunk is ModelResponse")
choices = getattr(chunk, "choices", [])
if choices and len(choices) > 0:
choice = choices[0]
debug_info.append(f"Choice type: {type(choice)}")
# Handle dictionary-style choices
if isinstance(choice, dict):
delta = choice.get("delta", {})
debug_info.append(f"Delta: {delta}")
if (
isinstance(delta, dict)
and "content" in delta
and delta["content"] is not None
):
chunk_content = delta["content"]
# Handle object-style choices
else:
# Try to access delta attribute safely
delta = getattr(choice, "delta", None)
debug_info.append(f"Delta: {delta}")
if delta is not None:
# Try to get content from delta.content
if (
hasattr(delta, "content")
and getattr(delta, "content", None) is not None
):
chunk_content = getattr(delta, "content")
# Some models return delta as a string
elif isinstance(delta, str):
chunk_content = delta
# Add content to response if found
if chunk_content:
full_response += chunk_content
print(f"Chunk content: {chunk_content}")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(chunk=chunk_content),
)
else:
debug_info.append(f"No content found in chunk: {chunk}")
# --- 4) Fallback to non-streaming if no content received
if not full_response.strip() and chunk_count == 0:
logging.warning(
"No chunks received in streaming response, falling back to non-streaming"
)
# Try non-streaming as fallback
non_streaming_params = params.copy()
non_streaming_params["stream"] = False
return self._handle_non_streaming_response(
non_streaming_params, available_functions
)
# --- 5) Handle empty response with chunks
if not full_response.strip() and chunk_count > 0:
logging.warning(
f"Received {chunk_count} chunks but no content. Debug info: {debug_info}"
)
if last_chunk is not None:
# Try to extract any content from the last chunk
if isinstance(last_chunk, ModelResponse):
choices = getattr(last_chunk, "choices", [])
if choices and len(choices) > 0:
choice = choices[0]
# Try to get content from message
message = getattr(choice, "message", None)
if message is not None and getattr(
message, "content", None
):
full_response = getattr(message, "content")
logging.info(
f"Extracted content from last chunk message: {full_response}"
)
# Try to get content from text (some models use this)
elif getattr(choice, "text", None):
full_response = getattr(choice, "text")
logging.info(
f"Extracted text from last chunk: {full_response}"
)
# --- 6) If still empty, use a default response
if not full_response.strip():
logging.warning("Using default response as fallback")
full_response = "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
# --- 7) Check for tool calls in the final response
if isinstance(last_chunk, ModelResponse):
choices = getattr(last_chunk, "choices", [])
if choices and len(choices) > 0:
choice = choices[0]
message = getattr(choice, "message", None)
if message is not None:
tool_calls = getattr(message, "tool_calls", [])
tool_result = self._handle_tool_call(
tool_calls, available_functions
)
if tool_result is not None:
return tool_result
# --- 8) Emit completion event and return response
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response
except Exception as e:
logging.error(
f"Error in streaming response: {str(e)}, Debug info: {debug_info}"
)
# If we have any response content, return it instead of failing
if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}")
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response
# Try non-streaming as fallback
try:
logging.warning("Falling back to non-streaming after error")
non_streaming_params = params.copy()
non_streaming_params["stream"] = False
return self._handle_non_streaming_response(
non_streaming_params, available_functions
)
except Exception as fallback_error:
logging.error(
f"Fallback to non-streaming also failed: {str(fallback_error)}"
)
# Return a default response as last resort
default_response = "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
self._handle_emit_call_events(default_response, LLMCallType.LLM_CALL)
return default_response
def _handle_non_streaming_response(
self,
params: Dict[str, Any],
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""Handle a non-streaming response from the LLM.
Args:
params: Parameters for the completion call
available_functions: Dict of available functions
Returns:
str: The response text
"""
# --- 1) Make the completion call
response = litellm.completion(**params)
# --- 2) Extract response message and content
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
# --- 3) Check for tool calls
tool_calls = getattr(response_message, "tool_calls", [])
# --- 4) Handle tool calls if present
tool_result = self._handle_tool_call(tool_calls, available_functions)
if tool_result is not None:
return tool_result
# --- 5) Emit completion event and return response
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
return text_response
def _handle_tool_call(
self,
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
"""Handle a tool call from the LLM.
Args:
tool_calls: List of tool calls from the LLM
available_functions: Dict of available functions
Returns:
Optional[str]: The result of the tool call, or None if no tool call was made
"""
# --- 1) Validate tool calls and available functions
if not tool_calls or not available_functions:
return None
# --- 2) Extract function name from first tool call
tool_call = tool_calls[0]
function_name = tool_call.function.name
function_args = {} # Initialize to empty dict to avoid unbound variable
# --- 3) Check if function is available
if function_name in available_functions:
try:
# --- 3.1) Parse function arguments
function_args = json.loads(tool_call.function.arguments)
fn = available_functions[function_name]
# --- 3.2) Execute function
result = fn(**function_args)
# --- 3.3) Emit success event
self._handle_emit_call_events(result, LLMCallType.TOOL_CALL)
return result
except Exception as e:
# --- 3.4) Handle execution errors
fn = available_functions.get(
function_name, lambda: None
) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}")
crewai_event_bus.emit(
self,
event=ToolExecutionErrorEvent(
tool_name=function_name,
tool_args=function_args,
tool_class=fn,
error=str(e),
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
)
return None
def call(
self,
messages: Union[str, List[Dict[str, str]]],
@@ -230,22 +570,8 @@ class LLM:
TypeError: If messages format is invalid
ValueError: If response format is not supported
LLMContextLengthExceededException: If input exceeds model's context limit
Examples:
# Example 1: Simple string input
>>> response = llm.call("Return the name of a random city.")
>>> print(response)
"Paris"
# Example 2: Message list with system and user messages
>>> messages = [
... {"role": "system", "content": "You are a geography expert"},
... {"role": "user", "content": "What is France's capital?"}
... ]
>>> response = llm.call(messages)
>>> print(response)
"The capital of France is Paris."
"""
# --- 1) Emit call started event
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
@@ -255,127 +581,36 @@ class LLM:
available_functions=available_functions,
),
)
# Validate parameters before proceeding with the call.
# --- 2) Validate parameters before proceeding with the call
self._validate_call_params()
# --- 3) Convert string messages to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# For O1 models, system messages are not supported.
# Convert any system messages into assistant messages.
# --- 4) Handle O1 model special case (system messages not supported)
if "o1" in self.model.lower():
for message in messages:
if message.get("role") == "system":
message["role"] = "assistant"
# --- 5) Set up callbacks if provided
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 1) Format messages according to provider requirements
formatted_messages = self._format_messages_for_provider(messages)
# --- 2) Prepare the parameters for the completion call
params = {
"model": self.model,
"messages": formatted_messages,
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.api_base,
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,
"tools": tools,
"reasoning_effort": self.reasoning_effort,
**self.additional_params,
}
# Remove None values from params
params = {k: v for k, v in params.items() if v is not None}
# --- 2) Make the completion call
response = litellm.completion(**params)
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
tool_calls = getattr(response_message, "tool_calls", [])
# --- 3) Handle callbacks with usage info
if callbacks and len(callbacks) > 0:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
usage_info = getattr(response, "usage", None)
if usage_info:
callback.log_success_event(
kwargs=params,
response_obj={"usage": usage_info},
start_time=0,
end_time=0,
)
# --- 4) If no tool calls, return the text response
if not tool_calls or not available_functions:
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
return text_response
# --- 5) Handle the tool call
tool_call = tool_calls[0]
function_name = tool_call.function.name
if function_name in available_functions:
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.warning(f"Failed to parse function arguments: {e}")
return text_response
fn = available_functions[function_name]
try:
# Call the actual tool function
result = fn(**function_args)
self._handle_emit_call_events(result, LLMCallType.TOOL_CALL)
return result
except Exception as e:
logging.error(
f"Error executing function '{function_name}': {e}"
)
crewai_event_bus.emit(
self,
event=ToolExecutionErrorEvent(
tool_name=function_name,
tool_args=function_args,
tool_class=fn,
error=str(e),
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=f"Tool execution error: {str(e)}"
),
)
return text_response
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
return self._handle_streaming_response(params, available_functions)
else:
logging.warning(
f"Tool call requested unknown function '{function_name}'"
return self._handle_non_streaming_response(
params, available_functions
)
return text_response
except Exception as e:
crewai_event_bus.emit(
@@ -426,6 +661,20 @@ class LLM:
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
)
# Handle O1 models specially
if "o1" in self.model.lower():
formatted_messages = []
for msg in messages:
# Convert system messages to assistant messages
if msg["role"] == "system":
formatted_messages.append(
{"role": "assistant", "content": msg["content"]}
)
else:
formatted_messages.append(msg)
return formatted_messages
# Handle Anthropic models
if not self.is_anthropic:
return messages

View File

@@ -14,7 +14,12 @@ from .agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
)
from .task_events import TaskStartedEvent, TaskCompletedEvent, TaskFailedEvent, TaskEvaluationEvent
from .task_events import (
TaskStartedEvent,
TaskCompletedEvent,
TaskFailedEvent,
TaskEvaluationEvent,
)
from .flow_events import (
FlowCreatedEvent,
FlowStartedEvent,
@@ -34,7 +39,13 @@ from .tool_usage_events import (
ToolUsageEvent,
ToolValidateInputErrorEvent,
)
from .llm_events import LLMCallCompletedEvent, LLMCallFailedEvent, LLMCallStartedEvent
from .llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
LLMStreamChunkEvent,
)
# events
from .event_listener import EventListener

View File

@@ -11,6 +11,7 @@ from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMStreamChunkEvent,
)
from .agent_events import AgentExecutionCompletedEvent, AgentExecutionStartedEvent
@@ -280,7 +281,14 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(source, event: LLMCallFailedEvent):
self.logger.log(
f"❌ LLM Call Failed: '{event.error}'",
f"❌ LLM call failed: {event.error}",
event.timestamp,
)
@crewai_event_bus.on(LLMStreamChunkEvent)
def on_llm_stream_chunk(source, event: LLMStreamChunkEvent):
self.logger.log(
f"📝 LLM stream chunk received",
event.timestamp,
)

View File

@@ -34,3 +34,10 @@ class LLMCallFailedEvent(CrewEvent):
error: str
type: str = "llm_call_failed"
class LLMStreamChunkEvent(CrewEvent):
"""Event emitted when a streaming chunk is received"""
type: str = "llm_stream_chunk"
chunk: str