mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 07:08:31 +00:00
Lorenze/OpenAI base url backwards support (#3723)
* fix: enhance OpenAICompletion class base URL handling - Updated the base URL assignment in the OpenAICompletion class to prioritize the new `api_base` attribute and fallback to the environment variable `OPENAI_BASE_URL` if both are not set. - Added `api_base` to the list of parameters in the OpenAICompletion class to ensure proper configuration and flexibility in API endpoint management. * feat: enhance OpenAICompletion class with api_base support - Added the `api_base` parameter to the OpenAICompletion class to allow for flexible API endpoint configuration. - Updated the `_get_client_params` method to prioritize `base_url` over `api_base`, ensuring correct URL handling. - Introduced comprehensive tests to validate the behavior of `api_base` and `base_url` in various scenarios, including environment variable fallback. - Enhanced test coverage for client parameter retrieval, ensuring robust integration with the OpenAI API. * fix: improve OpenAICompletion class configuration handling - Added a debug print statement to log the client configuration parameters during initialization for better traceability. - Updated the base URL assignment logic to ensure it defaults to None if no valid base URL is provided, enhancing robustness in API endpoint configuration. - Refined the retrieval of the `api_base` environment variable to streamline the configuration process. * drop print
This commit is contained in:
@@ -65,6 +65,7 @@ class OpenAICompletion(BaseLLM):
|
||||
self.client_params = client_params
|
||||
self.timeout = timeout
|
||||
self.base_url = base_url
|
||||
self.api_base = kwargs.pop("api_base", None)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
@@ -106,7 +107,10 @@ class OpenAICompletion(BaseLLM):
|
||||
"api_key": self.api_key,
|
||||
"organization": self.organization,
|
||||
"project": self.project,
|
||||
"base_url": self.base_url,
|
||||
"base_url": self.base_url
|
||||
or self.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or None,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
@@ -239,6 +243,7 @@ class OpenAICompletion(BaseLLM):
|
||||
"provider",
|
||||
"api_key",
|
||||
"base_url",
|
||||
"api_base",
|
||||
"timeout",
|
||||
}
|
||||
|
||||
|
||||
@@ -407,3 +407,77 @@ def test_extra_arguments_are_passed_to_openai_completion():
|
||||
assert call_kwargs['max_tokens'] == 1000
|
||||
assert call_kwargs['top_p'] == 0.5
|
||||
assert call_kwargs['model'] == 'gpt-4o'
|
||||
|
||||
|
||||
|
||||
def test_openai_get_client_params_with_api_base():
|
||||
"""
|
||||
Test that _get_client_params correctly converts api_base to base_url
|
||||
"""
|
||||
llm = OpenAICompletion(
|
||||
model="gpt-4o",
|
||||
api_base="https://custom.openai.com/v1",
|
||||
)
|
||||
client_params = llm._get_client_params()
|
||||
assert client_params["base_url"] == "https://custom.openai.com/v1"
|
||||
|
||||
def test_openai_get_client_params_with_base_url_priority():
|
||||
"""
|
||||
Test that base_url takes priority over api_base in _get_client_params
|
||||
"""
|
||||
llm = OpenAICompletion(
|
||||
model="gpt-4o",
|
||||
base_url="https://priority.openai.com/v1",
|
||||
api_base="https://fallback.openai.com/v1",
|
||||
)
|
||||
client_params = llm._get_client_params()
|
||||
assert client_params["base_url"] == "https://priority.openai.com/v1"
|
||||
|
||||
def test_openai_get_client_params_with_env_var():
|
||||
"""
|
||||
Test that _get_client_params uses OPENAI_BASE_URL environment variable as fallback
|
||||
"""
|
||||
with patch.dict(os.environ, {
|
||||
"OPENAI_BASE_URL": "https://env.openai.com/v1",
|
||||
}):
|
||||
llm = OpenAICompletion(model="gpt-4o")
|
||||
client_params = llm._get_client_params()
|
||||
assert client_params["base_url"] == "https://env.openai.com/v1"
|
||||
|
||||
def test_openai_get_client_params_priority_order():
|
||||
"""
|
||||
Test the priority order: base_url > api_base > OPENAI_BASE_URL env var
|
||||
"""
|
||||
with patch.dict(os.environ, {
|
||||
"OPENAI_BASE_URL": "https://env.openai.com/v1",
|
||||
}):
|
||||
# Test base_url beats api_base and env var
|
||||
llm1 = OpenAICompletion(
|
||||
model="gpt-4o",
|
||||
base_url="https://base-url.openai.com/v1",
|
||||
api_base="https://api-base.openai.com/v1",
|
||||
)
|
||||
params1 = llm1._get_client_params()
|
||||
assert params1["base_url"] == "https://base-url.openai.com/v1"
|
||||
|
||||
# Test api_base beats env var when base_url is None
|
||||
llm2 = OpenAICompletion(
|
||||
model="gpt-4o",
|
||||
api_base="https://api-base.openai.com/v1",
|
||||
)
|
||||
params2 = llm2._get_client_params()
|
||||
assert params2["base_url"] == "https://api-base.openai.com/v1"
|
||||
|
||||
# Test env var is used when both base_url and api_base are None
|
||||
llm3 = OpenAICompletion(model="gpt-4o")
|
||||
params3 = llm3._get_client_params()
|
||||
assert params3["base_url"] == "https://env.openai.com/v1"
|
||||
|
||||
def test_openai_get_client_params_no_base_url():
|
||||
"""
|
||||
Test that _get_client_params works correctly when no base_url is specified
|
||||
"""
|
||||
llm = OpenAICompletion(model="gpt-4o")
|
||||
client_params = llm._get_client_params()
|
||||
# When no base_url is provided, it should not be in the params (filtered out as None)
|
||||
assert "base_url" not in client_params or client_params.get("base_url") is None
|
||||
|
||||
Reference in New Issue
Block a user