From 143832bd8b82bdf0b6ef39902b28b6e1ec583140 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Mon, 3 Mar 2025 13:10:33 -0500 Subject: [PATCH] Initial Stream working --- docs/concepts/event-listner.mdx | 1 + src/crewai/llm.py | 487 +++++++++++++----- src/crewai/utilities/events/__init__.py | 15 +- src/crewai/utilities/events/event_listener.py | 10 +- src/crewai/utilities/events/llm_events.py | 7 + 5 files changed, 398 insertions(+), 122 deletions(-) diff --git a/docs/concepts/event-listner.mdx b/docs/concepts/event-listner.mdx index 7fdeec485..1439e1456 100644 --- a/docs/concepts/event-listner.mdx +++ b/docs/concepts/event-listner.mdx @@ -224,6 +224,7 @@ CrewAI provides a wide range of events that you can listen for: - **LLMCallStartedEvent**: Emitted when an LLM call starts - **LLMCallCompletedEvent**: Emitted when an LLM call completes - **LLMCallFailedEvent**: Emitted when an LLM call fails +- **LLMStreamChunkEvent**: Emitted for each chunk received during streaming LLM responses ## Event Handler Structure diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 0c8a46214..ec4306bc4 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -5,7 +5,17 @@ import sys import threading import warnings from contextlib import contextmanager -from typing import Any, Dict, List, Literal, Optional, Type, Union, cast +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Type, + TypedDict, + Union, + cast, +) from dotenv import load_dotenv from pydantic import BaseModel @@ -15,6 +25,7 @@ from crewai.utilities.events.llm_events import ( LLMCallFailedEvent, LLMCallStartedEvent, LLMCallType, + LLMStreamChunkEvent, ) from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent @@ -126,6 +137,17 @@ def suppress_warnings(): sys.stderr = old_stderr +class Delta(TypedDict): + content: Optional[str] + role: Optional[str] + + +class StreamingChoices(TypedDict): + delta: Delta + index: int + finish_reason: Optional[str] + + class LLM: def __init__( self, @@ -150,6 +172,7 @@ class LLM: api_key: Optional[str] = None, callbacks: List[Any] = [], reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, + stream: bool = False, **kwargs, ): self.model = model @@ -175,6 +198,7 @@ class LLM: self.reasoning_effort = reasoning_effort self.additional_params = kwargs self.is_anthropic = self._is_anthropic_model(model) + self.stream = stream litellm.drop_params = True @@ -201,6 +225,322 @@ class LLM: ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/") return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) + def _prepare_completion_params( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + ) -> Dict[str, Any]: + """Prepare parameters for the completion call. + + Args: + messages: Input messages for the LLM + tools: Optional list of tool schemas + callbacks: Optional list of callback functions + available_functions: Optional dict of available functions + + Returns: + Dict[str, Any]: Parameters for the completion call + """ + # --- 1) Format messages according to provider requirements + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + formatted_messages = self._format_messages_for_provider(messages) + + # --- 2) Prepare the parameters for the completion call + params = { + "model": self.model, + "messages": formatted_messages, + "timeout": self.timeout, + "temperature": self.temperature, + "top_p": self.top_p, + "n": self.n, + "stop": self.stop, + "max_tokens": self.max_tokens or self.max_completion_tokens, + "presence_penalty": self.presence_penalty, + "frequency_penalty": self.frequency_penalty, + "logit_bias": self.logit_bias, + "response_format": self.response_format, + "seed": self.seed, + "logprobs": self.logprobs, + "top_logprobs": self.top_logprobs, + "api_base": self.api_base, + "base_url": self.base_url, + "api_version": self.api_version, + "api_key": self.api_key, + "stream": self.stream, + "tools": tools, + "reasoning_effort": self.reasoning_effort, + **self.additional_params, + } + + # Remove None values from params + return {k: v for k, v in params.items() if v is not None} + + def _handle_streaming_response( + self, + params: Dict[str, Any], + available_functions: Optional[Dict[str, Any]] = None, + ) -> str: + """Handle a streaming response from the LLM. + + Args: + params: Parameters for the completion call + available_functions: Dict of available functions + + Returns: + str: The complete response text + """ + # --- 1) Initialize response tracking + full_response = "" + last_chunk = None + chunk_count = 0 + debug_info = [] + + # --- 2) Make sure stream is set to True + params["stream"] = True + + try: + # --- 3) Process each chunk in the stream + for chunk in litellm.completion(**params): + chunk_count += 1 + last_chunk = chunk + + # Add debug info + debug_info.append(f"Chunk type: {type(chunk)}") + + # Extract content from the chunk + chunk_content = None + + # Handle ModelResponse objects + if isinstance(chunk, ModelResponse): + debug_info.append("Chunk is ModelResponse") + choices = getattr(chunk, "choices", []) + if choices and len(choices) > 0: + choice = choices[0] + debug_info.append(f"Choice type: {type(choice)}") + + # Handle dictionary-style choices + if isinstance(choice, dict): + delta = choice.get("delta", {}) + debug_info.append(f"Delta: {delta}") + if ( + isinstance(delta, dict) + and "content" in delta + and delta["content"] is not None + ): + chunk_content = delta["content"] + + # Handle object-style choices + else: + # Try to access delta attribute safely + delta = getattr(choice, "delta", None) + debug_info.append(f"Delta: {delta}") + + if delta is not None: + # Try to get content from delta.content + if ( + hasattr(delta, "content") + and getattr(delta, "content", None) is not None + ): + chunk_content = getattr(delta, "content") + # Some models return delta as a string + elif isinstance(delta, str): + chunk_content = delta + + # Add content to response if found + if chunk_content: + full_response += chunk_content + print(f"Chunk content: {chunk_content}") + crewai_event_bus.emit( + self, + event=LLMStreamChunkEvent(chunk=chunk_content), + ) + else: + debug_info.append(f"No content found in chunk: {chunk}") + + # --- 4) Fallback to non-streaming if no content received + if not full_response.strip() and chunk_count == 0: + logging.warning( + "No chunks received in streaming response, falling back to non-streaming" + ) + # Try non-streaming as fallback + non_streaming_params = params.copy() + non_streaming_params["stream"] = False + return self._handle_non_streaming_response( + non_streaming_params, available_functions + ) + + # --- 5) Handle empty response with chunks + if not full_response.strip() and chunk_count > 0: + logging.warning( + f"Received {chunk_count} chunks but no content. Debug info: {debug_info}" + ) + if last_chunk is not None: + # Try to extract any content from the last chunk + if isinstance(last_chunk, ModelResponse): + choices = getattr(last_chunk, "choices", []) + if choices and len(choices) > 0: + choice = choices[0] + + # Try to get content from message + message = getattr(choice, "message", None) + if message is not None and getattr( + message, "content", None + ): + full_response = getattr(message, "content") + logging.info( + f"Extracted content from last chunk message: {full_response}" + ) + + # Try to get content from text (some models use this) + elif getattr(choice, "text", None): + full_response = getattr(choice, "text") + logging.info( + f"Extracted text from last chunk: {full_response}" + ) + + # --- 6) If still empty, use a default response + if not full_response.strip(): + logging.warning("Using default response as fallback") + full_response = "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request." + + # --- 7) Check for tool calls in the final response + if isinstance(last_chunk, ModelResponse): + choices = getattr(last_chunk, "choices", []) + if choices and len(choices) > 0: + choice = choices[0] + message = getattr(choice, "message", None) + if message is not None: + tool_calls = getattr(message, "tool_calls", []) + tool_result = self._handle_tool_call( + tool_calls, available_functions + ) + if tool_result is not None: + return tool_result + + # --- 8) Emit completion event and return response + self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL) + return full_response + + except Exception as e: + logging.error( + f"Error in streaming response: {str(e)}, Debug info: {debug_info}" + ) + # If we have any response content, return it instead of failing + if full_response.strip(): + logging.warning(f"Returning partial response despite error: {str(e)}") + self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL) + return full_response + + # Try non-streaming as fallback + try: + logging.warning("Falling back to non-streaming after error") + non_streaming_params = params.copy() + non_streaming_params["stream"] = False + return self._handle_non_streaming_response( + non_streaming_params, available_functions + ) + except Exception as fallback_error: + logging.error( + f"Fallback to non-streaming also failed: {str(fallback_error)}" + ) + # Return a default response as last resort + default_response = "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request." + self._handle_emit_call_events(default_response, LLMCallType.LLM_CALL) + return default_response + + def _handle_non_streaming_response( + self, + params: Dict[str, Any], + available_functions: Optional[Dict[str, Any]] = None, + ) -> str: + """Handle a non-streaming response from the LLM. + + Args: + params: Parameters for the completion call + available_functions: Dict of available functions + + Returns: + str: The response text + """ + # --- 1) Make the completion call + response = litellm.completion(**params) + + # --- 2) Extract response message and content + response_message = cast(Choices, cast(ModelResponse, response).choices)[ + 0 + ].message + text_response = response_message.content or "" + + # --- 3) Check for tool calls + tool_calls = getattr(response_message, "tool_calls", []) + + # --- 4) Handle tool calls if present + tool_result = self._handle_tool_call(tool_calls, available_functions) + if tool_result is not None: + return tool_result + + # --- 5) Emit completion event and return response + self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL) + return text_response + + def _handle_tool_call( + self, + tool_calls: List[Any], + available_functions: Optional[Dict[str, Any]] = None, + ) -> Optional[str]: + """Handle a tool call from the LLM. + + Args: + tool_calls: List of tool calls from the LLM + available_functions: Dict of available functions + + Returns: + Optional[str]: The result of the tool call, or None if no tool call was made + """ + # --- 1) Validate tool calls and available functions + if not tool_calls or not available_functions: + return None + + # --- 2) Extract function name from first tool call + tool_call = tool_calls[0] + function_name = tool_call.function.name + function_args = {} # Initialize to empty dict to avoid unbound variable + + # --- 3) Check if function is available + if function_name in available_functions: + try: + # --- 3.1) Parse function arguments + function_args = json.loads(tool_call.function.arguments) + fn = available_functions[function_name] + + # --- 3.2) Execute function + result = fn(**function_args) + + # --- 3.3) Emit success event + self._handle_emit_call_events(result, LLMCallType.TOOL_CALL) + return result + except Exception as e: + # --- 3.4) Handle execution errors + fn = available_functions.get( + function_name, lambda: None + ) # Ensure fn is always a callable + logging.error(f"Error executing function '{function_name}': {e}") + crewai_event_bus.emit( + self, + event=ToolExecutionErrorEvent( + tool_name=function_name, + tool_args=function_args, + tool_class=fn, + error=str(e), + ), + ) + crewai_event_bus.emit( + self, + event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"), + ) + return None + def call( self, messages: Union[str, List[Dict[str, str]]], @@ -230,22 +570,8 @@ class LLM: TypeError: If messages format is invalid ValueError: If response format is not supported LLMContextLengthExceededException: If input exceeds model's context limit - - Examples: - # Example 1: Simple string input - >>> response = llm.call("Return the name of a random city.") - >>> print(response) - "Paris" - - # Example 2: Message list with system and user messages - >>> messages = [ - ... {"role": "system", "content": "You are a geography expert"}, - ... {"role": "user", "content": "What is France's capital?"} - ... ] - >>> response = llm.call(messages) - >>> print(response) - "The capital of France is Paris." """ + # --- 1) Emit call started event crewai_event_bus.emit( self, event=LLMCallStartedEvent( @@ -255,127 +581,36 @@ class LLM: available_functions=available_functions, ), ) - # Validate parameters before proceeding with the call. + + # --- 2) Validate parameters before proceeding with the call self._validate_call_params() + # --- 3) Convert string messages to proper format if needed if isinstance(messages, str): messages = [{"role": "user", "content": messages}] - # For O1 models, system messages are not supported. - # Convert any system messages into assistant messages. + # --- 4) Handle O1 model special case (system messages not supported) if "o1" in self.model.lower(): for message in messages: if message.get("role") == "system": message["role"] = "assistant" + # --- 5) Set up callbacks if provided with suppress_warnings(): if callbacks and len(callbacks) > 0: self.set_callbacks(callbacks) try: - # --- 1) Format messages according to provider requirements - formatted_messages = self._format_messages_for_provider(messages) - - # --- 2) Prepare the parameters for the completion call - params = { - "model": self.model, - "messages": formatted_messages, - "timeout": self.timeout, - "temperature": self.temperature, - "top_p": self.top_p, - "n": self.n, - "stop": self.stop, - "max_tokens": self.max_tokens or self.max_completion_tokens, - "presence_penalty": self.presence_penalty, - "frequency_penalty": self.frequency_penalty, - "logit_bias": self.logit_bias, - "response_format": self.response_format, - "seed": self.seed, - "logprobs": self.logprobs, - "top_logprobs": self.top_logprobs, - "api_base": self.api_base, - "base_url": self.base_url, - "api_version": self.api_version, - "api_key": self.api_key, - "stream": False, - "tools": tools, - "reasoning_effort": self.reasoning_effort, - **self.additional_params, - } - - # Remove None values from params - params = {k: v for k, v in params.items() if v is not None} - - # --- 2) Make the completion call - response = litellm.completion(**params) - response_message = cast(Choices, cast(ModelResponse, response).choices)[ - 0 - ].message - text_response = response_message.content or "" - tool_calls = getattr(response_message, "tool_calls", []) - - # --- 3) Handle callbacks with usage info - if callbacks and len(callbacks) > 0: - for callback in callbacks: - if hasattr(callback, "log_success_event"): - usage_info = getattr(response, "usage", None) - if usage_info: - callback.log_success_event( - kwargs=params, - response_obj={"usage": usage_info}, - start_time=0, - end_time=0, - ) - - # --- 4) If no tool calls, return the text response - if not tool_calls or not available_functions: - self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL) - return text_response - - # --- 5) Handle the tool call - tool_call = tool_calls[0] - function_name = tool_call.function.name - - if function_name in available_functions: - try: - function_args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError as e: - logging.warning(f"Failed to parse function arguments: {e}") - return text_response - - fn = available_functions[function_name] - try: - # Call the actual tool function - result = fn(**function_args) - self._handle_emit_call_events(result, LLMCallType.TOOL_CALL) - return result - - except Exception as e: - logging.error( - f"Error executing function '{function_name}': {e}" - ) - crewai_event_bus.emit( - self, - event=ToolExecutionErrorEvent( - tool_name=function_name, - tool_args=function_args, - tool_class=fn, - error=str(e), - ), - ) - crewai_event_bus.emit( - self, - event=LLMCallFailedEvent( - error=f"Tool execution error: {str(e)}" - ), - ) - return text_response + # --- 6) Prepare parameters for the completion call + params = self._prepare_completion_params(messages, tools) + # --- 7) Make the completion call and handle response + if self.stream: + return self._handle_streaming_response(params, available_functions) else: - logging.warning( - f"Tool call requested unknown function '{function_name}'" + return self._handle_non_streaming_response( + params, available_functions ) - return text_response except Exception as e: crewai_event_bus.emit( @@ -426,6 +661,20 @@ class LLM: "Invalid message format. Each message must be a dict with 'role' and 'content' keys" ) + # Handle O1 models specially + if "o1" in self.model.lower(): + formatted_messages = [] + for msg in messages: + # Convert system messages to assistant messages + if msg["role"] == "system": + formatted_messages.append( + {"role": "assistant", "content": msg["content"]} + ) + else: + formatted_messages.append(msg) + return formatted_messages + + # Handle Anthropic models if not self.is_anthropic: return messages diff --git a/src/crewai/utilities/events/__init__.py b/src/crewai/utilities/events/__init__.py index aa4a24ac5..264f0ac5e 100644 --- a/src/crewai/utilities/events/__init__.py +++ b/src/crewai/utilities/events/__init__.py @@ -14,7 +14,12 @@ from .agent_events import ( AgentExecutionCompletedEvent, AgentExecutionErrorEvent, ) -from .task_events import TaskStartedEvent, TaskCompletedEvent, TaskFailedEvent, TaskEvaluationEvent +from .task_events import ( + TaskStartedEvent, + TaskCompletedEvent, + TaskFailedEvent, + TaskEvaluationEvent, +) from .flow_events import ( FlowCreatedEvent, FlowStartedEvent, @@ -34,7 +39,13 @@ from .tool_usage_events import ( ToolUsageEvent, ToolValidateInputErrorEvent, ) -from .llm_events import LLMCallCompletedEvent, LLMCallFailedEvent, LLMCallStartedEvent +from .llm_events import ( + LLMCallCompletedEvent, + LLMCallFailedEvent, + LLMCallStartedEvent, + LLMCallType, + LLMStreamChunkEvent, +) # events from .event_listener import EventListener diff --git a/src/crewai/utilities/events/event_listener.py b/src/crewai/utilities/events/event_listener.py index d853a5f7c..5b18b79ab 100644 --- a/src/crewai/utilities/events/event_listener.py +++ b/src/crewai/utilities/events/event_listener.py @@ -11,6 +11,7 @@ from crewai.utilities.events.llm_events import ( LLMCallCompletedEvent, LLMCallFailedEvent, LLMCallStartedEvent, + LLMStreamChunkEvent, ) from .agent_events import AgentExecutionCompletedEvent, AgentExecutionStartedEvent @@ -280,7 +281,14 @@ class EventListener(BaseEventListener): @crewai_event_bus.on(LLMCallFailedEvent) def on_llm_call_failed(source, event: LLMCallFailedEvent): self.logger.log( - f"❌ LLM Call Failed: '{event.error}'", + f"❌ LLM call failed: {event.error}", + event.timestamp, + ) + + @crewai_event_bus.on(LLMStreamChunkEvent) + def on_llm_stream_chunk(source, event: LLMStreamChunkEvent): + self.logger.log( + f"📝 LLM stream chunk received", event.timestamp, ) diff --git a/src/crewai/utilities/events/llm_events.py b/src/crewai/utilities/events/llm_events.py index 8c2554a21..988b6f945 100644 --- a/src/crewai/utilities/events/llm_events.py +++ b/src/crewai/utilities/events/llm_events.py @@ -34,3 +34,10 @@ class LLMCallFailedEvent(CrewEvent): error: str type: str = "llm_call_failed" + + +class LLMStreamChunkEvent(CrewEvent): + """Event emitted when a streaming chunk is received""" + + type: str = "llm_stream_chunk" + chunk: str