mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
fix: Add proper null checks for logger calls and improve type safety in LLM class
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -6,6 +6,8 @@ import warnings
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
import litellm
|
import litellm
|
||||||
@@ -93,10 +95,33 @@ def suppress_warnings():
|
|||||||
sys.stderr = old_stderr
|
sys.stderr = old_stderr
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM(BaseModel):
|
||||||
|
model: str = "gpt-4" # Set default model
|
||||||
|
timeout: Optional[Union[float, int]] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
max_completion_tokens: Optional[int] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None
|
||||||
|
response_format: Optional[Dict[str, Any]] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
logprobs: Optional[bool] = None
|
||||||
|
top_logprobs: Optional[int] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_version: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
callbacks: Optional[List[Any]] = None
|
||||||
|
context_window_size: Optional[int] = None
|
||||||
|
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
logger: Optional[logging.Logger] = Field(default_factory=lambda: logging.getLogger(__name__))
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Union[str, 'LLM'],
|
model: Optional[Union[str, 'LLM']] = "gpt-4",
|
||||||
timeout: Optional[Union[float, int]] = None,
|
timeout: Optional[Union[float, int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
@@ -114,12 +139,103 @@ class LLM:
|
|||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
callbacks: List[Any] = [],
|
callbacks: Optional[List[Any]] = None,
|
||||||
**kwargs,
|
context_window_size: Optional[int] = None,
|
||||||
):
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
# Initialize with default values
|
||||||
|
init_dict = {
|
||||||
|
"model": model if isinstance(model, str) else "gpt-4",
|
||||||
|
"timeout": timeout,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"n": n,
|
||||||
|
"stop": stop,
|
||||||
|
"max_completion_tokens": max_completion_tokens,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"presence_penalty": presence_penalty,
|
||||||
|
"frequency_penalty": frequency_penalty,
|
||||||
|
"logit_bias": logit_bias,
|
||||||
|
"response_format": response_format,
|
||||||
|
"seed": seed,
|
||||||
|
"logprobs": logprobs,
|
||||||
|
"top_logprobs": top_logprobs,
|
||||||
|
"base_url": base_url,
|
||||||
|
"api_version": api_version,
|
||||||
|
"api_key": api_key,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
"context_window_size": context_window_size,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
super().__init__(**init_dict)
|
||||||
|
|
||||||
|
# Initialize model with default value
|
||||||
|
self.model = "gpt-4" # Default fallback
|
||||||
|
|
||||||
|
# Extract and validate model name
|
||||||
|
if isinstance(model, LLM):
|
||||||
|
# Extract and validate model name from LLM instance
|
||||||
|
if hasattr(model, 'model'):
|
||||||
|
if isinstance(model.model, str):
|
||||||
|
self.model = model.model
|
||||||
|
else:
|
||||||
|
# Try to extract string model name from nested LLM
|
||||||
|
if isinstance(model.model, LLM):
|
||||||
|
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
# Extract and validate model name for non-LLM instances
|
||||||
|
if not isinstance(model, str):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
|
||||||
|
if model is not None:
|
||||||
|
if hasattr(model, 'model_name'):
|
||||||
|
model_name = getattr(model, 'model_name', None)
|
||||||
|
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||||
|
elif hasattr(model, 'model'):
|
||||||
|
model_attr = getattr(model, 'model', None)
|
||||||
|
self.model = str(model_attr) if model_attr is not None else "gpt-4"
|
||||||
|
elif hasattr(model, '_model_name'):
|
||||||
|
model_name = getattr(model, '_model_name', None)
|
||||||
|
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4" # Default fallback
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4" # Default fallback for None
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Model is None, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = str(model) # Ensure it's a string
|
||||||
|
|
||||||
# If model is an LLM instance, copy its configuration
|
# If model is an LLM instance, copy its configuration
|
||||||
if isinstance(model, LLM):
|
if isinstance(model, LLM):
|
||||||
self.model = model.model
|
# Extract and validate model name first
|
||||||
|
if hasattr(model, 'model'):
|
||||||
|
if isinstance(model.model, str):
|
||||||
|
self.model = model.model
|
||||||
|
else:
|
||||||
|
# Try to extract string model name from nested LLM
|
||||||
|
if isinstance(model.model, LLM):
|
||||||
|
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
|
||||||
|
|
||||||
|
# Copy other configuration
|
||||||
self.timeout = model.timeout
|
self.timeout = model.timeout
|
||||||
self.temperature = model.temperature
|
self.temperature = model.temperature
|
||||||
self.top_p = model.top_p
|
self.top_p = model.top_p
|
||||||
@@ -140,8 +256,44 @@ class LLM:
|
|||||||
self.callbacks = model.callbacks
|
self.callbacks = model.callbacks
|
||||||
self.context_window_size = model.context_window_size
|
self.context_window_size = model.context_window_size
|
||||||
self.kwargs = model.kwargs
|
self.kwargs = model.kwargs
|
||||||
|
|
||||||
|
# Final validation of model name
|
||||||
|
if not isinstance(self.model, str):
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Model name is still not a string after LLM copy, using default: gpt-4")
|
||||||
else:
|
else:
|
||||||
self.model = model
|
# Extract and validate model name for non-LLM instances
|
||||||
|
if not isinstance(model, str):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
|
||||||
|
if model is not None:
|
||||||
|
if hasattr(model, 'model_name'):
|
||||||
|
model_name = getattr(model, 'model_name', None)
|
||||||
|
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||||
|
elif hasattr(model, 'model'):
|
||||||
|
model_attr = getattr(model, 'model', None)
|
||||||
|
self.model = str(model_attr) if model_attr is not None else "gpt-4"
|
||||||
|
elif hasattr(model, '_model_name'):
|
||||||
|
model_name = getattr(model, '_model_name', None)
|
||||||
|
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4" # Default fallback
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4" # Default fallback for None
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Model is None, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = str(model) # Ensure it's a string
|
||||||
|
|
||||||
|
# Final validation
|
||||||
|
if not isinstance(self.model, str):
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Model name is still not a string after extraction, using default: gpt-4")
|
||||||
|
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
@@ -163,69 +315,155 @@ class LLM:
|
|||||||
self.context_window_size = 0
|
self.context_window_size = 0
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
# Ensure model is a string after initialization
|
||||||
|
if not isinstance(self.model, str):
|
||||||
|
self.model = "gpt-4"
|
||||||
|
self.logger.warning(f"Model is still not a string after initialization, using default: {self.model}")
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
self.set_env_callbacks()
|
self.set_env_callbacks()
|
||||||
|
|
||||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
def call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
callbacks: Optional[List[Any]] = None
|
||||||
|
) -> str:
|
||||||
with suppress_warnings():
|
with suppress_warnings():
|
||||||
if callbacks and len(callbacks) > 0:
|
if callbacks and len(callbacks) > 0:
|
||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
|
|
||||||
try:
|
# Store original model to restore later
|
||||||
# Ensure model is a string and set default
|
original_model = self.model
|
||||||
model_name = "gpt-4" # Default model
|
|
||||||
|
try:
|
||||||
# Extract model name from self.model
|
# Ensure model is a string before making the call
|
||||||
current = self.model
|
if not isinstance(self.model, str):
|
||||||
while current is not None:
|
if self.logger:
|
||||||
if isinstance(current, str):
|
self.logger.warning(f"Model is not a string in call method: {type(self.model)}. Attempting to convert...")
|
||||||
model_name = current
|
if isinstance(self.model, LLM):
|
||||||
break
|
self.model = self.model.model if isinstance(self.model.model, str) else "gpt-4"
|
||||||
elif isinstance(current, LLM):
|
elif hasattr(self.model, 'model_name'):
|
||||||
current = current.model
|
self.model = str(self.model.model_name)
|
||||||
elif hasattr(current, "model"):
|
elif hasattr(self.model, 'model'):
|
||||||
current = getattr(current, "model")
|
if isinstance(self.model.model, str):
|
||||||
else:
|
self.model = str(self.model.model)
|
||||||
break
|
elif hasattr(self.model.model, 'model_name'):
|
||||||
|
self.model = str(self.model.model.model_name)
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Could not extract model name from nested model object, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Could not extract model name, using default: gpt-4")
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"Using model: {self.model} (type: {type(self.model)}) for LiteLLM call")
|
||||||
|
|
||||||
|
# Create base params with validated model name
|
||||||
|
# Extract model name string
|
||||||
|
model_name = None
|
||||||
|
if isinstance(self.model, str):
|
||||||
|
model_name = self.model
|
||||||
|
elif hasattr(self.model, 'model_name'):
|
||||||
|
model_name = str(self.model.model_name)
|
||||||
|
elif hasattr(self.model, 'model'):
|
||||||
|
if isinstance(self.model.model, str):
|
||||||
|
model_name = str(self.model.model)
|
||||||
|
elif hasattr(self.model.model, 'model_name'):
|
||||||
|
model_name = str(self.model.model.model_name)
|
||||||
|
|
||||||
|
if not model_name:
|
||||||
|
model_name = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Could not extract model name, using default: gpt-4")
|
||||||
|
|
||||||
# Set parameters for litellm
|
|
||||||
# Build base params dict with required fields
|
|
||||||
params = {
|
params = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"custom_llm_provider": "openai",
|
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"stream": False # Always set stream to False
|
"stream": False,
|
||||||
|
"api_key": self.api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
|
"api_base": self.base_url,
|
||||||
|
"api_version": self.api_version,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add API configuration
|
if self.logger:
|
||||||
|
self.logger.debug(f"Using model parameters: {params}")
|
||||||
|
|
||||||
|
# Add API configuration if available
|
||||||
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||||
if api_key:
|
if api_key:
|
||||||
params["api_key"] = api_key
|
params["api_key"] = api_key
|
||||||
|
|
||||||
# Define optional parameters
|
# Try to get supported parameters for the model
|
||||||
optional_params = {
|
try:
|
||||||
"timeout": self.timeout,
|
supported_params = get_supported_openai_params(self.model)
|
||||||
"temperature": self.temperature,
|
optional_params = {}
|
||||||
"top_p": self.top_p,
|
|
||||||
"n": self.n,
|
if supported_params:
|
||||||
"stop": self.stop,
|
param_mapping = {
|
||||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
"timeout": self.timeout,
|
||||||
"presence_penalty": self.presence_penalty,
|
"temperature": self.temperature,
|
||||||
"frequency_penalty": self.frequency_penalty,
|
"top_p": self.top_p,
|
||||||
"logit_bias": self.logit_bias,
|
"n": self.n,
|
||||||
"response_format": self.response_format,
|
"stop": self.stop,
|
||||||
"seed": self.seed,
|
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||||
"logprobs": self.logprobs,
|
"presence_penalty": self.presence_penalty,
|
||||||
"top_logprobs": self.top_logprobs,
|
"frequency_penalty": self.frequency_penalty,
|
||||||
}
|
"logit_bias": self.logit_bias,
|
||||||
|
"response_format": self.response_format,
|
||||||
|
"seed": self.seed,
|
||||||
|
"logprobs": self.logprobs,
|
||||||
|
"top_logprobs": self.top_logprobs
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only add parameters that are supported and not None
|
||||||
|
optional_params = {
|
||||||
|
k: v for k, v in param_mapping.items()
|
||||||
|
if k in supported_params and v is not None
|
||||||
|
}
|
||||||
|
if "logprobs" in supported_params and self.logprobs is not None:
|
||||||
|
optional_params["logprobs"] = self.logprobs
|
||||||
|
if "top_logprobs" in supported_params and self.top_logprobs is not None:
|
||||||
|
optional_params["top_logprobs"] = self.top_logprobs
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Failed to get supported params for model {self.model}: {str(e)}")
|
||||||
|
# If we can't get supported params, just add non-None parameters
|
||||||
|
param_mapping = {
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"n": self.n,
|
||||||
|
"stop": self.stop,
|
||||||
|
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"logit_bias": self.logit_bias,
|
||||||
|
"response_format": self.response_format,
|
||||||
|
"seed": self.seed,
|
||||||
|
"logprobs": self.logprobs,
|
||||||
|
"top_logprobs": self.top_logprobs
|
||||||
|
}
|
||||||
|
optional_params = {k: v for k, v in param_mapping.items() if v is not None}
|
||||||
|
|
||||||
|
# Update params with optional parameters
|
||||||
|
params.update(optional_params)
|
||||||
|
|
||||||
# Add API endpoint configuration if available
|
# Add API endpoint configuration if available
|
||||||
if self.base_url:
|
if self.base_url:
|
||||||
optional_params["api_base"] = self.base_url
|
params["api_base"] = self.base_url
|
||||||
if self.api_version:
|
if self.api_version:
|
||||||
optional_params["api_version"] = self.api_version
|
params["api_version"] = self.api_version
|
||||||
|
|
||||||
|
# Final validation of model parameter
|
||||||
|
if not isinstance(params["model"], str):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Model is still not a string after all conversions: {type(params['model'])}")
|
||||||
|
params["model"] = "gpt-4"
|
||||||
|
|
||||||
# Update params with non-None optional parameters
|
# Update params with non-None optional parameters
|
||||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||||
@@ -238,21 +476,38 @@ class LLM:
|
|||||||
params = {k: v for k, v in params.items() if v is not None}
|
params = {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
response = litellm.completion(**params)
|
response = litellm.completion(**params)
|
||||||
return response["choices"][0]["message"]["content"]
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
# Extract usage metrics
|
||||||
|
usage = response.get("usage", {})
|
||||||
|
if callbacks:
|
||||||
|
for callback in callbacks:
|
||||||
|
if hasattr(callback, "update_token_usage"):
|
||||||
|
callback.update_token_usage(usage)
|
||||||
|
|
||||||
|
return content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if not LLMContextLengthExceededException(
|
if not LLMContextLengthExceededException(
|
||||||
str(e)
|
str(e)
|
||||||
)._is_context_limit_error(str(e)):
|
)._is_context_limit_error(str(e)):
|
||||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||||
|
|
||||||
raise # Re-raise the exception after logging
|
raise # Re-raise the exception after logging
|
||||||
|
finally:
|
||||||
|
# Always restore the original model object
|
||||||
|
self.model = original_model
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the LLM supports function calling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model supports function calling, False otherwise
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
params = get_supported_openai_params(model=self.model)
|
params = get_supported_openai_params(model=self.model)
|
||||||
return "response_format" in params
|
return "response_format" in params
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to get supported params: {str(e)}")
|
if self.logger:
|
||||||
|
self.logger.error(f"Failed to get supported params: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
def supports_stop_words(self) -> bool:
|
||||||
@@ -264,33 +519,47 @@ class LLM:
|
|||||||
params = get_supported_openai_params(model=self.model)
|
params = get_supported_openai_params(model=self.model)
|
||||||
return "stop" in params
|
return "stop" in params
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to get supported params: {str(e)}")
|
if self.logger:
|
||||||
|
self.logger.error(f"Failed to get supported params: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Get the context window size for the current model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The context window size in tokens
|
||||||
|
"""
|
||||||
# Only using 75% of the context window size to avoid cutting the message in the middle
|
# Only using 75% of the context window size to avoid cutting the message in the middle
|
||||||
if self.context_window_size != 0:
|
if self.context_window_size is not None and self.context_window_size != 0:
|
||||||
return self.context_window_size
|
return int(self.context_window_size)
|
||||||
|
|
||||||
self.context_window_size = int(
|
window_size = DEFAULT_CONTEXT_WINDOW_SIZE
|
||||||
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
|
if isinstance(self.model, str):
|
||||||
)
|
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
if self.model.startswith(key):
|
||||||
if self.model.startswith(key):
|
window_size = value
|
||||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
break
|
||||||
|
|
||||||
|
self.context_window_size = int(window_size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
return self.context_window_size
|
return self.context_window_size
|
||||||
|
|
||||||
def set_callbacks(self, callbacks: List[Any]):
|
def set_callbacks(self, callbacks: Optional[List[Any]] = None) -> None:
|
||||||
callback_types = [type(callback) for callback in callbacks]
|
"""Set callbacks for the LLM.
|
||||||
for callback in litellm.success_callback[:]:
|
|
||||||
if type(callback) in callback_types:
|
Args:
|
||||||
litellm.success_callback.remove(callback)
|
callbacks: Optional list of callback functions. If None, no callbacks will be set.
|
||||||
|
"""
|
||||||
|
if callbacks is not None:
|
||||||
|
callback_types = [type(callback) for callback in callbacks]
|
||||||
|
for callback in litellm.success_callback[:]:
|
||||||
|
if type(callback) in callback_types:
|
||||||
|
litellm.success_callback.remove(callback)
|
||||||
|
|
||||||
for callback in litellm._async_success_callback[:]:
|
for callback in litellm._async_success_callback[:]:
|
||||||
if type(callback) in callback_types:
|
if type(callback) in callback_types:
|
||||||
litellm._async_success_callback.remove(callback)
|
litellm._async_success_callback.remove(callback)
|
||||||
|
|
||||||
litellm.callbacks = callbacks
|
litellm.callbacks = callbacks
|
||||||
|
|
||||||
def set_env_callbacks(self):
|
def set_env_callbacks(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user