Address PR review: Add validation and helper functions for endpoint handling

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-25 10:57:55 +00:00
parent 9e5084ba22
commit 556da08555
2 changed files with 36 additions and 10 deletions

View File

@@ -3,7 +3,7 @@ import shutil
import subprocess import subprocess
from typing import Any, Dict, List, Literal, Optional, Sequence, Union from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator from pydantic import Field, InstanceOf, PrivateAttr, model_validator, field_validator
from crewai.agents import CacheHandler from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -113,8 +113,18 @@ class Agent(BaseAgent):
) )
endpoint: Optional[str] = Field( endpoint: Optional[str] = Field(
default=None, default=None,
description="Custom endpoint URL for the LLM API.", description="Custom endpoint URL for the LLM API. Primarily used for Ollama models to specify alternative API endpoints.",
examples=["http://localhost:11434", "https://ollama.example.com:11434"],
) )
@field_validator("endpoint")
def validate_endpoint(cls, v):
if v is not None:
if not isinstance(v, str):
raise ValueError("Endpoint must be a string")
if not v.startswith(("http://", "https://")):
raise ValueError("Endpoint must start with http:// or https://")
return v
@model_validator(mode="after") @model_validator(mode="after")
def post_init_setup(self): def post_init_setup(self):

View File

@@ -5,6 +5,20 @@ from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM from crewai.llm import LLM
def is_ollama_model(model_name: str) -> bool:
"""Check if a model name is an Ollama model."""
return bool(model_name and "ollama" in str(model_name).lower())
def validate_and_set_endpoint(endpoint: str) -> str:
"""Validate and format an endpoint URL."""
if not endpoint:
return endpoint
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Invalid endpoint URL: {endpoint}. Must start with http:// or https://")
return endpoint.rstrip("/")
def create_llm( def create_llm(
llm_value: Union[str, LLM, Any, None] = None, llm_value: Union[str, LLM, Any, None] = None,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
@@ -14,12 +28,14 @@ def create_llm(
Args: Args:
llm_value (str | LLM | Any | None): llm_value (str | LLM | Any | None):
- str: The model name (e.g., "gpt-4"). - str: Model name to use for instantiating a new LLM.
- LLM: Already instantiated LLM, returned as-is. - LLM: Already instantiated LLM, returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc. - Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model. - None: Use environment-based or fallback default model.
endpoint (str | None): endpoint (str | None):
- Optional endpoint URL for the LLM API. - Optional API endpoint URL. When provided for Ollama models,
this will override the default API endpoint. Should be a valid
HTTP/HTTPS URL (e.g., "http://localhost:11434").
Returns: Returns:
An LLM instance if successful, or None if something fails. An LLM instance if successful, or None if something fails.
@@ -33,8 +49,8 @@ def create_llm(
if isinstance(llm_value, str): if isinstance(llm_value, str):
try: try:
# If endpoint is provided and model is Ollama, use it as api_base # If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in llm_value.lower(): if endpoint and is_ollama_model(llm_value):
created_llm = LLM(model=llm_value, api_base=endpoint) created_llm = LLM(model=llm_value, api_base=validate_and_set_endpoint(endpoint))
else: else:
created_llm = LLM(model=llm_value) created_llm = LLM(model=llm_value)
return created_llm return created_llm
@@ -63,8 +79,8 @@ def create_llm(
api_base: Optional[str] = getattr(llm_value, "api_base", None) api_base: Optional[str] = getattr(llm_value, "api_base", None)
# If endpoint is provided and model is Ollama, use it as api_base # If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in str(model).lower(): if endpoint and is_ollama_model(model):
api_base = endpoint api_base = validate_and_set_endpoint(endpoint)
created_llm = LLM( created_llm = LLM(
model=model, model=model,
@@ -188,8 +204,8 @@ def _llm_via_environment_or_fallback(endpoint: Optional[str] = None) -> Optional
llm_params = {k: v for k, v in llm_params.items() if v is not None} llm_params = {k: v for k, v in llm_params.items() if v is not None}
# If endpoint is provided and model is Ollama, use it as api_base # If endpoint is provided and model is Ollama, use it as api_base
if endpoint and "ollama" in model_name.lower(): if endpoint and is_ollama_model(model_name):
llm_params["api_base"] = endpoint llm_params["api_base"] = validate_and_set_endpoint(endpoint)
# Try creating the LLM # Try creating the LLM
try: try: