Adding new LLM class

This commit is contained in:
João Moura
2024-09-23 03:59:05 -03:00
parent 59e51f18fd
commit a19a4a5556
9 changed files with 124 additions and 93 deletions

View File

@@ -1,6 +1,6 @@
import os
from inspect import signature
from typing import Any, List, Optional
from typing import Any, List, Optional, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.agents import CacheHandler
@@ -12,6 +12,7 @@ from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.llm import LLM
def mock_agent_ops_provider():
@@ -81,8 +82,8 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: Any = Field(
description="Language model that will run the agent.", default="gpt-4o-mini"
llm: Union[str, InstanceOf[LLM], Any] = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Any] = Field(
description="Language model that will run the agent.", default=None
@@ -118,17 +119,58 @@ class Agent(BaseAgent):
@model_validator(mode="after")
def post_init_setup(self):
self.agent_ops_agent_name = self.role
self.llm = (
getattr(self.llm, "model_name", None)
or getattr(self.llm, "deployment_name", None)
or self.llm
or os.environ.get("OPENAI_MODEL_NAME")
)
self.function_calling_llm = (
getattr(self.function_calling_llm, "model_name", None)
or getattr(self.function_calling_llm, "deployment_name", None)
or self.function_calling_llm
)
# Handle different cases for self.llm
if isinstance(self.llm, str):
# If it's a string, create an LLM instance
self.llm = LLM(model=self.llm)
elif isinstance(self.llm, LLM):
# If it's already an LLM instance, keep it as is
pass
elif self.llm is None:
# If it's None, use environment variables or default
model_name = os.environ.get("OPENAI_MODEL_NAME", "gpt-4o-mini")
llm_params = {"model": model_name}
api_base = os.environ.get("OPENAI_API_BASE")
if api_base:
llm_params["base_url"] = api_base
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
llm_params["api_key"] = api_key
self.llm = LLM(**llm_params)
else:
# For any other type, attempt to extract relevant attributes
llm_params = {
"model": getattr(self.llm, "model_name", None)
or getattr(self.llm, "deployment_name", None)
or str(self.llm),
"temperature": getattr(self.llm, "temperature", None),
"max_tokens": getattr(self.llm, "max_tokens", None),
"logprobs": getattr(self.llm, "logprobs", None),
"timeout": getattr(self.llm, "timeout", None),
"max_retries": getattr(self.llm, "max_retries", None),
"api_key": getattr(self.llm, "api_key", None),
"base_url": getattr(self.llm, "base_url", None),
"organization": getattr(self.llm, "organization", None),
}
# Remove None values to avoid passing unnecessary parameters
llm_params = {k: v for k, v in llm_params.items() if v is not None}
self.llm = LLM(**llm_params)
# Similar handling for function_calling_llm
if self.function_calling_llm:
if isinstance(self.function_calling_llm, str):
self.function_calling_llm = LLM(model=self.function_calling_llm)
elif not isinstance(self.function_calling_llm, LLM):
self.function_calling_llm = LLM(
model=getattr(self.function_calling_llm, "model_name", None)
or getattr(self.function_calling_llm, "deployment_name", None)
or str(self.function_calling_llm)
)
if not self.agent_executor:
self._setup_agent_executor()