mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Lorenze/OpenAI base url backwards support (#3723)
* fix: enhance OpenAICompletion class base URL handling - Updated the base URL assignment in the OpenAICompletion class to prioritize the new `api_base` attribute and fallback to the environment variable `OPENAI_BASE_URL` if both are not set. - Added `api_base` to the list of parameters in the OpenAICompletion class to ensure proper configuration and flexibility in API endpoint management. * feat: enhance OpenAICompletion class with api_base support - Added the `api_base` parameter to the OpenAICompletion class to allow for flexible API endpoint configuration. - Updated the `_get_client_params` method to prioritize `base_url` over `api_base`, ensuring correct URL handling. - Introduced comprehensive tests to validate the behavior of `api_base` and `base_url` in various scenarios, including environment variable fallback. - Enhanced test coverage for client parameter retrieval, ensuring robust integration with the OpenAI API. * fix: improve OpenAICompletion class configuration handling - Added a debug print statement to log the client configuration parameters during initialization for better traceability. - Updated the base URL assignment logic to ensure it defaults to None if no valid base URL is provided, enhancing robustness in API endpoint configuration. - Refined the retrieval of the `api_base` environment variable to streamline the configuration process. * drop print
This commit is contained in:
@@ -65,6 +65,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
self.client_params = client_params
|
self.client_params = client_params
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
self.api_base = kwargs.pop("api_base", None)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -106,7 +107,10 @@ class OpenAICompletion(BaseLLM):
|
|||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
"organization": self.organization,
|
"organization": self.organization,
|
||||||
"project": self.project,
|
"project": self.project,
|
||||||
"base_url": self.base_url,
|
"base_url": self.base_url
|
||||||
|
or self.api_base
|
||||||
|
or os.getenv("OPENAI_BASE_URL")
|
||||||
|
or None,
|
||||||
"timeout": self.timeout,
|
"timeout": self.timeout,
|
||||||
"max_retries": self.max_retries,
|
"max_retries": self.max_retries,
|
||||||
"default_headers": self.default_headers,
|
"default_headers": self.default_headers,
|
||||||
@@ -239,6 +243,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
"provider",
|
"provider",
|
||||||
"api_key",
|
"api_key",
|
||||||
"base_url",
|
"base_url",
|
||||||
|
"api_base",
|
||||||
"timeout",
|
"timeout",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -407,3 +407,77 @@ def test_extra_arguments_are_passed_to_openai_completion():
|
|||||||
assert call_kwargs['max_tokens'] == 1000
|
assert call_kwargs['max_tokens'] == 1000
|
||||||
assert call_kwargs['top_p'] == 0.5
|
assert call_kwargs['top_p'] == 0.5
|
||||||
assert call_kwargs['model'] == 'gpt-4o'
|
assert call_kwargs['model'] == 'gpt-4o'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_get_client_params_with_api_base():
|
||||||
|
"""
|
||||||
|
Test that _get_client_params correctly converts api_base to base_url
|
||||||
|
"""
|
||||||
|
llm = OpenAICompletion(
|
||||||
|
model="gpt-4o",
|
||||||
|
api_base="https://custom.openai.com/v1",
|
||||||
|
)
|
||||||
|
client_params = llm._get_client_params()
|
||||||
|
assert client_params["base_url"] == "https://custom.openai.com/v1"
|
||||||
|
|
||||||
|
def test_openai_get_client_params_with_base_url_priority():
|
||||||
|
"""
|
||||||
|
Test that base_url takes priority over api_base in _get_client_params
|
||||||
|
"""
|
||||||
|
llm = OpenAICompletion(
|
||||||
|
model="gpt-4o",
|
||||||
|
base_url="https://priority.openai.com/v1",
|
||||||
|
api_base="https://fallback.openai.com/v1",
|
||||||
|
)
|
||||||
|
client_params = llm._get_client_params()
|
||||||
|
assert client_params["base_url"] == "https://priority.openai.com/v1"
|
||||||
|
|
||||||
|
def test_openai_get_client_params_with_env_var():
|
||||||
|
"""
|
||||||
|
Test that _get_client_params uses OPENAI_BASE_URL environment variable as fallback
|
||||||
|
"""
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"OPENAI_BASE_URL": "https://env.openai.com/v1",
|
||||||
|
}):
|
||||||
|
llm = OpenAICompletion(model="gpt-4o")
|
||||||
|
client_params = llm._get_client_params()
|
||||||
|
assert client_params["base_url"] == "https://env.openai.com/v1"
|
||||||
|
|
||||||
|
def test_openai_get_client_params_priority_order():
|
||||||
|
"""
|
||||||
|
Test the priority order: base_url > api_base > OPENAI_BASE_URL env var
|
||||||
|
"""
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"OPENAI_BASE_URL": "https://env.openai.com/v1",
|
||||||
|
}):
|
||||||
|
# Test base_url beats api_base and env var
|
||||||
|
llm1 = OpenAICompletion(
|
||||||
|
model="gpt-4o",
|
||||||
|
base_url="https://base-url.openai.com/v1",
|
||||||
|
api_base="https://api-base.openai.com/v1",
|
||||||
|
)
|
||||||
|
params1 = llm1._get_client_params()
|
||||||
|
assert params1["base_url"] == "https://base-url.openai.com/v1"
|
||||||
|
|
||||||
|
# Test api_base beats env var when base_url is None
|
||||||
|
llm2 = OpenAICompletion(
|
||||||
|
model="gpt-4o",
|
||||||
|
api_base="https://api-base.openai.com/v1",
|
||||||
|
)
|
||||||
|
params2 = llm2._get_client_params()
|
||||||
|
assert params2["base_url"] == "https://api-base.openai.com/v1"
|
||||||
|
|
||||||
|
# Test env var is used when both base_url and api_base are None
|
||||||
|
llm3 = OpenAICompletion(model="gpt-4o")
|
||||||
|
params3 = llm3._get_client_params()
|
||||||
|
assert params3["base_url"] == "https://env.openai.com/v1"
|
||||||
|
|
||||||
|
def test_openai_get_client_params_no_base_url():
|
||||||
|
"""
|
||||||
|
Test that _get_client_params works correctly when no base_url is specified
|
||||||
|
"""
|
||||||
|
llm = OpenAICompletion(model="gpt-4o")
|
||||||
|
client_params = llm._get_client_params()
|
||||||
|
# When no base_url is provided, it should not be in the params (filtered out as None)
|
||||||
|
assert "base_url" not in client_params or client_params.get("base_url") is None
|
||||||
|
|||||||
Reference in New Issue
Block a user