From e8707e15efe902e4488904d012415a9ea04e7360 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Thu, 6 Mar 2025 11:58:10 -0500 Subject: [PATCH] improve logic --- src/crewai/llm.py | 106 +++++++++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 40 deletions(-) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index f739e7203..b7f8f3dc9 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -439,6 +439,7 @@ class LLM: ) # --- 7) Check for tool calls in the final response + tool_calls = None try: if last_chunk: choices = None @@ -458,53 +459,30 @@ class LLM: message = getattr(choice, "message") if message: - tool_calls = None if isinstance(message, dict) and "tool_calls" in message: tool_calls = message["tool_calls"] elif hasattr(message, "tool_calls"): tool_calls = getattr(message, "tool_calls") - - if tool_calls: - tool_result = self._handle_tool_call( - tool_calls, available_functions - ) - if tool_result is not None: - return tool_result except Exception as e: logging.debug(f"Error checking for tool calls: {e}") - # --- 8) Log token usage if available in streaming mode - # Safely handle callbacks with usage info - if callbacks and len(callbacks) > 0: - for callback in callbacks: - if hasattr(callback, "log_success_event"): - # Use the usage_info we've been tracking - if not usage_info: - # Try to get usage from the last chunk if we haven't already - try: - if last_chunk: - if ( - isinstance(last_chunk, dict) - and "usage" in last_chunk - ): - usage_info = last_chunk["usage"] - elif hasattr(last_chunk, "usage"): - if not isinstance( - getattr(last_chunk, "usage"), type - ): - usage_info = getattr(last_chunk, "usage") - except Exception as e: - logging.debug(f"Error extracting usage info: {e}") + # --- 8) If no tool calls or no available functions, return the text response directly + if not tool_calls or not available_functions: + # Log token usage if available in streaming mode + self._handle_streaming_callbacks(callbacks, usage_info, last_chunk) + # Emit completion event and return response + self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL) + return full_response - if usage_info: - callback.log_success_event( - kwargs=params, - response_obj={"usage": usage_info}, - start_time=0, - end_time=0, - ) + # --- 9) Handle tool calls if present + tool_result = self._handle_tool_call(tool_calls, available_functions) + if tool_result is not None: + return tool_result - # --- 9) Emit completion event and return response + # --- 10) Log token usage if available in streaming mode + self._handle_streaming_callbacks(callbacks, usage_info, last_chunk) + + # --- 11) Emit completion event and return response self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL) return full_response @@ -522,6 +500,48 @@ class LLM: ) raise Exception(f"Failed to get streaming response: {str(e)}") + def _handle_streaming_callbacks( + self, + callbacks: Optional[List[Any]], + usage_info: Optional[Dict[str, Any]], + last_chunk: Optional[Any], + ) -> None: + """Handle callbacks with usage info for streaming responses. + + Args: + callbacks: Optional list of callback functions + usage_info: Usage information collected during streaming + last_chunk: The last chunk received from the streaming response + """ + if callbacks and len(callbacks) > 0: + for callback in callbacks: + if hasattr(callback, "log_success_event"): + # Use the usage_info we've been tracking + if not usage_info: + # Try to get usage from the last chunk if we haven't already + try: + if last_chunk: + if ( + isinstance(last_chunk, dict) + and "usage" in last_chunk + ): + usage_info = last_chunk["usage"] + elif hasattr(last_chunk, "usage"): + if not isinstance( + getattr(last_chunk, "usage"), type + ): + usage_info = getattr(last_chunk, "usage") + except Exception as e: + logging.debug(f"Error extracting usage info: {e}") + + if usage_info: + callback.log_success_event( + kwargs={}, # We don't have the original params here + response_obj={"usage": usage_info}, + start_time=0, + end_time=0, + ) + def _handle_non_streaming_response( self, params: Dict[str, Any], @@ -532,6 +552,7 @@ class LLM: Args: params: Parameters for the completion call + callbacks: Optional list of callback functions available_functions: Dict of available functions Returns: @@ -562,12 +583,17 @@ class LLM: # --- 4) Check for tool calls tool_calls = getattr(response_message, "tool_calls", []) - # --- 5) Handle tool calls if present + # --- 5) If no tool calls or no available functions, return the text response directly + if not tool_calls or not available_functions: + self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL) + return text_response + + # --- 6) Handle tool calls if present tool_result = self._handle_tool_call(tool_calls, available_functions) if tool_result is not None: return tool_result - # --- 6) Emit completion event and return response + # --- 7) If tool call handling didn't return a result, emit completion event and return text response self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL) return text_response