mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-15 02:58:30 +00:00
implementing initial LLM class
This commit is contained in:
@@ -104,11 +104,23 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
try:
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
if not self.request_within_rpm_limit or self.request_within_rpm_limit():
|
||||
answer = LLM(
|
||||
self.llm,
|
||||
stop=self.stop if self.use_stop_words else None,
|
||||
callbacks=self.callbacks,
|
||||
).call(self.messages)
|
||||
if isinstance(self.llm, str):
|
||||
llm = LLM(
|
||||
model=self.llm,
|
||||
stop=self.stop if self.use_stop_words else None,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
elif isinstance(self.llm, LLM):
|
||||
llm = self.llm
|
||||
else:
|
||||
llm = LLM(
|
||||
model=self.llm.model,
|
||||
provider=getattr(self.llm, "provider", "litellm"),
|
||||
stop=self.stop if self.use_stop_words else None,
|
||||
callbacks=self.callbacks,
|
||||
**getattr(self.llm, "llm_kwargs", {}),
|
||||
)
|
||||
answer = llm.call(self.messages)
|
||||
|
||||
if not self.use_stop_words:
|
||||
try:
|
||||
@@ -241,7 +253,16 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return tool_result
|
||||
|
||||
def _summarize_messages(self) -> None:
|
||||
llm = LLM(self.llm)
|
||||
if isinstance(self.llm, str):
|
||||
llm = LLM(model=self.llm)
|
||||
elif isinstance(self.llm, LLM):
|
||||
llm = self.llm
|
||||
else:
|
||||
llm = LLM(
|
||||
model=self.llm.model,
|
||||
provider=getattr(self.llm, "provider", "litellm"),
|
||||
**getattr(self.llm, "llm_kwargs", {}),
|
||||
)
|
||||
messages_groups = []
|
||||
|
||||
for message in self.messages:
|
||||
|
||||
@@ -1,20 +1,86 @@
|
||||
from typing import Any, Dict, List
|
||||
from litellm import completion
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import logging
|
||||
import litellm
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(self, model: str, stop: List[str] = [], callbacks: List[Any] = []):
|
||||
self.stop = stop
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Dict[str, Any]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.stop = stop
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.kwargs = kwargs
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def call(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
|
||||
response = completion(
|
||||
stop=self.stop, model=self.model, messages=messages, num_retries=5
|
||||
)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
def call(self, messages: List[Dict[str, str]]) -> str:
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
**self.kwargs,
|
||||
}
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
def _call_callbacks(self, formatted_answer):
|
||||
for callback in self.callbacks:
|
||||
callback(formatted_answer)
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
raise # Re-raise the exception after logging
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.kwargs.get(name)
|
||||
|
||||
@@ -27,7 +27,8 @@ class Converter(OutputConverter):
|
||||
if self.is_gpt:
|
||||
return self._create_instructor().to_pydantic()
|
||||
else:
|
||||
return LLM(model=self.llm).call(
|
||||
llm = self._create_llm()
|
||||
return llm.call(
|
||||
[
|
||||
{"role": "system", "content": self.instructions},
|
||||
{"role": "user", "content": self.text},
|
||||
@@ -46,8 +47,9 @@ class Converter(OutputConverter):
|
||||
if self.is_gpt:
|
||||
return self._create_instructor().to_json()
|
||||
else:
|
||||
llm = self._create_llm()
|
||||
return json.dumps(
|
||||
LLM(model=self.llm).call(
|
||||
llm.call(
|
||||
[
|
||||
{"role": "system", "content": self.instructions},
|
||||
{"role": "user", "content": self.text},
|
||||
@@ -59,6 +61,19 @@ class Converter(OutputConverter):
|
||||
return self.to_json(current_attempt + 1)
|
||||
return ConverterError(f"Failed to convert text into JSON, error: {e}.")
|
||||
|
||||
def _create_llm(self):
|
||||
"""Create an LLM instance."""
|
||||
if isinstance(self.llm, str):
|
||||
return LLM(model=self.llm)
|
||||
elif isinstance(self.llm, LLM):
|
||||
return self.llm
|
||||
else:
|
||||
return LLM(
|
||||
model=self.llm.model,
|
||||
provider=getattr(self.llm, "provider", "litellm"),
|
||||
**getattr(self.llm, "llm_kwargs", {}),
|
||||
)
|
||||
|
||||
def _create_instructor(self):
|
||||
"""Create an instructor."""
|
||||
from crewai.utilities import InternalInstructor
|
||||
|
||||
Reference in New Issue
Block a user