properly return tool call result

This commit is contained in:
Brandon Hancock
2024-12-30 13:32:53 -05:00
parent 5da6d36dd9
commit bcd838a2ff
3 changed files with 67 additions and 55 deletions

View File

@@ -154,6 +154,8 @@ def run_crew_tool(messages: List[Dict[str, str]], **kwargs: Any) -> str:
2) Passes the LLM-provided kwargs as CLI overrides (e.g. --key=value). 2) Passes the LLM-provided kwargs as CLI overrides (e.g. --key=value).
3) Also takes in messages from the main chat loop and passes them to the command. 3) Also takes in messages from the main chat loop and passes them to the command.
""" """
import json
import re
import subprocess import subprocess
command = ["uv", "run", "run_crew"] command = ["uv", "run", "run_crew"]
@@ -169,9 +171,28 @@ def run_crew_tool(messages: List[Dict[str, str]], **kwargs: Any) -> str:
try: try:
# Capture stdout so we can return it to the LLM # Capture stdout so we can return it to the LLM
result = subprocess.run(command, text=True, check=True) print(f"Command: {command}")
result = subprocess.run(command, text=True, capture_output=True, check=True)
print(f"Result: {result}")
stdout_str = result.stdout.strip() stdout_str = result.stdout.strip()
return stdout_str if stdout_str else "No output from run_crew command." print(f"Stdout: {stdout_str}")
# Remove ANSI escape sequences
ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
stdout_clean = ansi_escape.sub("", stdout_str)
# Find the last occurrence of '## Final Answer:'
final_answer_index = stdout_clean.rfind("## Final Answer:")
if final_answer_index != -1:
# Extract everything after '## Final Answer:'
final_output = stdout_clean[
final_answer_index + len("## Final Answer:") :
].strip()
print(f"Final output: {final_output}")
return final_output
else:
# If '## Final Answer:' is not found, return the cleaned stdout
return stdout_clean if stdout_clean else "No output from run_crew command."
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
return ( return (
f"Error: Command failed with exit code {e.returncode}\n" f"Error: Command failed with exit code {e.returncode}\n"

View File

@@ -5,12 +5,13 @@ import sys
import threading import threading
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union, cast
# Load environment variables from .env file # Load environment variables from .env file
import litellm import litellm
from dotenv import load_dotenv from dotenv import load_dotenv
from litellm import get_supported_openai_params from litellm import Choices, get_supported_openai_params
from litellm.types.utils import ModelResponse
from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException, LLMContextLengthExceededException,
@@ -209,24 +210,23 @@ class LLM:
High-level call method that: High-level call method that:
1) Calls litellm.completion 1) Calls litellm.completion
2) Checks for function/tool calls 2) Checks for function/tool calls
3) If tool calls found: 3) If a tool call is found:
a) executes each function a) executes the function
b) appends their output as tool messages b) returns the result
c) calls litellm.completion again with the updated messages 4) If no tool call, returns the text response
4) Returns the final text response
:param messages: The conversation messages :param messages: The conversation messages
:param tools: Optional list of function schemas for function calling :param tools: Optional list of function schemas for function calling
:param callbacks: Optional list of callbacks :param callbacks: Optional list of callbacks
:param available_functions: A dictionary mapping function_name -> actual Python function :param available_functions: A dictionary mapping function_name -> actual Python function
:return: Final text response from the LLM :return: Final text response from the LLM or the tool result
""" """
with suppress_warnings(): with suppress_warnings():
if callbacks: if callbacks:
self.set_callbacks(callbacks) self.set_callbacks(callbacks)
try: try:
# --- 1) Make first completion call # --- 1) Make the completion call
params = { params = {
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
@@ -250,68 +250,61 @@ class LLM:
"tools": tools, # pass the tool schema "tools": tools, # pass the tool schema
} }
# remove None values # Remove None values
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
print(f"Params: {params}")
response = litellm.completion(**params) response = litellm.completion(**params)
response_message = response.choices[0].message response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or "" text_response = response_message.content or ""
tool_calls = getattr(response_message, "tool_calls", []) tool_calls = getattr(response_message, "tool_calls", [])
# --- 2) If no tool calls, we can just return # --- 2) If no tool calls, return the text response
if not tool_calls or not available_functions: if not tool_calls or not available_functions:
return text_response return text_response
# --- 3) We have tool calls and a dictionary of available functions # --- 3) Handle the tool call
# run them, append output to messages tool_call = tool_calls[0]
for tool_call in tool_calls: function_name = tool_call.function.name
function_name = tool_call.function.name
if function_name in available_functions:
# parse arguments
function_args = {}
try:
function_args = json.loads(tool_call.function.arguments)
except Exception as e:
logging.warning(f"Failed to parse function arguments: {e}")
fn = available_functions[function_name] if function_name in available_functions:
# call the actual tool function # Parse arguments
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.warning(f"Failed to parse function arguments: {e}")
return text_response # Fallback to text response
fn = available_functions[function_name]
try:
# Call the actual tool function
result = fn(**function_args) result = fn(**function_args)
# append the "tool" response to messages print(f"Result from function '{function_name}': {result}")
messages.append(
{
"tool_call_id": tool_call.id,
"role": "function",
"name": function_name,
"content": str(result),
}
)
else:
logging.warning(
f"Tool call requested unknown function '{function_name}'"
)
# --- 4) Make a second call so the LLM can incorporate the tool results # Return the result directly
second_params = dict(params) # copy the same params return result
second_params["messages"] = messages
# We'll remove "tools" from second call, or keep it if you want more calls
# but typically you keep it in case it wants additional calls
second_response = litellm.completion(**second_params)
second_msg = second_response.choices[0].message
final_response = second_msg.content or ""
return final_response except Exception as e:
logging.error(
f"Error executing function '{function_name}': {e}"
)
return text_response # Fallback to text response
else:
logging.warning(
f"Tool call requested unknown function '{function_name}'"
)
return text_response # Fallback to text response
except Exception as e: except Exception as e:
# check if context length was exceeded, otherwise log # Check if context length was exceeded, otherwise log
if not LLMContextLengthExceededException( if not LLMContextLengthExceededException(
str(e) str(e)
)._is_context_limit_error(str(e)): )._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}") logging.error(f"LiteLLM call failed: {str(e)}")
# re-raise # Re-raise the exception
raise raise
def supports_function_calling(self) -> bool: def supports_function_calling(self) -> bool:

View File

@@ -439,8 +439,6 @@ class Task(BaseModel):
f"\n\n{conversation_instruction}\n\n{conversation_history}" f"\n\n{conversation_instruction}\n\n{conversation_history}"
) )
print("UPDATED DESCRIPTION:", self.description)
def interpolate_only(self, input_string: str, inputs: Dict[str, Any]) -> str: def interpolate_only(self, input_string: str, inputs: Dict[str, Any]) -> str:
"""Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched.""" """Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched."""
escaped_string = input_string.replace("{", "{{").replace("}", "}}") escaped_string = input_string.replace("{", "{{").replace("}", "}}")