Fix issue 2216: Pass custom endpoint to LiteLLM for Ollama provider

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-25 10:47:09 +00:00
parent 5bae78639e
commit 9e5084ba22
2 changed files with 29 additions and 5 deletions

View File

@@ -111,15 +111,20 @@ class Agent(BaseAgent):
default=None,
description="Embedder configuration for the agent.",
)
endpoint: Optional[str] = Field(
default=None,
description="Custom endpoint URL for the LLM API.",
)
@model_validator(mode="after")
def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm)
# Pass endpoint to create_llm if it exists
self.llm = create_llm(self.llm, endpoint=self.endpoint)
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
self.function_calling_llm = create_llm(self.function_calling_llm)
self.function_calling_llm = create_llm(self.function_calling_llm, endpoint=self.endpoint)
if not self.agent_executor:
self._setup_agent_executor()

View File

@@ -7,6 +7,7 @@ from crewai.llm import LLM
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
endpoint: Optional[str] = None,
) -> Optional[LLM]:
"""
Creates or returns an LLM instance based on the given llm_value.
@@ -17,6 +18,8 @@ def create_llm(
- LLM: Already instantiated LLM, returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
endpoint (str | None):
- Optional endpoint URL for the LLM API.
Returns:
An LLM instance if successful, or None if something fails.
@@ -29,7 +32,11 @@ def create_llm(
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str):
try:
created_llm = LLM(model=llm_value)
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in llm_value.lower():
created_llm = LLM(model=llm_value, api_base=endpoint)
else:
created_llm = LLM(model=llm_value)
return created_llm
except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
@@ -37,7 +44,7 @@ def create_llm(
# 3) If llm_value is None, parse environment variables or use default
if llm_value is None:
return _llm_via_environment_or_fallback()
return _llm_via_environment_or_fallback(endpoint=endpoint)
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
try:
@@ -55,6 +62,10 @@ def create_llm(
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in str(model).lower():
api_base = endpoint
created_llm = LLM(
model=model,
temperature=temperature,
@@ -71,9 +82,13 @@ def create_llm(
return None
def _llm_via_environment_or_fallback() -> Optional[LLM]:
def _llm_via_environment_or_fallback(endpoint: Optional[str] = None) -> Optional[LLM]:
"""
Helper function: if llm_value is None, we load environment variables or fallback default model.
Args:
endpoint (str | None):
- Optional endpoint URL for the LLM API.
"""
model_name = (
os.environ.get("OPENAI_MODEL_NAME")
@@ -172,6 +187,10 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
# Remove None values
llm_params = {k: v for k, v in llm_params.items() if v is not None}
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in model_name.lower():
llm_params["api_base"] = endpoint
# Try creating the LLM
try:
new_llm = LLM(**llm_params)