diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llms/providers/openai/completion.py index 811a1a38d..f92440545 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai/completion.py @@ -65,6 +65,7 @@ class OpenAICompletion(BaseLLM): self.client_params = client_params self.timeout = timeout self.base_url = base_url + self.api_base = kwargs.pop("api_base", None) super().__init__( model=model, @@ -106,7 +107,10 @@ class OpenAICompletion(BaseLLM): "api_key": self.api_key, "organization": self.organization, "project": self.project, - "base_url": self.base_url, + "base_url": self.base_url + or self.api_base + or os.getenv("OPENAI_BASE_URL") + or None, "timeout": self.timeout, "max_retries": self.max_retries, "default_headers": self.default_headers, @@ -239,6 +243,7 @@ class OpenAICompletion(BaseLLM): "provider", "api_key", "base_url", + "api_base", "timeout", } diff --git a/lib/crewai/tests/llms/openai/test_openai.py b/lib/crewai/tests/llms/openai/test_openai.py index b825f9621..053bfa6f1 100644 --- a/lib/crewai/tests/llms/openai/test_openai.py +++ b/lib/crewai/tests/llms/openai/test_openai.py @@ -407,3 +407,77 @@ def test_extra_arguments_are_passed_to_openai_completion(): assert call_kwargs['max_tokens'] == 1000 assert call_kwargs['top_p'] == 0.5 assert call_kwargs['model'] == 'gpt-4o' + + + +def test_openai_get_client_params_with_api_base(): + """ + Test that _get_client_params correctly converts api_base to base_url + """ + llm = OpenAICompletion( + model="gpt-4o", + api_base="https://custom.openai.com/v1", + ) + client_params = llm._get_client_params() + assert client_params["base_url"] == "https://custom.openai.com/v1" + +def test_openai_get_client_params_with_base_url_priority(): + """ + Test that base_url takes priority over api_base in _get_client_params + """ + llm = OpenAICompletion( + model="gpt-4o", + base_url="https://priority.openai.com/v1", + api_base="https://fallback.openai.com/v1", + ) + client_params = llm._get_client_params() + assert client_params["base_url"] == "https://priority.openai.com/v1" + +def test_openai_get_client_params_with_env_var(): + """ + Test that _get_client_params uses OPENAI_BASE_URL environment variable as fallback + """ + with patch.dict(os.environ, { + "OPENAI_BASE_URL": "https://env.openai.com/v1", + }): + llm = OpenAICompletion(model="gpt-4o") + client_params = llm._get_client_params() + assert client_params["base_url"] == "https://env.openai.com/v1" + +def test_openai_get_client_params_priority_order(): + """ + Test the priority order: base_url > api_base > OPENAI_BASE_URL env var + """ + with patch.dict(os.environ, { + "OPENAI_BASE_URL": "https://env.openai.com/v1", + }): + # Test base_url beats api_base and env var + llm1 = OpenAICompletion( + model="gpt-4o", + base_url="https://base-url.openai.com/v1", + api_base="https://api-base.openai.com/v1", + ) + params1 = llm1._get_client_params() + assert params1["base_url"] == "https://base-url.openai.com/v1" + + # Test api_base beats env var when base_url is None + llm2 = OpenAICompletion( + model="gpt-4o", + api_base="https://api-base.openai.com/v1", + ) + params2 = llm2._get_client_params() + assert params2["base_url"] == "https://api-base.openai.com/v1" + + # Test env var is used when both base_url and api_base are None + llm3 = OpenAICompletion(model="gpt-4o") + params3 = llm3._get_client_params() + assert params3["base_url"] == "https://env.openai.com/v1" + +def test_openai_get_client_params_no_base_url(): + """ + Test that _get_client_params works correctly when no base_url is specified + """ + llm = OpenAICompletion(model="gpt-4o") + client_params = llm._get_client_params() + # When no base_url is provided, it should not be in the params (filtered out as None) + assert "base_url" not in client_params or client_params.get("base_url") is None