mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
Adding new LLM class
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from inspect import signature
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Union
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
@@ -12,6 +12,7 @@ from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
def mock_agent_ops_provider():
|
||||
@@ -81,8 +82,8 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Any = Field(
|
||||
description="Language model that will run the agent.", default="gpt-4o-mini"
|
||||
llm: Union[str, InstanceOf[LLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
@@ -118,17 +119,58 @@ class Agent(BaseAgent):
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
self.agent_ops_agent_name = self.role
|
||||
self.llm = (
|
||||
getattr(self.llm, "model_name", None)
|
||||
or getattr(self.llm, "deployment_name", None)
|
||||
or self.llm
|
||||
or os.environ.get("OPENAI_MODEL_NAME")
|
||||
)
|
||||
self.function_calling_llm = (
|
||||
getattr(self.function_calling_llm, "model_name", None)
|
||||
or getattr(self.function_calling_llm, "deployment_name", None)
|
||||
or self.function_calling_llm
|
||||
)
|
||||
|
||||
# Handle different cases for self.llm
|
||||
if isinstance(self.llm, str):
|
||||
# If it's a string, create an LLM instance
|
||||
self.llm = LLM(model=self.llm)
|
||||
elif isinstance(self.llm, LLM):
|
||||
# If it's already an LLM instance, keep it as is
|
||||
pass
|
||||
elif self.llm is None:
|
||||
# If it's None, use environment variables or default
|
||||
model_name = os.environ.get("OPENAI_MODEL_NAME", "gpt-4o-mini")
|
||||
llm_params = {"model": model_name}
|
||||
|
||||
api_base = os.environ.get("OPENAI_API_BASE")
|
||||
if api_base:
|
||||
llm_params["base_url"] = api_base
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
|
||||
self.llm = LLM(**llm_params)
|
||||
else:
|
||||
# For any other type, attempt to extract relevant attributes
|
||||
llm_params = {
|
||||
"model": getattr(self.llm, "model_name", None)
|
||||
or getattr(self.llm, "deployment_name", None)
|
||||
or str(self.llm),
|
||||
"temperature": getattr(self.llm, "temperature", None),
|
||||
"max_tokens": getattr(self.llm, "max_tokens", None),
|
||||
"logprobs": getattr(self.llm, "logprobs", None),
|
||||
"timeout": getattr(self.llm, "timeout", None),
|
||||
"max_retries": getattr(self.llm, "max_retries", None),
|
||||
"api_key": getattr(self.llm, "api_key", None),
|
||||
"base_url": getattr(self.llm, "base_url", None),
|
||||
"organization": getattr(self.llm, "organization", None),
|
||||
}
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
self.llm = LLM(**llm_params)
|
||||
|
||||
# Similar handling for function_calling_llm
|
||||
if self.function_calling_llm:
|
||||
if isinstance(self.function_calling_llm, str):
|
||||
self.function_calling_llm = LLM(model=self.function_calling_llm)
|
||||
elif not isinstance(self.function_calling_llm, LLM):
|
||||
self.function_calling_llm = LLM(
|
||||
model=getattr(self.function_calling_llm, "model_name", None)
|
||||
or getattr(self.function_calling_llm, "deployment_name", None)
|
||||
or str(self.function_calling_llm)
|
||||
)
|
||||
|
||||
if not self.agent_executor:
|
||||
self._setup_agent_executor()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user