Compare commits

...

4 Commits

Author SHA1 Message Date
Brandon Hancock (bhancock_ai)
a8124cda06 Merge branch 'main' into bugfix/llm-env-fix-for-daniel 2025-01-29 19:31:30 -05:00
Lorenze Jay
c471f67074 Merge branch 'main' into bugfix/llm-env-fix-for-daniel 2025-01-29 14:42:55 -08:00
Brandon Hancock
973a860415 add in api_base 2025-01-29 16:15:01 -05:00
Brandon Hancock
9f6b4cefd3 iwp 2025-01-28 15:46:13 -05:00
2 changed files with 19 additions and 3 deletions

View File

@@ -133,6 +133,7 @@ class LLM:
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
@@ -152,6 +153,7 @@ class LLM:
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.base_url = base_url
self.api_base = api_base
self.api_version = api_version
self.api_key = api_key
self.callbacks = callbacks
@@ -232,7 +234,8 @@ class LLM:
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.base_url,
"api_base": self.api_base,
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,

View File

@@ -53,6 +53,7 @@ def create_llm(
timeout: Optional[float] = getattr(llm_value, "timeout", None)
api_key: Optional[str] = getattr(llm_value, "api_key", None)
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
created_llm = LLM(
model=model,
@@ -62,6 +63,7 @@ def create_llm(
timeout=timeout,
api_key=api_key,
base_url=base_url,
api_base=api_base,
)
return created_llm
except Exception as e:
@@ -101,8 +103,18 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
callbacks: List[Any] = []
# Optional base URL from env
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL")
if api_base:
base_url = (
os.environ.get("BASE_URL")
or os.environ.get("OPENAI_API_BASE")
or os.environ.get("OPENAI_BASE_URL")
)
api_base = os.environ.get("API_BASE") or os.environ.get("AZURE_API_BASE")
# Synchronize base_url and api_base if one is populated and the other is not
if base_url and not api_base:
api_base = base_url
elif api_base and not base_url:
base_url = api_base
# Initialize llm_params dictionary
@@ -115,6 +127,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
"timeout": timeout,
"api_key": api_key,
"base_url": base_url,
"api_base": api_base,
"api_version": api_version,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,