fix: Add proper null checks for logger calls and improve type safety in LLM class

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-01-01 21:54:49 +00:00
parent 5d3c34b3ea
commit bfb578d506

View File

@@ -6,6 +6,8 @@ import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning) warnings.simplefilter("ignore", UserWarning)
import litellm import litellm
@@ -93,10 +95,33 @@ def suppress_warnings():
sys.stderr = old_stderr sys.stderr = old_stderr
class LLM: class LLM(BaseModel):
model: str = "gpt-4" # Set default model
timeout: Optional[Union[float, int]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
max_completion_tokens: Optional[int] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[int, float]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
callbacks: Optional[List[Any]] = None
context_window_size: Optional[int] = None
kwargs: Dict[str, Any] = Field(default_factory=dict)
logger: Optional[logging.Logger] = Field(default_factory=lambda: logging.getLogger(__name__))
def __init__( def __init__(
self, self,
model: Union[str, 'LLM'], model: Optional[Union[str, 'LLM']] = "gpt-4",
timeout: Optional[Union[float, int]] = None, timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
@@ -114,12 +139,103 @@ class LLM:
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
callbacks: List[Any] = [], callbacks: Optional[List[Any]] = None,
**kwargs, context_window_size: Optional[int] = None,
): **kwargs: Any,
) -> None:
# Initialize with default values
init_dict = {
"model": model if isinstance(model, str) else "gpt-4",
"timeout": timeout,
"temperature": temperature,
"top_p": top_p,
"n": n,
"stop": stop,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"response_format": response_format,
"seed": seed,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"base_url": base_url,
"api_version": api_version,
"api_key": api_key,
"callbacks": callbacks,
"context_window_size": context_window_size,
"kwargs": kwargs,
}
super().__init__(**init_dict)
# Initialize model with default value
self.model = "gpt-4" # Default fallback
# Extract and validate model name
if isinstance(model, LLM):
# Extract and validate model name from LLM instance
if hasattr(model, 'model'):
if isinstance(model.model, str):
self.model = model.model
else:
# Try to extract string model name from nested LLM
if isinstance(model.model, LLM):
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
else:
# Extract and validate model name for non-LLM instances
if not isinstance(model, str):
if self.logger:
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
if model is not None:
if hasattr(model, 'model_name'):
model_name = getattr(model, 'model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
elif hasattr(model, 'model'):
model_attr = getattr(model, 'model', None)
self.model = str(model_attr) if model_attr is not None else "gpt-4"
elif hasattr(model, '_model_name'):
model_name = getattr(model, '_model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
else:
self.model = "gpt-4" # Default fallback
if self.logger:
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
else:
self.model = "gpt-4" # Default fallback for None
if self.logger:
self.logger.warning("Model is None, using default: gpt-4")
else:
self.model = str(model) # Ensure it's a string
# If model is an LLM instance, copy its configuration # If model is an LLM instance, copy its configuration
if isinstance(model, LLM): if isinstance(model, LLM):
self.model = model.model # Extract and validate model name first
if hasattr(model, 'model'):
if isinstance(model.model, str):
self.model = model.model
else:
# Try to extract string model name from nested LLM
if isinstance(model.model, LLM):
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
# Copy other configuration
self.timeout = model.timeout self.timeout = model.timeout
self.temperature = model.temperature self.temperature = model.temperature
self.top_p = model.top_p self.top_p = model.top_p
@@ -140,8 +256,44 @@ class LLM:
self.callbacks = model.callbacks self.callbacks = model.callbacks
self.context_window_size = model.context_window_size self.context_window_size = model.context_window_size
self.kwargs = model.kwargs self.kwargs = model.kwargs
# Final validation of model name
if not isinstance(self.model, str):
self.model = "gpt-4"
if self.logger:
self.logger.warning("Model name is still not a string after LLM copy, using default: gpt-4")
else: else:
self.model = model # Extract and validate model name for non-LLM instances
if not isinstance(model, str):
if self.logger:
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
if model is not None:
if hasattr(model, 'model_name'):
model_name = getattr(model, 'model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
elif hasattr(model, 'model'):
model_attr = getattr(model, 'model', None)
self.model = str(model_attr) if model_attr is not None else "gpt-4"
elif hasattr(model, '_model_name'):
model_name = getattr(model, '_model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
else:
self.model = "gpt-4" # Default fallback
if self.logger:
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
else:
self.model = "gpt-4" # Default fallback for None
if self.logger:
self.logger.warning("Model is None, using default: gpt-4")
else:
self.model = str(model) # Ensure it's a string
# Final validation
if not isinstance(self.model, str):
self.model = "gpt-4"
if self.logger:
self.logger.warning("Model name is still not a string after extraction, using default: gpt-4")
self.timeout = timeout self.timeout = timeout
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
@@ -163,69 +315,155 @@ class LLM:
self.context_window_size = 0 self.context_window_size = 0
self.kwargs = kwargs self.kwargs = kwargs
# Ensure model is a string after initialization
if not isinstance(self.model, str):
self.model = "gpt-4"
self.logger.warning(f"Model is still not a string after initialization, using default: {self.model}")
litellm.drop_params = True litellm.drop_params = True
self.set_callbacks(callbacks) self.set_callbacks(callbacks)
self.set_env_callbacks() self.set_env_callbacks()
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str: def call(
self,
messages: List[Dict[str, str]],
callbacks: Optional[List[Any]] = None
) -> str:
with suppress_warnings(): with suppress_warnings():
if callbacks and len(callbacks) > 0: if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks) self.set_callbacks(callbacks)
try: # Store original model to restore later
# Ensure model is a string and set default original_model = self.model
model_name = "gpt-4" # Default model
try:
# Extract model name from self.model # Ensure model is a string before making the call
current = self.model if not isinstance(self.model, str):
while current is not None: if self.logger:
if isinstance(current, str): self.logger.warning(f"Model is not a string in call method: {type(self.model)}. Attempting to convert...")
model_name = current if isinstance(self.model, LLM):
break self.model = self.model.model if isinstance(self.model.model, str) else "gpt-4"
elif isinstance(current, LLM): elif hasattr(self.model, 'model_name'):
current = current.model self.model = str(self.model.model_name)
elif hasattr(current, "model"): elif hasattr(self.model, 'model'):
current = getattr(current, "model") if isinstance(self.model.model, str):
else: self.model = str(self.model.model)
break elif hasattr(self.model.model, 'model_name'):
self.model = str(self.model.model.model_name)
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name from nested model object, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name, using default: gpt-4")
if self.logger:
self.logger.debug(f"Using model: {self.model} (type: {type(self.model)}) for LiteLLM call")
# Create base params with validated model name
# Extract model name string
model_name = None
if isinstance(self.model, str):
model_name = self.model
elif hasattr(self.model, 'model_name'):
model_name = str(self.model.model_name)
elif hasattr(self.model, 'model'):
if isinstance(self.model.model, str):
model_name = str(self.model.model)
elif hasattr(self.model.model, 'model_name'):
model_name = str(self.model.model.model_name)
if not model_name:
model_name = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name, using default: gpt-4")
# Set parameters for litellm
# Build base params dict with required fields
params = { params = {
"model": model_name, "model": model_name,
"custom_llm_provider": "openai",
"messages": messages, "messages": messages,
"stream": False # Always set stream to False "stream": False,
"api_key": self.api_key or os.getenv("OPENAI_API_KEY"),
"api_base": self.base_url,
"api_version": self.api_version,
} }
# Add API configuration if self.logger:
self.logger.debug(f"Using model parameters: {params}")
# Add API configuration if available
api_key = self.api_key or os.getenv("OPENAI_API_KEY") api_key = self.api_key or os.getenv("OPENAI_API_KEY")
if api_key: if api_key:
params["api_key"] = api_key params["api_key"] = api_key
# Define optional parameters # Try to get supported parameters for the model
optional_params = { try:
"timeout": self.timeout, supported_params = get_supported_openai_params(self.model)
"temperature": self.temperature, optional_params = {}
"top_p": self.top_p,
"n": self.n, if supported_params:
"stop": self.stop, param_mapping = {
"max_tokens": self.max_tokens or self.max_completion_tokens, "timeout": self.timeout,
"presence_penalty": self.presence_penalty, "temperature": self.temperature,
"frequency_penalty": self.frequency_penalty, "top_p": self.top_p,
"logit_bias": self.logit_bias, "n": self.n,
"response_format": self.response_format, "stop": self.stop,
"seed": self.seed, "max_tokens": self.max_tokens or self.max_completion_tokens,
"logprobs": self.logprobs, "presence_penalty": self.presence_penalty,
"top_logprobs": self.top_logprobs, "frequency_penalty": self.frequency_penalty,
} "logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
}
# Only add parameters that are supported and not None
optional_params = {
k: v for k, v in param_mapping.items()
if k in supported_params and v is not None
}
if "logprobs" in supported_params and self.logprobs is not None:
optional_params["logprobs"] = self.logprobs
if "top_logprobs" in supported_params and self.top_logprobs is not None:
optional_params["top_logprobs"] = self.top_logprobs
except Exception as e:
if self.logger:
self.logger.error(f"Failed to get supported params for model {self.model}: {str(e)}")
# If we can't get supported params, just add non-None parameters
param_mapping = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
}
optional_params = {k: v for k, v in param_mapping.items() if v is not None}
# Update params with optional parameters
params.update(optional_params)
# Add API endpoint configuration if available # Add API endpoint configuration if available
if self.base_url: if self.base_url:
optional_params["api_base"] = self.base_url params["api_base"] = self.base_url
if self.api_version: if self.api_version:
optional_params["api_version"] = self.api_version params["api_version"] = self.api_version
# Final validation of model parameter
if not isinstance(params["model"], str):
if self.logger:
self.logger.error(f"Model is still not a string after all conversions: {type(params['model'])}")
params["model"] = "gpt-4"
# Update params with non-None optional parameters # Update params with non-None optional parameters
params.update({k: v for k, v in optional_params.items() if v is not None}) params.update({k: v for k, v in optional_params.items() if v is not None})
@@ -238,21 +476,38 @@ class LLM:
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
response = litellm.completion(**params) response = litellm.completion(**params)
return response["choices"][0]["message"]["content"] content = response["choices"][0]["message"]["content"]
# Extract usage metrics
usage = response.get("usage", {})
if callbacks:
for callback in callbacks:
if hasattr(callback, "update_token_usage"):
callback.update_token_usage(usage)
return content
except Exception as e: except Exception as e:
if not LLMContextLengthExceededException( if not LLMContextLengthExceededException(
str(e) str(e)
)._is_context_limit_error(str(e)): )._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}") logging.error(f"LiteLLM call failed: {str(e)}")
raise # Re-raise the exception after logging raise # Re-raise the exception after logging
finally:
# Always restore the original model object
self.model = original_model
def supports_function_calling(self) -> bool: def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
bool: True if the model supports function calling, False otherwise
"""
try: try:
params = get_supported_openai_params(model=self.model) params = get_supported_openai_params(model=self.model)
return "response_format" in params return "response_format" in params
except Exception as e: except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}") if self.logger:
self.logger.error(f"Failed to get supported params: {str(e)}")
return False return False
def supports_stop_words(self) -> bool: def supports_stop_words(self) -> bool:
@@ -264,33 +519,47 @@ class LLM:
params = get_supported_openai_params(model=self.model) params = get_supported_openai_params(model=self.model)
return "stop" in params return "stop" in params
except Exception as e: except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}") if self.logger:
self.logger.error(f"Failed to get supported params: {str(e)}")
return False return False
def get_context_window_size(self) -> int: def get_context_window_size(self) -> int:
"""Get the context window size for the current model.
Returns:
int: The context window size in tokens
"""
# Only using 75% of the context window size to avoid cutting the message in the middle # Only using 75% of the context window size to avoid cutting the message in the middle
if self.context_window_size != 0: if self.context_window_size is not None and self.context_window_size != 0:
return self.context_window_size return int(self.context_window_size)
self.context_window_size = int( window_size = DEFAULT_CONTEXT_WINDOW_SIZE
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO if isinstance(self.model, str):
) for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): if self.model.startswith(key):
if self.model.startswith(key): window_size = value
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) break
self.context_window_size = int(window_size * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size return self.context_window_size
def set_callbacks(self, callbacks: List[Any]): def set_callbacks(self, callbacks: Optional[List[Any]] = None) -> None:
callback_types = [type(callback) for callback in callbacks] """Set callbacks for the LLM.
for callback in litellm.success_callback[:]:
if type(callback) in callback_types: Args:
litellm.success_callback.remove(callback) callbacks: Optional list of callback functions. If None, no callbacks will be set.
"""
if callbacks is not None:
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]: for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types: if type(callback) in callback_types:
litellm._async_success_callback.remove(callback) litellm._async_success_callback.remove(callback)
litellm.callbacks = callbacks litellm.callbacks = callbacks
def set_env_callbacks(self): def set_env_callbacks(self):
""" """