mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
Address PR review: Add validation and helper functions for endpoint handling
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator, field_validator
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -113,8 +113,18 @@ class Agent(BaseAgent):
|
||||
)
|
||||
endpoint: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Custom endpoint URL for the LLM API.",
|
||||
description="Custom endpoint URL for the LLM API. Primarily used for Ollama models to specify alternative API endpoints.",
|
||||
examples=["http://localhost:11434", "https://ollama.example.com:11434"],
|
||||
)
|
||||
|
||||
@field_validator("endpoint")
|
||||
def validate_endpoint(cls, v):
|
||||
if v is not None:
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("Endpoint must be a string")
|
||||
if not v.startswith(("http://", "https://")):
|
||||
raise ValueError("Endpoint must start with http:// or https://")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
|
||||
@@ -5,6 +5,20 @@ from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
def is_ollama_model(model_name: str) -> bool:
|
||||
"""Check if a model name is an Ollama model."""
|
||||
return bool(model_name and "ollama" in str(model_name).lower())
|
||||
|
||||
|
||||
def validate_and_set_endpoint(endpoint: str) -> str:
|
||||
"""Validate and format an endpoint URL."""
|
||||
if not endpoint:
|
||||
return endpoint
|
||||
if not endpoint.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Invalid endpoint URL: {endpoint}. Must start with http:// or https://")
|
||||
return endpoint.rstrip("/")
|
||||
|
||||
|
||||
def create_llm(
|
||||
llm_value: Union[str, LLM, Any, None] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
@@ -14,12 +28,14 @@ def create_llm(
|
||||
|
||||
Args:
|
||||
llm_value (str | LLM | Any | None):
|
||||
- str: The model name (e.g., "gpt-4").
|
||||
- str: Model name to use for instantiating a new LLM.
|
||||
- LLM: Already instantiated LLM, returned as-is.
|
||||
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||
- None: Use environment-based or fallback default model.
|
||||
endpoint (str | None):
|
||||
- Optional endpoint URL for the LLM API.
|
||||
- Optional API endpoint URL. When provided for Ollama models,
|
||||
this will override the default API endpoint. Should be a valid
|
||||
HTTP/HTTPS URL (e.g., "http://localhost:11434").
|
||||
|
||||
Returns:
|
||||
An LLM instance if successful, or None if something fails.
|
||||
@@ -33,8 +49,8 @@ def create_llm(
|
||||
if isinstance(llm_value, str):
|
||||
try:
|
||||
# If endpoint is provided and model is Ollama, use it as api_base
|
||||
if endpoint and "ollama" in llm_value.lower():
|
||||
created_llm = LLM(model=llm_value, api_base=endpoint)
|
||||
if endpoint and is_ollama_model(llm_value):
|
||||
created_llm = LLM(model=llm_value, api_base=validate_and_set_endpoint(endpoint))
|
||||
else:
|
||||
created_llm = LLM(model=llm_value)
|
||||
return created_llm
|
||||
@@ -63,8 +79,8 @@ def create_llm(
|
||||
api_base: Optional[str] = getattr(llm_value, "api_base", None)
|
||||
|
||||
# If endpoint is provided and model is Ollama, use it as api_base
|
||||
if endpoint and "ollama" in str(model).lower():
|
||||
api_base = endpoint
|
||||
if endpoint and is_ollama_model(model):
|
||||
api_base = validate_and_set_endpoint(endpoint)
|
||||
|
||||
created_llm = LLM(
|
||||
model=model,
|
||||
@@ -188,8 +204,8 @@ def _llm_via_environment_or_fallback(endpoint: Optional[str] = None) -> Optional
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
|
||||
# If endpoint is provided and model is Ollama, use it as api_base
|
||||
if endpoint and "ollama" in model_name.lower():
|
||||
llm_params["api_base"] = endpoint
|
||||
if endpoint and is_ollama_model(model_name):
|
||||
llm_params["api_base"] = validate_and_set_endpoint(endpoint)
|
||||
|
||||
# Try creating the LLM
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user