mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
implementing initial LLM class
This commit is contained in:
@@ -1,20 +1,86 @@
|
||||
from typing import Any, Dict, List
|
||||
from litellm import completion
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import logging
|
||||
import litellm
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(self, model: str, stop: List[str] = [], callbacks: List[Any] = []):
|
||||
self.stop = stop
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Dict[str, Any]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.stop = stop
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.kwargs = kwargs
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def call(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
|
||||
response = completion(
|
||||
stop=self.stop, model=self.model, messages=messages, num_retries=5
|
||||
)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
def call(self, messages: List[Dict[str, str]]) -> str:
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
**self.kwargs,
|
||||
}
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
def _call_callbacks(self, formatted_answer):
|
||||
for callback in self.callbacks:
|
||||
callback(formatted_answer)
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
raise # Re-raise the exception after logging
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.kwargs.get(name)
|
||||
|
||||
Reference in New Issue
Block a user