fix: Add proper null checks for logger calls and improve type safety in LLM class

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-01-01 21:54:49 +00:00
parent 5d3c34b3ea
commit bfb578d506

View File

@@ -6,6 +6,8 @@ import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm
@@ -93,10 +95,33 @@ def suppress_warnings():
sys.stderr = old_stderr
class LLM:
class LLM(BaseModel):
model: str = "gpt-4" # Set default model
timeout: Optional[Union[float, int]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
max_completion_tokens: Optional[int] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[int, float]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
callbacks: Optional[List[Any]] = None
context_window_size: Optional[int] = None
kwargs: Dict[str, Any] = Field(default_factory=dict)
logger: Optional[logging.Logger] = Field(default_factory=lambda: logging.getLogger(__name__))
def __init__(
self,
model: Union[str, 'LLM'],
model: Optional[Union[str, 'LLM']] = "gpt-4",
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
@@ -114,12 +139,103 @@ class LLM:
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
**kwargs,
):
callbacks: Optional[List[Any]] = None,
context_window_size: Optional[int] = None,
**kwargs: Any,
) -> None:
# Initialize with default values
init_dict = {
"model": model if isinstance(model, str) else "gpt-4",
"timeout": timeout,
"temperature": temperature,
"top_p": top_p,
"n": n,
"stop": stop,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"response_format": response_format,
"seed": seed,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"base_url": base_url,
"api_version": api_version,
"api_key": api_key,
"callbacks": callbacks,
"context_window_size": context_window_size,
"kwargs": kwargs,
}
super().__init__(**init_dict)
# Initialize model with default value
self.model = "gpt-4" # Default fallback
# Extract and validate model name
if isinstance(model, LLM):
# Extract and validate model name from LLM instance
if hasattr(model, 'model'):
if isinstance(model.model, str):
self.model = model.model
else:
# Try to extract string model name from nested LLM
if isinstance(model.model, LLM):
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
else:
# Extract and validate model name for non-LLM instances
if not isinstance(model, str):
if self.logger:
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
if model is not None:
if hasattr(model, 'model_name'):
model_name = getattr(model, 'model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
elif hasattr(model, 'model'):
model_attr = getattr(model, 'model', None)
self.model = str(model_attr) if model_attr is not None else "gpt-4"
elif hasattr(model, '_model_name'):
model_name = getattr(model, '_model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
else:
self.model = "gpt-4" # Default fallback
if self.logger:
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
else:
self.model = "gpt-4" # Default fallback for None
if self.logger:
self.logger.warning("Model is None, using default: gpt-4")
else:
self.model = str(model) # Ensure it's a string
# If model is an LLM instance, copy its configuration
if isinstance(model, LLM):
self.model = model.model
# Extract and validate model name first
if hasattr(model, 'model'):
if isinstance(model.model, str):
self.model = model.model
else:
# Try to extract string model name from nested LLM
if isinstance(model.model, LLM):
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
# Copy other configuration
self.timeout = model.timeout
self.temperature = model.temperature
self.top_p = model.top_p
@@ -140,8 +256,44 @@ class LLM:
self.callbacks = model.callbacks
self.context_window_size = model.context_window_size
self.kwargs = model.kwargs
# Final validation of model name
if not isinstance(self.model, str):
self.model = "gpt-4"
if self.logger:
self.logger.warning("Model name is still not a string after LLM copy, using default: gpt-4")
else:
self.model = model
# Extract and validate model name for non-LLM instances
if not isinstance(model, str):
if self.logger:
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
if model is not None:
if hasattr(model, 'model_name'):
model_name = getattr(model, 'model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
elif hasattr(model, 'model'):
model_attr = getattr(model, 'model', None)
self.model = str(model_attr) if model_attr is not None else "gpt-4"
elif hasattr(model, '_model_name'):
model_name = getattr(model, '_model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
else:
self.model = "gpt-4" # Default fallback
if self.logger:
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
else:
self.model = "gpt-4" # Default fallback for None
if self.logger:
self.logger.warning("Model is None, using default: gpt-4")
else:
self.model = str(model) # Ensure it's a string
# Final validation
if not isinstance(self.model, str):
self.model = "gpt-4"
if self.logger:
self.logger.warning("Model name is still not a string after extraction, using default: gpt-4")
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
@@ -163,69 +315,155 @@ class LLM:
self.context_window_size = 0
self.kwargs = kwargs
# Ensure model is a string after initialization
if not isinstance(self.model, str):
self.model = "gpt-4"
self.logger.warning(f"Model is still not a string after initialization, using default: {self.model}")
litellm.drop_params = True
self.set_callbacks(callbacks)
self.set_env_callbacks()
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
def call(
self,
messages: List[Dict[str, str]],
callbacks: Optional[List[Any]] = None
) -> str:
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# Ensure model is a string and set default
model_name = "gpt-4" # Default model
# Extract model name from self.model
current = self.model
while current is not None:
if isinstance(current, str):
model_name = current
break
elif isinstance(current, LLM):
current = current.model
elif hasattr(current, "model"):
current = getattr(current, "model")
else:
break
# Store original model to restore later
original_model = self.model
try:
# Ensure model is a string before making the call
if not isinstance(self.model, str):
if self.logger:
self.logger.warning(f"Model is not a string in call method: {type(self.model)}. Attempting to convert...")
if isinstance(self.model, LLM):
self.model = self.model.model if isinstance(self.model.model, str) else "gpt-4"
elif hasattr(self.model, 'model_name'):
self.model = str(self.model.model_name)
elif hasattr(self.model, 'model'):
if isinstance(self.model.model, str):
self.model = str(self.model.model)
elif hasattr(self.model.model, 'model_name'):
self.model = str(self.model.model.model_name)
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name from nested model object, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name, using default: gpt-4")
if self.logger:
self.logger.debug(f"Using model: {self.model} (type: {type(self.model)}) for LiteLLM call")
# Create base params with validated model name
# Extract model name string
model_name = None
if isinstance(self.model, str):
model_name = self.model
elif hasattr(self.model, 'model_name'):
model_name = str(self.model.model_name)
elif hasattr(self.model, 'model'):
if isinstance(self.model.model, str):
model_name = str(self.model.model)
elif hasattr(self.model.model, 'model_name'):
model_name = str(self.model.model.model_name)
if not model_name:
model_name = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name, using default: gpt-4")
# Set parameters for litellm
# Build base params dict with required fields
params = {
"model": model_name,
"custom_llm_provider": "openai",
"messages": messages,
"stream": False # Always set stream to False
"stream": False,
"api_key": self.api_key or os.getenv("OPENAI_API_KEY"),
"api_base": self.base_url,
"api_version": self.api_version,
}
# Add API configuration
if self.logger:
self.logger.debug(f"Using model parameters: {params}")
# Add API configuration if available
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
if api_key:
params["api_key"] = api_key
# Define optional parameters
optional_params = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
}
# Try to get supported parameters for the model
try:
supported_params = get_supported_openai_params(self.model)
optional_params = {}
if supported_params:
param_mapping = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
}
# Only add parameters that are supported and not None
optional_params = {
k: v for k, v in param_mapping.items()
if k in supported_params and v is not None
}
if "logprobs" in supported_params and self.logprobs is not None:
optional_params["logprobs"] = self.logprobs
if "top_logprobs" in supported_params and self.top_logprobs is not None:
optional_params["top_logprobs"] = self.top_logprobs
except Exception as e:
if self.logger:
self.logger.error(f"Failed to get supported params for model {self.model}: {str(e)}")
# If we can't get supported params, just add non-None parameters
param_mapping = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
}
optional_params = {k: v for k, v in param_mapping.items() if v is not None}
# Update params with optional parameters
params.update(optional_params)
# Add API endpoint configuration if available
if self.base_url:
optional_params["api_base"] = self.base_url
params["api_base"] = self.base_url
if self.api_version:
optional_params["api_version"] = self.api_version
params["api_version"] = self.api_version
# Final validation of model parameter
if not isinstance(params["model"], str):
if self.logger:
self.logger.error(f"Model is still not a string after all conversions: {type(params['model'])}")
params["model"] = "gpt-4"
# Update params with non-None optional parameters
params.update({k: v for k, v in optional_params.items() if v is not None})
@@ -238,21 +476,38 @@ class LLM:
params = {k: v for k, v in params.items() if v is not None}
response = litellm.completion(**params)
return response["choices"][0]["message"]["content"]
content = response["choices"][0]["message"]["content"]
# Extract usage metrics
usage = response.get("usage", {})
if callbacks:
for callback in callbacks:
if hasattr(callback, "update_token_usage"):
callback.update_token_usage(usage)
return content
except Exception as e:
if not LLMContextLengthExceededException(
str(e)
)._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}")
raise # Re-raise the exception after logging
finally:
# Always restore the original model object
self.model = original_model
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
bool: True if the model supports function calling, False otherwise
"""
try:
params = get_supported_openai_params(model=self.model)
return "response_format" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
if self.logger:
self.logger.error(f"Failed to get supported params: {str(e)}")
return False
def supports_stop_words(self) -> bool:
@@ -264,33 +519,47 @@ class LLM:
params = get_supported_openai_params(model=self.model)
return "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
if self.logger:
self.logger.error(f"Failed to get supported params: {str(e)}")
return False
def get_context_window_size(self) -> int:
"""Get the context window size for the current model.
Returns:
int: The context window size in tokens
"""
# Only using 75% of the context window size to avoid cutting the message in the middle
if self.context_window_size != 0:
return self.context_window_size
if self.context_window_size is not None and self.context_window_size != 0:
return int(self.context_window_size)
self.context_window_size = int(
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
)
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if self.model.startswith(key):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
window_size = DEFAULT_CONTEXT_WINDOW_SIZE
if isinstance(self.model, str):
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if self.model.startswith(key):
window_size = value
break
self.context_window_size = int(window_size * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
def set_callbacks(self, callbacks: Optional[List[Any]] = None) -> None:
"""Set callbacks for the LLM.
Args:
callbacks: Optional list of callback functions. If None, no callbacks will be set.
"""
if callbacks is not None:
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
litellm.callbacks = callbacks
litellm.callbacks = callbacks
def set_env_callbacks(self):
"""