Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
31449024a7 Fix linting issue in agent.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-25 10:59:42 +00:00
Devin AI
556da08555 Address PR review: Add validation and helper functions for endpoint handling
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-25 10:57:55 +00:00
Devin AI
9e5084ba22 Fix issue 2216: Pass custom endpoint to LiteLLM for Ollama provider
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-25 10:47:09 +00:00
2 changed files with 57 additions and 7 deletions

View File

@@ -3,7 +3,7 @@ import shutil
import subprocess
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from pydantic import Field, InstanceOf, PrivateAttr, field_validator, model_validator
from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -111,15 +111,30 @@ class Agent(BaseAgent):
default=None,
description="Embedder configuration for the agent.",
)
endpoint: Optional[str] = Field(
default=None,
description="Custom endpoint URL for the LLM API. Primarily used for Ollama models to specify alternative API endpoints.",
examples=["http://localhost:11434", "https://ollama.example.com:11434"],
)
@field_validator("endpoint")
def validate_endpoint(cls, v):
if v is not None:
if not isinstance(v, str):
raise ValueError("Endpoint must be a string")
if not v.startswith(("http://", "https://")):
raise ValueError("Endpoint must start with http:// or https://")
return v
@model_validator(mode="after")
def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm)
# Pass endpoint to create_llm if it exists
self.llm = create_llm(self.llm, endpoint=self.endpoint)
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
self.function_calling_llm = create_llm(self.function_calling_llm)
self.function_calling_llm = create_llm(self.function_calling_llm, endpoint=self.endpoint)
if not self.agent_executor:
self._setup_agent_executor()

View File

@@ -5,18 +5,37 @@ from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM
def is_ollama_model(model_name: str) -> bool:
"""Check if a model name is an Ollama model."""
return bool(model_name and "ollama" in str(model_name).lower())
def validate_and_set_endpoint(endpoint: str) -> str:
"""Validate and format an endpoint URL."""
if not endpoint:
return endpoint
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Invalid endpoint URL: {endpoint}. Must start with http:// or https://")
return endpoint.rstrip("/")
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
endpoint: Optional[str] = None,
) -> Optional[LLM]:
"""
Creates or returns an LLM instance based on the given llm_value.
Args:
llm_value (str | LLM | Any | None):
- str: The model name (e.g., "gpt-4").
- str: Model name to use for instantiating a new LLM.
- LLM: Already instantiated LLM, returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
endpoint (str | None):
- Optional API endpoint URL. When provided for Ollama models,
this will override the default API endpoint. Should be a valid
HTTP/HTTPS URL (e.g., "http://localhost:11434").
Returns:
An LLM instance if successful, or None if something fails.
@@ -29,7 +48,11 @@ def create_llm(
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str):
try:
created_llm = LLM(model=llm_value)
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and is_ollama_model(llm_value):
created_llm = LLM(model=llm_value, api_base=validate_and_set_endpoint(endpoint))
else:
created_llm = LLM(model=llm_value)
return created_llm
except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
@@ -37,7 +60,7 @@ def create_llm(
# 3) If llm_value is None, parse environment variables or use default
if llm_value is None:
return _llm_via_environment_or_fallback()
return _llm_via_environment_or_fallback(endpoint=endpoint)
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
try:
@@ -55,6 +78,10 @@ def create_llm(
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and is_ollama_model(model):
api_base = validate_and_set_endpoint(endpoint)
created_llm = LLM(
model=model,
temperature=temperature,
@@ -71,9 +98,13 @@ def create_llm(
return None
def _llm_via_environment_or_fallback() -> Optional[LLM]:
def _llm_via_environment_or_fallback(endpoint: Optional[str] = None) -> Optional[LLM]:
"""
Helper function: if llm_value is None, we load environment variables or fallback default model.
Args:
endpoint (str | None):
- Optional endpoint URL for the LLM API.
"""
model_name = (
os.environ.get("OPENAI_MODEL_NAME")
@@ -172,6 +203,10 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
# Remove None values
llm_params = {k: v for k, v in llm_params.items() if v is not None}
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and is_ollama_model(model_name):
llm_params["api_base"] = validate_and_set_endpoint(endpoint)
# Try creating the LLM
try:
new_llm = LLM(**llm_params)