From 9e5084ba22b9124fd57d3d138df2107e369b29d4 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 25 Feb 2025 10:47:09 +0000 Subject: [PATCH] Fix issue 2216: Pass custom endpoint to LiteLLM for Ollama provider Co-Authored-By: Joe Moura --- src/crewai/agent.py | 9 +++++++-- src/crewai/utilities/llm_utils.py | 25 ++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index f07408133..558b8c78a 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -111,15 +111,20 @@ class Agent(BaseAgent): default=None, description="Embedder configuration for the agent.", ) + endpoint: Optional[str] = Field( + default=None, + description="Custom endpoint URL for the LLM API.", + ) @model_validator(mode="after") def post_init_setup(self): self._set_knowledge() self.agent_ops_agent_name = self.role - self.llm = create_llm(self.llm) + # Pass endpoint to create_llm if it exists + self.llm = create_llm(self.llm, endpoint=self.endpoint) if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM): - self.function_calling_llm = create_llm(self.function_calling_llm) + self.function_calling_llm = create_llm(self.function_calling_llm, endpoint=self.endpoint) if not self.agent_executor: self._setup_agent_executor() diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index c774a71fb..7a10a0b30 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -7,6 +7,7 @@ from crewai.llm import LLM def create_llm( llm_value: Union[str, LLM, Any, None] = None, + endpoint: Optional[str] = None, ) -> Optional[LLM]: """ Creates or returns an LLM instance based on the given llm_value. @@ -17,6 +18,8 @@ def create_llm( - LLM: Already instantiated LLM, returned as-is. - Any: Attempt to extract known attributes like model_name, temperature, etc. - None: Use environment-based or fallback default model. + endpoint (str | None): + - Optional endpoint URL for the LLM API. Returns: An LLM instance if successful, or None if something fails. @@ -29,7 +32,11 @@ def create_llm( # 2) If llm_value is a string (model name) if isinstance(llm_value, str): try: - created_llm = LLM(model=llm_value) + # If endpoint is provided and model is Ollama, use it as api_base + if endpoint and "ollama" in llm_value.lower(): + created_llm = LLM(model=llm_value, api_base=endpoint) + else: + created_llm = LLM(model=llm_value) return created_llm except Exception as e: print(f"Failed to instantiate LLM with model='{llm_value}': {e}") @@ -37,7 +44,7 @@ def create_llm( # 3) If llm_value is None, parse environment variables or use default if llm_value is None: - return _llm_via_environment_or_fallback() + return _llm_via_environment_or_fallback(endpoint=endpoint) # 4) Otherwise, attempt to extract relevant attributes from an unknown object try: @@ -55,6 +62,10 @@ def create_llm( base_url: Optional[str] = getattr(llm_value, "base_url", None) api_base: Optional[str] = getattr(llm_value, "api_base", None) + # If endpoint is provided and model is Ollama, use it as api_base + if endpoint and "ollama" in str(model).lower(): + api_base = endpoint + created_llm = LLM( model=model, temperature=temperature, @@ -71,9 +82,13 @@ def create_llm( return None -def _llm_via_environment_or_fallback() -> Optional[LLM]: +def _llm_via_environment_or_fallback(endpoint: Optional[str] = None) -> Optional[LLM]: """ Helper function: if llm_value is None, we load environment variables or fallback default model. + + Args: + endpoint (str | None): + - Optional endpoint URL for the LLM API. """ model_name = ( os.environ.get("OPENAI_MODEL_NAME") @@ -172,6 +187,10 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: # Remove None values llm_params = {k: v for k, v in llm_params.items() if v is not None} + # If endpoint is provided and model is Ollama, use it as api_base + if endpoint and "ollama" in model_name.lower(): + llm_params["api_base"] = endpoint + # Try creating the LLM try: new_llm = LLM(**llm_params)