properly return tool call result

This commit is contained in:
Brandon Hancock
2024-12-30 13:32:53 -05:00
parent 5da6d36dd9
commit bcd838a2ff
3 changed files with 67 additions and 55 deletions

View File

@@ -5,12 +5,13 @@ import sys
import threading
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast
# Load environment variables from .env file
import litellm
from dotenv import load_dotenv
from litellm import get_supported_openai_params
from litellm import Choices, get_supported_openai_params
from litellm.types.utils import ModelResponse
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
@@ -209,24 +210,23 @@ class LLM:
High-level call method that:
1) Calls litellm.completion
2) Checks for function/tool calls
3) If tool calls found:
a) executes each function
b) appends their output as tool messages
c) calls litellm.completion again with the updated messages
4) Returns the final text response
3) If a tool call is found:
a) executes the function
b) returns the result
4) If no tool call, returns the text response
:param messages: The conversation messages
:param tools: Optional list of function schemas for function calling
:param callbacks: Optional list of callbacks
:param available_functions: A dictionary mapping function_name -> actual Python function
:return: Final text response from the LLM
:return: Final text response from the LLM or the tool result
"""
with suppress_warnings():
if callbacks:
self.set_callbacks(callbacks)
try:
# --- 1) Make first completion call
# --- 1) Make the completion call
params = {
"model": self.model,
"messages": messages,
@@ -250,68 +250,61 @@ class LLM:
"tools": tools, # pass the tool schema
}
# remove None values
# Remove None values
params = {k: v for k, v in params.items() if v is not None}
print(f"Params: {params}")
response = litellm.completion(**params)
response_message = response.choices[0].message
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
tool_calls = getattr(response_message, "tool_calls", [])
# --- 2) If no tool calls, we can just return
# --- 2) If no tool calls, return the text response
if not tool_calls or not available_functions:
return text_response
# --- 3) We have tool calls and a dictionary of available functions
# run them, append output to messages
for tool_call in tool_calls:
function_name = tool_call.function.name
if function_name in available_functions:
# parse arguments
function_args = {}
try:
function_args = json.loads(tool_call.function.arguments)
except Exception as e:
logging.warning(f"Failed to parse function arguments: {e}")
# --- 3) Handle the tool call
tool_call = tool_calls[0]
function_name = tool_call.function.name
fn = available_functions[function_name]
# call the actual tool function
if function_name in available_functions:
# Parse arguments
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.warning(f"Failed to parse function arguments: {e}")
return text_response # Fallback to text response
fn = available_functions[function_name]
try:
# Call the actual tool function
result = fn(**function_args)
# append the "tool" response to messages
messages.append(
{
"tool_call_id": tool_call.id,
"role": "function",
"name": function_name,
"content": str(result),
}
)
else:
logging.warning(
f"Tool call requested unknown function '{function_name}'"
)
print(f"Result from function '{function_name}': {result}")
# --- 4) Make a second call so the LLM can incorporate the tool results
second_params = dict(params) # copy the same params
second_params["messages"] = messages
# We'll remove "tools" from second call, or keep it if you want more calls
# but typically you keep it in case it wants additional calls
second_response = litellm.completion(**second_params)
second_msg = second_response.choices[0].message
final_response = second_msg.content or ""
# Return the result directly
return result
return final_response
except Exception as e:
logging.error(
f"Error executing function '{function_name}': {e}"
)
return text_response # Fallback to text response
else:
logging.warning(
f"Tool call requested unknown function '{function_name}'"
)
return text_response # Fallback to text response
except Exception as e:
# check if context length was exceeded, otherwise log
# Check if context length was exceeded, otherwise log
if not LLMContextLengthExceededException(
str(e)
)._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}")
# re-raise
# Re-raise the exception
raise
def supports_function_calling(self) -> bool: