Address PR review: Add validation and helper functions for endpoint handling

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-25 10:57:55 +00:00
parent 9e5084ba22
commit 556da08555
2 changed files with 36 additions and 10 deletions

View File

@@ -3,7 +3,7 @@ import shutil
import subprocess
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from pydantic import Field, InstanceOf, PrivateAttr, model_validator, field_validator
from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -113,8 +113,18 @@ class Agent(BaseAgent):
)
endpoint: Optional[str] = Field(
default=None,
description="Custom endpoint URL for the LLM API.",
description="Custom endpoint URL for the LLM API. Primarily used for Ollama models to specify alternative API endpoints.",
examples=["http://localhost:11434", "https://ollama.example.com:11434"],
)
@field_validator("endpoint")
def validate_endpoint(cls, v):
if v is not None:
if not isinstance(v, str):
raise ValueError("Endpoint must be a string")
if not v.startswith(("http://", "https://")):
raise ValueError("Endpoint must start with http:// or https://")
return v
@model_validator(mode="after")
def post_init_setup(self):

View File

@@ -5,6 +5,20 @@ from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM
def is_ollama_model(model_name: str) -> bool:
"""Check if a model name is an Ollama model."""
return bool(model_name and "ollama" in str(model_name).lower())
def validate_and_set_endpoint(endpoint: str) -> str:
"""Validate and format an endpoint URL."""
if not endpoint:
return endpoint
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Invalid endpoint URL: {endpoint}. Must start with http:// or https://")
return endpoint.rstrip("/")
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
endpoint: Optional[str] = None,
@@ -14,12 +28,14 @@ def create_llm(
Args:
llm_value (str | LLM | Any | None):
- str: The model name (e.g., "gpt-4").
- str: Model name to use for instantiating a new LLM.
- LLM: Already instantiated LLM, returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
endpoint (str | None):
- Optional endpoint URL for the LLM API.
- Optional API endpoint URL. When provided for Ollama models,
this will override the default API endpoint. Should be a valid
HTTP/HTTPS URL (e.g., "http://localhost:11434").
Returns:
An LLM instance if successful, or None if something fails.
@@ -33,8 +49,8 @@ def create_llm(
if isinstance(llm_value, str):
try:
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in llm_value.lower():
created_llm = LLM(model=llm_value, api_base=endpoint)
if endpoint and is_ollama_model(llm_value):
created_llm = LLM(model=llm_value, api_base=validate_and_set_endpoint(endpoint))
else:
created_llm = LLM(model=llm_value)
return created_llm
@@ -63,8 +79,8 @@ def create_llm(
api_base: Optional[str] = getattr(llm_value, "api_base", None)
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in str(model).lower():
api_base = endpoint
if endpoint and is_ollama_model(model):
api_base = validate_and_set_endpoint(endpoint)
created_llm = LLM(
model=model,
@@ -188,8 +204,8 @@ def _llm_via_environment_or_fallback(endpoint: Optional[str] = None) -> Optional
llm_params = {k: v for k, v in llm_params.items() if v is not None}
# If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in model_name.lower():
llm_params["api_base"] = endpoint
if endpoint and is_ollama_model(model_name):
llm_params["api_base"] = validate_and_set_endpoint(endpoint)
# Try creating the LLM
try: