mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Some checks failed
Mark stale issues and pull requests / stale (push) Has been cancelled
* Clean up to match enterprise * add additional params to LLM calls * make sure additional params are getting passed to llm * update docs * drop print
397 lines
14 KiB
Python
397 lines
14 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import threading
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
from typing import Any, Dict, List, Optional, Union, cast
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore", UserWarning)
|
|
import litellm
|
|
from litellm import Choices, get_supported_openai_params
|
|
from litellm.types.utils import ModelResponse
|
|
|
|
|
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|
LLMContextLengthExceededException,
|
|
)
|
|
|
|
load_dotenv()
|
|
|
|
|
|
class FilteredStream:
|
|
def __init__(self, original_stream):
|
|
self._original_stream = original_stream
|
|
self._lock = threading.Lock()
|
|
|
|
def write(self, s) -> int:
|
|
with self._lock:
|
|
# Filter out extraneous messages from LiteLLM
|
|
if (
|
|
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
|
|
in s
|
|
or "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True`"
|
|
in s
|
|
):
|
|
return 0
|
|
return self._original_stream.write(s)
|
|
|
|
def flush(self):
|
|
with self._lock:
|
|
return self._original_stream.flush()
|
|
|
|
|
|
LLM_CONTEXT_WINDOW_SIZES = {
|
|
# openai
|
|
"gpt-4": 8192,
|
|
"gpt-4o": 128000,
|
|
"gpt-4o-mini": 128000,
|
|
"gpt-4-turbo": 128000,
|
|
"o1-preview": 128000,
|
|
"o1-mini": 128000,
|
|
# gemini
|
|
"gemini-2.0-flash": 1048576,
|
|
"gemini-1.5-pro": 2097152,
|
|
"gemini-1.5-flash": 1048576,
|
|
"gemini-1.5-flash-8b": 1048576,
|
|
# deepseek
|
|
"deepseek-chat": 128000,
|
|
# groq
|
|
"gemma2-9b-it": 8192,
|
|
"gemma-7b-it": 8192,
|
|
"llama3-groq-70b-8192-tool-use-preview": 8192,
|
|
"llama3-groq-8b-8192-tool-use-preview": 8192,
|
|
"llama-3.1-70b-versatile": 131072,
|
|
"llama-3.1-8b-instant": 131072,
|
|
"llama-3.2-1b-preview": 8192,
|
|
"llama-3.2-3b-preview": 8192,
|
|
"llama-3.2-11b-text-preview": 8192,
|
|
"llama-3.2-90b-text-preview": 8192,
|
|
"llama3-70b-8192": 8192,
|
|
"llama3-8b-8192": 8192,
|
|
"mixtral-8x7b-32768": 32768,
|
|
"llama-3.3-70b-versatile": 128000,
|
|
"llama-3.3-70b-instruct": 128000,
|
|
# sambanova
|
|
"Meta-Llama-3.3-70B-Instruct": 131072,
|
|
"QwQ-32B-Preview": 8192,
|
|
"Qwen2.5-72B-Instruct": 8192,
|
|
"Qwen2.5-Coder-32B-Instruct": 8192,
|
|
"Meta-Llama-3.1-405B-Instruct": 8192,
|
|
"Meta-Llama-3.1-70B-Instruct": 131072,
|
|
"Meta-Llama-3.1-8B-Instruct": 131072,
|
|
"Llama-3.2-90B-Vision-Instruct": 16384,
|
|
"Llama-3.2-11B-Vision-Instruct": 16384,
|
|
"Meta-Llama-3.2-3B-Instruct": 4096,
|
|
"Meta-Llama-3.2-1B-Instruct": 16384,
|
|
}
|
|
|
|
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
|
|
CONTEXT_WINDOW_USAGE_RATIO = 0.75
|
|
|
|
|
|
@contextmanager
|
|
def suppress_warnings():
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore")
|
|
warnings.filterwarnings(
|
|
"ignore", message="open_text is deprecated*", category=DeprecationWarning
|
|
)
|
|
|
|
# Redirect stdout and stderr
|
|
old_stdout = sys.stdout
|
|
old_stderr = sys.stderr
|
|
sys.stdout = FilteredStream(old_stdout)
|
|
sys.stderr = FilteredStream(old_stderr)
|
|
try:
|
|
yield
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
sys.stderr = old_stderr
|
|
|
|
|
|
class LLM:
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
timeout: Optional[Union[float, int]] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
n: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
max_completion_tokens: Optional[int] = None,
|
|
max_tokens: Optional[int] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
frequency_penalty: Optional[float] = None,
|
|
logit_bias: Optional[Dict[int, float]] = None,
|
|
response_format: Optional[Dict[str, Any]] = None,
|
|
seed: Optional[int] = None,
|
|
logprobs: Optional[int] = None,
|
|
top_logprobs: Optional[int] = None,
|
|
base_url: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
callbacks: List[Any] = [],
|
|
**kwargs,
|
|
):
|
|
self.model = model
|
|
self.timeout = timeout
|
|
self.temperature = temperature
|
|
self.top_p = top_p
|
|
self.n = n
|
|
self.max_completion_tokens = max_completion_tokens
|
|
self.max_tokens = max_tokens
|
|
self.presence_penalty = presence_penalty
|
|
self.frequency_penalty = frequency_penalty
|
|
self.logit_bias = logit_bias
|
|
self.response_format = response_format
|
|
self.seed = seed
|
|
self.logprobs = logprobs
|
|
self.top_logprobs = top_logprobs
|
|
self.base_url = base_url
|
|
self.api_base = api_base
|
|
self.api_version = api_version
|
|
self.api_key = api_key
|
|
self.callbacks = callbacks
|
|
self.context_window_size = 0
|
|
self.additional_params = kwargs
|
|
|
|
litellm.drop_params = True
|
|
|
|
# Normalize self.stop to always be a List[str]
|
|
if stop is None:
|
|
self.stop: List[str] = []
|
|
elif isinstance(stop, str):
|
|
self.stop = [stop]
|
|
else:
|
|
self.stop = stop
|
|
|
|
self.set_callbacks(callbacks)
|
|
self.set_env_callbacks()
|
|
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
"""
|
|
High-level llm call method that:
|
|
1) Accepts either a string or a list of messages
|
|
2) Converts string input to the required message format
|
|
3) Calls litellm.completion
|
|
4) Handles function/tool calls if any
|
|
5) Returns the final text response or tool result
|
|
|
|
Parameters:
|
|
- messages (Union[str, List[Dict[str, str]]]): The input messages for the LLM.
|
|
- If a string is provided, it will be converted into a message list with a single entry.
|
|
- If a list of dictionaries is provided, each dictionary should have 'role' and 'content' keys.
|
|
- tools (Optional[List[dict]]): A list of tool schemas for function calling.
|
|
- callbacks (Optional[List[Any]]): A list of callback functions to be executed.
|
|
- available_functions (Optional[Dict[str, Any]]): A dictionary mapping function names to actual Python functions.
|
|
|
|
Returns:
|
|
- str: The final text response from the LLM or the result of a tool function call.
|
|
|
|
Examples:
|
|
---------
|
|
# Example 1: Using a string input
|
|
response = llm.call("Return the name of a random city in the world.")
|
|
print(response)
|
|
|
|
# Example 2: Using a list of messages
|
|
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
response = llm.call(messages)
|
|
print(response)
|
|
"""
|
|
if isinstance(messages, str):
|
|
messages = [{"role": "user", "content": messages}]
|
|
|
|
with suppress_warnings():
|
|
if callbacks and len(callbacks) > 0:
|
|
self.set_callbacks(callbacks)
|
|
|
|
try:
|
|
# --- 1) Prepare the parameters for the completion call
|
|
params = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"timeout": self.timeout,
|
|
"temperature": self.temperature,
|
|
"top_p": self.top_p,
|
|
"n": self.n,
|
|
"stop": self.stop,
|
|
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
|
"presence_penalty": self.presence_penalty,
|
|
"frequency_penalty": self.frequency_penalty,
|
|
"logit_bias": self.logit_bias,
|
|
"response_format": self.response_format,
|
|
"seed": self.seed,
|
|
"logprobs": self.logprobs,
|
|
"top_logprobs": self.top_logprobs,
|
|
"api_base": self.api_base,
|
|
"base_url": self.base_url,
|
|
"api_version": self.api_version,
|
|
"api_key": self.api_key,
|
|
"stream": False,
|
|
"tools": tools,
|
|
**self.additional_params,
|
|
}
|
|
|
|
# Remove None values from params
|
|
params = {k: v for k, v in params.items() if v is not None}
|
|
|
|
# --- 2) Make the completion call
|
|
response = litellm.completion(**params)
|
|
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
|
0
|
|
].message
|
|
text_response = response_message.content or ""
|
|
tool_calls = getattr(response_message, "tool_calls", [])
|
|
|
|
# --- 3) Handle callbacks with usage info
|
|
if callbacks and len(callbacks) > 0:
|
|
for callback in callbacks:
|
|
if hasattr(callback, "log_success_event"):
|
|
usage_info = getattr(response, "usage", None)
|
|
if usage_info:
|
|
callback.log_success_event(
|
|
kwargs=params,
|
|
response_obj={"usage": usage_info},
|
|
start_time=0,
|
|
end_time=0,
|
|
)
|
|
|
|
# --- 4) If no tool calls, return the text response
|
|
if not tool_calls or not available_functions:
|
|
return text_response
|
|
|
|
# --- 5) Handle the tool call
|
|
tool_call = tool_calls[0]
|
|
function_name = tool_call.function.name
|
|
|
|
if function_name in available_functions:
|
|
try:
|
|
function_args = json.loads(tool_call.function.arguments)
|
|
except json.JSONDecodeError as e:
|
|
logging.warning(f"Failed to parse function arguments: {e}")
|
|
return text_response
|
|
|
|
fn = available_functions[function_name]
|
|
try:
|
|
# Call the actual tool function
|
|
result = fn(**function_args)
|
|
return result
|
|
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Error executing function '{function_name}': {e}"
|
|
)
|
|
return text_response
|
|
|
|
else:
|
|
logging.warning(
|
|
f"Tool call requested unknown function '{function_name}'"
|
|
)
|
|
return text_response
|
|
|
|
except Exception as e:
|
|
if not LLMContextLengthExceededException(
|
|
str(e)
|
|
)._is_context_limit_error(str(e)):
|
|
logging.error(f"LiteLLM call failed: {str(e)}")
|
|
raise
|
|
|
|
def supports_function_calling(self) -> bool:
|
|
try:
|
|
params = get_supported_openai_params(model=self.model)
|
|
return "response_format" in params
|
|
except Exception as e:
|
|
logging.error(f"Failed to get supported params: {str(e)}")
|
|
return False
|
|
|
|
def supports_stop_words(self) -> bool:
|
|
try:
|
|
params = get_supported_openai_params(model=self.model)
|
|
return "stop" in params
|
|
except Exception as e:
|
|
logging.error(f"Failed to get supported params: {str(e)}")
|
|
return False
|
|
|
|
def get_context_window_size(self) -> int:
|
|
"""
|
|
Returns the context window size, using 75% of the maximum to avoid
|
|
cutting off messages mid-thread.
|
|
"""
|
|
if self.context_window_size != 0:
|
|
return self.context_window_size
|
|
|
|
self.context_window_size = int(
|
|
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
|
|
)
|
|
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
|
if self.model.startswith(key):
|
|
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
|
return self.context_window_size
|
|
|
|
def set_callbacks(self, callbacks: List[Any]):
|
|
"""
|
|
Attempt to keep a single set of callbacks in litellm by removing old
|
|
duplicates and adding new ones.
|
|
"""
|
|
with suppress_warnings():
|
|
callback_types = [type(callback) for callback in callbacks]
|
|
for callback in litellm.success_callback[:]:
|
|
if type(callback) in callback_types:
|
|
litellm.success_callback.remove(callback)
|
|
|
|
for callback in litellm._async_success_callback[:]:
|
|
if type(callback) in callback_types:
|
|
litellm._async_success_callback.remove(callback)
|
|
|
|
litellm.callbacks = callbacks
|
|
|
|
def set_env_callbacks(self):
|
|
"""
|
|
Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
|
|
|
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
|
environment variables, which should contain comma-separated lists of callback names.
|
|
It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`,
|
|
respectively.
|
|
|
|
If the environment variables are not set or are empty, the corresponding callback lists
|
|
will be set to empty lists.
|
|
|
|
Example:
|
|
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
|
|
LITELLM_FAILURE_CALLBACKS="langfuse"
|
|
|
|
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
|
|
`litellm.failure_callback` to ["langfuse"].
|
|
"""
|
|
with suppress_warnings():
|
|
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
|
success_callbacks = []
|
|
if success_callbacks_str:
|
|
success_callbacks = [
|
|
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
|
]
|
|
|
|
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
|
failure_callbacks = []
|
|
if failure_callbacks_str:
|
|
failure_callbacks = [
|
|
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
|
]
|
|
|
|
litellm.success_callback = success_callbacks
|
|
litellm.failure_callback = failure_callbacks
|