core loop should be working and ready for testing.

This commit is contained in:
Brandon Hancock
2024-12-26 14:18:42 -05:00
parent 1c45f730c6
commit 2bf5b15f1e
7 changed files with 402 additions and 169 deletions

View File

@@ -1,3 +1,4 @@
import json
import logging
import os
import sys
@@ -7,7 +8,7 @@ from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
import litellm
from litellm import get_supported_openai_params
from litellm import ModelResponse, get_supported_openai_params
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
@@ -21,6 +22,7 @@ class FilteredStream:
def write(self, s) -> int:
with self._lock:
# Filter out extraneous messages from LiteLLM
if (
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
in s
@@ -80,11 +82,9 @@ def suppress_warnings():
old_stderr = sys.stderr
sys.stdout = FilteredStream(old_stdout)
sys.stderr = FilteredStream(old_stderr)
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
@@ -135,8 +135,10 @@ class LLM:
self.context_window_size = 0
self.kwargs = kwargs
# For safety, we disable passing init params to next calls
litellm.drop_params = True
litellm.set_verbose = False
self.set_callbacks(callbacks)
self.set_env_callbacks()
@@ -173,8 +175,6 @@ class LLM:
Create an LLM instance from a dict.
We assume the dict has all relevant keys that match what's in the constructor.
"""
# We can pop off fields we know, then pass the rest into **kwargs
# so that any leftover keys still get passed into the LLM constructor.
known_fields = {}
known_fields["model"] = data.pop("model", None)
known_fields["timeout"] = data.pop("timeout", None)
@@ -196,15 +196,37 @@ class LLM:
known_fields["api_key"] = data.pop("api_key", None)
known_fields["callbacks"] = data.pop("callbacks", None)
# leftover keys go into kwargs:
return cls(**known_fields, **data)
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
def call(
self,
messages: List[Dict[str, str]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""
High-level call method that:
1) Calls litellm.completion
2) Checks for function/tool calls
3) If tool calls found:
a) executes each function
b) appends their output as tool messages
c) calls litellm.completion again with the updated messages
4) Returns the final text response
:param messages: The conversation messages
:param tools: Optional list of function schemas for function calling
:param callbacks: Optional list of callbacks
:param available_functions: A dictionary mapping function_name -> actual Python function
:return: Final text response from the LLM
"""
with suppress_warnings():
if callbacks and len(callbacks) > 0:
if callbacks:
self.set_callbacks(callbacks)
try:
# --- 1) Make first completion call
params = {
"model": self.model,
"messages": messages,
@@ -225,21 +247,71 @@ class LLM:
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,
"tools": tools, # pass the tool schema
**self.kwargs,
}
# Remove None values to avoid passing unnecessary parameters
# remove None values
params = {k: v for k, v in params.items() if v is not None}
response = litellm.completion(**params)
return response["choices"][0]["message"]["content"]
response_message = response.choices[0].message
text_response = response_message.content or ""
tool_calls = getattr(response_message, "tool_calls", [])
# --- 2) If no tool calls, we can just return
if not tool_calls or not available_functions:
return text_response
# --- 3) We have tool calls and a dictionary of available functions
# run them, append output to messages
for tool_call in tool_calls:
function_name = tool_call.function.name
if function_name in available_functions:
# parse arguments
function_args = {}
try:
function_args = json.loads(tool_call.function.arguments)
except Exception as e:
logging.warning(f"Failed to parse function arguments: {e}")
fn = available_functions[function_name]
# call the actual tool function
result = fn(**function_args)
# append the "tool" response to messages
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": str(result),
}
)
else:
logging.warning(
f"Tool call requested unknown function '{function_name}'"
)
# --- 4) Make a second call so the LLM can incorporate the tool results
second_params = dict(params) # copy the same params
second_params["messages"] = messages
# We'll remove "tools" from second call, or keep it if you want more calls
# but typically you keep it in case it wants additional calls
second_response = litellm.completion(**second_params)
second_msg = second_response.choices[0].message
final_response = second_msg.content or ""
return final_response
except Exception as e:
# check if context length was exceeded, otherwise log
if not LLMContextLengthExceededException(
str(e)
)._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}")
raise # Re-raise the exception after logging
# re-raise
raise
def supports_function_calling(self) -> bool:
try:
@@ -258,7 +330,10 @@ class LLM:
return False
def get_context_window_size(self) -> int:
# Only using 75% of the context window size to avoid cutting the message in the middle
"""
Returns the context window size, using 75% of the maximum to avoid
cutting off messages mid-thread.
"""
if self.context_window_size != 0:
return self.context_window_size
@@ -271,6 +346,10 @@ class LLM:
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
"""
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
@@ -285,34 +364,19 @@ class LLM:
def set_env_callbacks(self):
"""
Sets the success and failure callbacks for the LiteLLM library from environment variables.
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
environment variables, which should contain comma-separated lists of callback names.
It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`,
respectively.
If the environment variables are not set or are empty, the corresponding callback lists
will be set to empty lists.
Example:
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
LITELLM_FAILURE_CALLBACKS="langfuse"
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
`litellm.failure_callback` to ["langfuse"].
"""
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
success_callbacks = []
if success_callbacks_str:
success_callbacks = [
callback.strip() for callback in success_callbacks_str.split(",")
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
]
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
failure_callbacks = []
if failure_callbacks_str:
failure_callbacks = [
callback.strip() for callback in failure_callbacks_str.split(",")
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
]
litellm.success_callback = success_callbacks