import json import logging import os import sys import threading import warnings from contextlib import contextmanager from typing import Any, Dict, List, Optional, Union # Load environment variables from .env file import litellm from dotenv import load_dotenv from litellm import get_supported_openai_params from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, ) load_dotenv() class FilteredStream: def __init__(self, original_stream): self._original_stream = original_stream self._lock = threading.Lock() def write(self, s) -> int: with self._lock: # Filter out extraneous messages from LiteLLM if ( "Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new" in s or "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True`" in s ): return 0 return self._original_stream.write(s) def flush(self): with self._lock: return self._original_stream.flush() LLM_CONTEXT_WINDOW_SIZES = { # openai "gpt-4": 8192, "gpt-4o": 128000, "gpt-4o-mini": 128000, "gpt-4-turbo": 128000, "o1-preview": 128000, "o1-mini": 128000, # gemini "gemini-2.0-flash": 1048576, "gemini-1.5-pro": 2097152, "gemini-1.5-flash": 1048576, "gemini-1.5-flash-8b": 1048576, # deepseek "deepseek-chat": 128000, # groq "gemma2-9b-it": 8192, "gemma-7b-it": 8192, "llama3-groq-70b-8192-tool-use-preview": 8192, "llama3-groq-8b-8192-tool-use-preview": 8192, "llama-3.1-70b-versatile": 131072, "llama-3.1-8b-instant": 131072, "llama-3.2-1b-preview": 8192, "llama-3.2-3b-preview": 8192, "llama-3.2-11b-text-preview": 8192, "llama-3.2-90b-text-preview": 8192, "llama3-70b-8192": 8192, "llama3-8b-8192": 8192, "mixtral-8x7b-32768": 32768, } DEFAULT_CONTEXT_WINDOW_SIZE = 8192 CONTEXT_WINDOW_USAGE_RATIO = 0.75 @contextmanager def suppress_warnings(): with warnings.catch_warnings(): warnings.filterwarnings("ignore") # Redirect stdout and stderr old_stdout = sys.stdout old_stderr = sys.stderr sys.stdout = FilteredStream(old_stdout) sys.stderr = FilteredStream(old_stderr) try: yield finally: sys.stdout = old_stdout sys.stderr = old_stderr class LLM: def __init__( self, model: str, timeout: Optional[Union[float, int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, n: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, max_completion_tokens: Optional[int] = None, max_tokens: Optional[int] = None, presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, logit_bias: Optional[Dict[int, float]] = None, response_format: Optional[Dict[str, Any]] = None, seed: Optional[int] = None, logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, base_url: Optional[str] = None, api_version: Optional[str] = None, api_key: Optional[str] = None, callbacks: List[Any] = [], ): self.model = model self.timeout = timeout self.temperature = temperature self.top_p = top_p self.n = n self.stop = stop self.max_completion_tokens = max_completion_tokens self.max_tokens = max_tokens self.presence_penalty = presence_penalty self.frequency_penalty = frequency_penalty self.logit_bias = logit_bias self.response_format = response_format self.seed = seed self.logprobs = logprobs self.top_logprobs = top_logprobs self.base_url = base_url self.api_version = api_version self.api_key = api_key self.callbacks = callbacks self.context_window_size = 0 # For safety, we disable passing init params to next calls litellm.drop_params = True self.set_callbacks(callbacks) self.set_env_callbacks() def to_dict(self) -> dict: """ Return a dict of all relevant parameters for serialization. """ return { "model": self.model, "timeout": self.timeout, "temperature": self.temperature, "top_p": self.top_p, "n": self.n, "stop": self.stop, "max_completion_tokens": self.max_completion_tokens, "max_tokens": self.max_tokens, "presence_penalty": self.presence_penalty, "frequency_penalty": self.frequency_penalty, "logit_bias": self.logit_bias, "response_format": self.response_format, "seed": self.seed, "logprobs": self.logprobs, "top_logprobs": self.top_logprobs, "base_url": self.base_url, "api_version": self.api_version, "api_key": self.api_key, "callbacks": self.callbacks, } @classmethod def from_dict(cls, data: dict) -> "LLM": """ Create an LLM instance from a dict. We assume the dict has all relevant keys that match what's in the constructor. """ known_fields = {} known_fields["model"] = data.pop("model", None) known_fields["timeout"] = data.pop("timeout", None) known_fields["temperature"] = data.pop("temperature", None) known_fields["top_p"] = data.pop("top_p", None) known_fields["n"] = data.pop("n", None) known_fields["stop"] = data.pop("stop", None) known_fields["max_completion_tokens"] = data.pop("max_completion_tokens", None) known_fields["max_tokens"] = data.pop("max_tokens", None) known_fields["presence_penalty"] = data.pop("presence_penalty", None) known_fields["frequency_penalty"] = data.pop("frequency_penalty", None) known_fields["logit_bias"] = data.pop("logit_bias", None) known_fields["response_format"] = data.pop("response_format", None) known_fields["seed"] = data.pop("seed", None) known_fields["logprobs"] = data.pop("logprobs", None) known_fields["top_logprobs"] = data.pop("top_logprobs", None) known_fields["base_url"] = data.pop("base_url", None) known_fields["api_version"] = data.pop("api_version", None) known_fields["api_key"] = data.pop("api_key", None) known_fields["callbacks"] = data.pop("callbacks", None) return cls(**known_fields, **data) def call( self, messages: List[Dict[str, str]], tools: Optional[List[dict]] = None, callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> str: """ High-level call method that: 1) Calls litellm.completion 2) Checks for function/tool calls 3) If tool calls found: a) executes each function b) appends their output as tool messages c) calls litellm.completion again with the updated messages 4) Returns the final text response :param messages: The conversation messages :param tools: Optional list of function schemas for function calling :param callbacks: Optional list of callbacks :param available_functions: A dictionary mapping function_name -> actual Python function :return: Final text response from the LLM """ with suppress_warnings(): if callbacks: self.set_callbacks(callbacks) try: # --- 1) Make first completion call params = { "model": self.model, "messages": messages, "timeout": self.timeout, "temperature": self.temperature, "top_p": self.top_p, "n": self.n, "stop": self.stop, "max_tokens": self.max_tokens or self.max_completion_tokens, "presence_penalty": self.presence_penalty, "frequency_penalty": self.frequency_penalty, "logit_bias": self.logit_bias, "response_format": self.response_format, "seed": self.seed, "logprobs": self.logprobs, "top_logprobs": self.top_logprobs, "api_base": self.base_url, "api_version": self.api_version, "api_key": self.api_key, "stream": False, "tools": tools, # pass the tool schema } # remove None values params = {k: v for k, v in params.items() if v is not None} print(f"Params: {params}") response = litellm.completion(**params) response_message = response.choices[0].message text_response = response_message.content or "" tool_calls = getattr(response_message, "tool_calls", []) # --- 2) If no tool calls, we can just return if not tool_calls or not available_functions: return text_response # --- 3) We have tool calls and a dictionary of available functions # run them, append output to messages for tool_call in tool_calls: function_name = tool_call.function.name if function_name in available_functions: # parse arguments function_args = {} try: function_args = json.loads(tool_call.function.arguments) except Exception as e: logging.warning(f"Failed to parse function arguments: {e}") fn = available_functions[function_name] # call the actual tool function result = fn(**function_args) # append the "tool" response to messages messages.append( { "tool_call_id": tool_call.id, "role": "function", "name": function_name, "content": str(result), } ) else: logging.warning( f"Tool call requested unknown function '{function_name}'" ) # --- 4) Make a second call so the LLM can incorporate the tool results second_params = dict(params) # copy the same params second_params["messages"] = messages # We'll remove "tools" from second call, or keep it if you want more calls # but typically you keep it in case it wants additional calls second_response = litellm.completion(**second_params) second_msg = second_response.choices[0].message final_response = second_msg.content or "" return final_response except Exception as e: # check if context length was exceeded, otherwise log if not LLMContextLengthExceededException( str(e) )._is_context_limit_error(str(e)): logging.error(f"LiteLLM call failed: {str(e)}") # re-raise raise def supports_function_calling(self) -> bool: try: params = get_supported_openai_params(model=self.model) return "response_format" in params except Exception as e: logging.error(f"Failed to get supported params: {str(e)}") return False def supports_stop_words(self) -> bool: try: params = get_supported_openai_params(model=self.model) return "stop" in params except Exception as e: logging.error(f"Failed to get supported params: {str(e)}") return False def get_context_window_size(self) -> int: """ Returns the context window size, using 75% of the maximum to avoid cutting off messages mid-thread. """ if self.context_window_size != 0: return self.context_window_size self.context_window_size = int( DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO ) for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): if self.model.startswith(key): self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) return self.context_window_size def set_callbacks(self, callbacks: List[Any]): """ Attempt to keep a single set of callbacks in litellm by removing old duplicates and adding new ones. """ callback_types = [type(callback) for callback in callbacks] for callback in litellm.success_callback[:]: if type(callback) in callback_types: litellm.success_callback.remove(callback) for callback in litellm._async_success_callback[:]: if type(callback) in callback_types: litellm._async_success_callback.remove(callback) litellm.callbacks = callbacks def set_env_callbacks(self): """ Sets the success and failure callbacks for the LiteLLM library from environment variables. """ success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "") success_callbacks = [] if success_callbacks_str: success_callbacks = [ cb.strip() for cb in success_callbacks_str.split(",") if cb.strip() ] failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "") failure_callbacks = [] if failure_callbacks_str: failure_callbacks = [ cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip() ] litellm.success_callback = success_callbacks litellm.failure_callback = failure_callbacks