From 556da085551ad6fcbe061dab34103551cab3ee3a Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 25 Feb 2025 10:57:55 +0000 Subject: [PATCH] Address PR review: Add validation and helper functions for endpoint handling Co-Authored-By: Joe Moura --- src/crewai/agent.py | 14 ++++++++++++-- src/crewai/utilities/llm_utils.py | 32 +++++++++++++++++++++++-------- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 558b8c78a..ca014f1d6 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -3,7 +3,7 @@ import shutil import subprocess from typing import Any, Dict, List, Literal, Optional, Sequence, Union -from pydantic import Field, InstanceOf, PrivateAttr, model_validator +from pydantic import Field, InstanceOf, PrivateAttr, model_validator, field_validator from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent @@ -113,8 +113,18 @@ class Agent(BaseAgent): ) endpoint: Optional[str] = Field( default=None, - description="Custom endpoint URL for the LLM API.", + description="Custom endpoint URL for the LLM API. Primarily used for Ollama models to specify alternative API endpoints.", + examples=["http://localhost:11434", "https://ollama.example.com:11434"], ) + + @field_validator("endpoint") + def validate_endpoint(cls, v): + if v is not None: + if not isinstance(v, str): + raise ValueError("Endpoint must be a string") + if not v.startswith(("http://", "https://")): + raise ValueError("Endpoint must start with http:// or https://") + return v @model_validator(mode="after") def post_init_setup(self): diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 7a10a0b30..6720e0865 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -5,6 +5,20 @@ from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS from crewai.llm import LLM +def is_ollama_model(model_name: str) -> bool: + """Check if a model name is an Ollama model.""" + return bool(model_name and "ollama" in str(model_name).lower()) + + +def validate_and_set_endpoint(endpoint: str) -> str: + """Validate and format an endpoint URL.""" + if not endpoint: + return endpoint + if not endpoint.startswith(("http://", "https://")): + raise ValueError(f"Invalid endpoint URL: {endpoint}. Must start with http:// or https://") + return endpoint.rstrip("/") + + def create_llm( llm_value: Union[str, LLM, Any, None] = None, endpoint: Optional[str] = None, @@ -14,12 +28,14 @@ def create_llm( Args: llm_value (str | LLM | Any | None): - - str: The model name (e.g., "gpt-4"). + - str: Model name to use for instantiating a new LLM. - LLM: Already instantiated LLM, returned as-is. - Any: Attempt to extract known attributes like model_name, temperature, etc. - None: Use environment-based or fallback default model. endpoint (str | None): - - Optional endpoint URL for the LLM API. + - Optional API endpoint URL. When provided for Ollama models, + this will override the default API endpoint. Should be a valid + HTTP/HTTPS URL (e.g., "http://localhost:11434"). Returns: An LLM instance if successful, or None if something fails. @@ -33,8 +49,8 @@ def create_llm( if isinstance(llm_value, str): try: # If endpoint is provided and model is Ollama, use it as api_base - if endpoint and "ollama" in llm_value.lower(): - created_llm = LLM(model=llm_value, api_base=endpoint) + if endpoint and is_ollama_model(llm_value): + created_llm = LLM(model=llm_value, api_base=validate_and_set_endpoint(endpoint)) else: created_llm = LLM(model=llm_value) return created_llm @@ -63,8 +79,8 @@ def create_llm( api_base: Optional[str] = getattr(llm_value, "api_base", None) # If endpoint is provided and model is Ollama, use it as api_base - if endpoint and "ollama" in str(model).lower(): - api_base = endpoint + if endpoint and is_ollama_model(model): + api_base = validate_and_set_endpoint(endpoint) created_llm = LLM( model=model, @@ -188,8 +204,8 @@ def _llm_via_environment_or_fallback(endpoint: Optional[str] = None) -> Optional llm_params = {k: v for k, v in llm_params.items() if v is not None} # If endpoint is provided and model is Ollama, use it as api_base - if endpoint and "ollama" in model_name.lower(): - llm_params["api_base"] = endpoint + if endpoint and is_ollama_model(model_name): + llm_params["api_base"] = validate_and_set_endpoint(endpoint) # Try creating the LLM try: