Fix issue 2216: Pass custom endpoint to LiteLLM for Ollama provider

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-25 10:47:09 +00:00
parent 5bae78639e
commit 9e5084ba22
2 changed files with 29 additions and 5 deletions

View File

@@ -111,15 +111,20 @@ class Agent(BaseAgent):
default=None, default=None,
description="Embedder configuration for the agent.", description="Embedder configuration for the agent.",
) )
endpoint: Optional[str] = Field(
default=None,
description="Custom endpoint URL for the LLM API.",
)
@model_validator(mode="after") @model_validator(mode="after")
def post_init_setup(self): def post_init_setup(self):
self._set_knowledge() self._set_knowledge()
self.agent_ops_agent_name = self.role self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm) # Pass endpoint to create_llm if it exists
self.llm = create_llm(self.llm, endpoint=self.endpoint)
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM): if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
self.function_calling_llm = create_llm(self.function_calling_llm) self.function_calling_llm = create_llm(self.function_calling_llm, endpoint=self.endpoint)
if not self.agent_executor: if not self.agent_executor:
self._setup_agent_executor() self._setup_agent_executor()

View File

@@ -7,6 +7,7 @@ from crewai.llm import LLM
def create_llm( def create_llm(
llm_value: Union[str, LLM, Any, None] = None, llm_value: Union[str, LLM, Any, None] = None,
endpoint: Optional[str] = None,
) -> Optional[LLM]: ) -> Optional[LLM]:
""" """
Creates or returns an LLM instance based on the given llm_value. Creates or returns an LLM instance based on the given llm_value.
@@ -17,6 +18,8 @@ def create_llm(
- LLM: Already instantiated LLM, returned as-is. - LLM: Already instantiated LLM, returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc. - Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model. - None: Use environment-based or fallback default model.
endpoint (str | None):
- Optional endpoint URL for the LLM API.
Returns: Returns:
An LLM instance if successful, or None if something fails. An LLM instance if successful, or None if something fails.
@@ -29,7 +32,11 @@ def create_llm(
# 2) If llm_value is a string (model name) # 2) If llm_value is a string (model name)
if isinstance(llm_value, str): if isinstance(llm_value, str):
try: try:
created_llm = LLM(model=llm_value) # If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in llm_value.lower():
created_llm = LLM(model=llm_value, api_base=endpoint)
else:
created_llm = LLM(model=llm_value)
return created_llm return created_llm
except Exception as e: except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}") print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
@@ -37,7 +44,7 @@ def create_llm(
# 3) If llm_value is None, parse environment variables or use default # 3) If llm_value is None, parse environment variables or use default
if llm_value is None: if llm_value is None:
return _llm_via_environment_or_fallback() return _llm_via_environment_or_fallback(endpoint=endpoint)
# 4) Otherwise, attempt to extract relevant attributes from an unknown object # 4) Otherwise, attempt to extract relevant attributes from an unknown object
try: try:
@@ -55,6 +62,10 @@ def create_llm(
base_url: Optional[str] = getattr(llm_value, "base_url", None) base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None) api_base: Optional[str] = getattr(llm_value, "api_base", None)
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in str(model).lower():
api_base = endpoint
created_llm = LLM( created_llm = LLM(
model=model, model=model,
temperature=temperature, temperature=temperature,
@@ -71,9 +82,13 @@ def create_llm(
return None return None
def _llm_via_environment_or_fallback() -> Optional[LLM]: def _llm_via_environment_or_fallback(endpoint: Optional[str] = None) -> Optional[LLM]:
""" """
Helper function: if llm_value is None, we load environment variables or fallback default model. Helper function: if llm_value is None, we load environment variables or fallback default model.
Args:
endpoint (str | None):
- Optional endpoint URL for the LLM API.
""" """
model_name = ( model_name = (
os.environ.get("OPENAI_MODEL_NAME") os.environ.get("OPENAI_MODEL_NAME")
@@ -172,6 +187,10 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
# Remove None values # Remove None values
llm_params = {k: v for k, v in llm_params.items() if v is not None} llm_params = {k: v for k, v in llm_params.items() if v is not None}
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in model_name.lower():
llm_params["api_base"] = endpoint
# Try creating the LLM # Try creating the LLM
try: try:
new_llm = LLM(**llm_params) new_llm = LLM(**llm_params)