Files
crewAI/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py
alex-clawd c183b77991
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
fix: address Copilot review on OpenAI-compatible providers (#5042) (#5089)
- Delegate supports_function_calling() to parent (handles o1 models via OpenRouter)
- Guard empty env vars in base_url resolution
- Fix misleading comment about model validation rules
- Remove unused MagicMock import
- Use 'is not None' for env var restoration in tests

Co-authored-by: Joao Moura <joao@crewai.com>
2026-03-25 18:22:13 -03:00

286 lines
9.1 KiB
Python

"""OpenAI-compatible providers implementation.
This module provides a thin subclass of OpenAICompletion that supports
various OpenAI-compatible APIs like OpenRouter, DeepSeek, Ollama, vLLM,
Cerebras, and Dashscope (Alibaba/Qwen).
Usage:
llm = LLM(model="deepseek/deepseek-chat") # Uses DeepSeek API
llm = LLM(model="openrouter/anthropic/claude-3-opus") # Uses OpenRouter
llm = LLM(model="ollama/llama3") # Uses local Ollama
"""
from __future__ import annotations
from dataclasses import dataclass, field
import os
from typing import Any
from crewai.llms.providers.openai.completion import OpenAICompletion
@dataclass(frozen=True)
class ProviderConfig:
"""Configuration for an OpenAI-compatible provider.
Attributes:
base_url: Default base URL for the provider's API endpoint.
api_key_env: Environment variable name for the API key.
base_url_env: Environment variable name for a custom base URL override.
default_headers: HTTP headers to include in all requests.
api_key_required: Whether an API key is required for this provider.
default_api_key: Default API key to use if none is provided and not required.
"""
base_url: str
api_key_env: str
base_url_env: str | None = None
default_headers: dict[str, str] = field(default_factory=dict)
api_key_required: bool = True
default_api_key: str | None = None
OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
"openrouter": ProviderConfig(
base_url="https://openrouter.ai/api/v1",
api_key_env="OPENROUTER_API_KEY",
base_url_env="OPENROUTER_BASE_URL",
default_headers={"HTTP-Referer": "https://crewai.com"},
api_key_required=True,
),
"deepseek": ProviderConfig(
base_url="https://api.deepseek.com/v1",
api_key_env="DEEPSEEK_API_KEY",
base_url_env="DEEPSEEK_BASE_URL",
api_key_required=True,
),
"ollama": ProviderConfig(
base_url="http://localhost:11434/v1",
api_key_env="OLLAMA_API_KEY",
base_url_env="OLLAMA_HOST",
api_key_required=False,
default_api_key="ollama",
),
"ollama_chat": ProviderConfig(
base_url="http://localhost:11434/v1",
api_key_env="OLLAMA_API_KEY",
base_url_env="OLLAMA_HOST",
api_key_required=False,
default_api_key="ollama",
),
"hosted_vllm": ProviderConfig(
base_url="http://localhost:8000/v1",
api_key_env="VLLM_API_KEY",
base_url_env="VLLM_BASE_URL",
api_key_required=False,
default_api_key="dummy",
),
"cerebras": ProviderConfig(
base_url="https://api.cerebras.ai/v1",
api_key_env="CEREBRAS_API_KEY",
base_url_env="CEREBRAS_BASE_URL",
api_key_required=True,
),
"dashscope": ProviderConfig(
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
api_key_env="DASHSCOPE_API_KEY",
base_url_env="DASHSCOPE_BASE_URL",
api_key_required=True,
),
}
def _normalize_ollama_base_url(base_url: str) -> str:
"""Normalize Ollama base URL to ensure it ends with /v1.
Ollama uses OLLAMA_HOST which may not include the /v1 suffix,
but the OpenAI-compatible endpoint requires it.
Args:
base_url: The base URL, potentially without /v1 suffix.
Returns:
The base URL with /v1 suffix if needed.
"""
base_url = base_url.rstrip("/")
if not base_url.endswith("/v1"):
return f"{base_url}/v1"
return base_url
class OpenAICompatibleCompletion(OpenAICompletion):
"""OpenAI-compatible completion implementation.
This class provides support for various OpenAI-compatible APIs by
automatically configuring the base URL, API key, and headers based
on the provider name.
Supported providers:
- openrouter: OpenRouter (https://openrouter.ai)
- deepseek: DeepSeek (https://deepseek.com)
- ollama: Ollama local server (https://ollama.ai)
- ollama_chat: Alias for ollama
- hosted_vllm: vLLM server (https://github.com/vllm-project/vllm)
- cerebras: Cerebras (https://cerebras.ai)
- dashscope: Alibaba Dashscope/Qwen (https://dashscope.aliyun.com)
Example:
# Using provider prefix
llm = LLM(model="deepseek/deepseek-chat")
# Using explicit provider parameter
llm = LLM(model="llama3", provider="ollama")
# With custom configuration
llm = LLM(
model="deepseek-chat",
provider="deepseek",
api_key="my-key",
temperature=0.7
)
"""
def __init__(
self,
model: str,
provider: str,
api_key: str | None = None,
base_url: str | None = None,
default_headers: dict[str, str] | None = None,
**kwargs: Any,
) -> None:
"""Initialize OpenAI-compatible completion client.
Args:
model: The model identifier.
provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS).
api_key: Optional API key override. If not provided, uses the
provider's configured environment variable.
base_url: Optional base URL override. If not provided, uses the
provider's configured default or environment variable.
default_headers: Optional headers to merge with provider defaults.
**kwargs: Additional arguments passed to OpenAICompletion.
Raises:
ValueError: If the provider is not supported or required API key
is missing.
"""
config = OPENAI_COMPATIBLE_PROVIDERS.get(provider)
if config is None:
supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys()))
raise ValueError(
f"Unknown OpenAI-compatible provider: {provider}. "
f"Supported providers: {supported}"
)
resolved_api_key = self._resolve_api_key(api_key, config, provider)
resolved_base_url = self._resolve_base_url(base_url, config, provider)
resolved_headers = self._resolve_headers(default_headers, config)
super().__init__(
model=model,
provider=provider,
api_key=resolved_api_key,
base_url=resolved_base_url,
default_headers=resolved_headers,
**kwargs,
)
def _resolve_api_key(
self,
api_key: str | None,
config: ProviderConfig,
provider: str,
) -> str | None:
"""Resolve the API key from explicit value, env var, or default.
Args:
api_key: Explicitly provided API key.
config: Provider configuration.
provider: Provider name for error messages.
Returns:
The resolved API key.
Raises:
ValueError: If API key is required but not found.
"""
if api_key:
return api_key
env_key = os.getenv(config.api_key_env)
if env_key:
return env_key
if config.api_key_required:
raise ValueError(
f"API key required for {provider}. "
f"Set {config.api_key_env} environment variable or pass api_key parameter."
)
return config.default_api_key
def _resolve_base_url(
self,
base_url: str | None,
config: ProviderConfig,
provider: str,
) -> str:
"""Resolve the base URL from explicit value, env var, or default.
Args:
base_url: Explicitly provided base URL.
config: Provider configuration.
provider: Provider name (used for special handling like Ollama).
Returns:
The resolved base URL.
"""
if base_url:
resolved = base_url
elif config.base_url_env:
env_value = os.getenv(config.base_url_env)
resolved = env_value if env_value else config.base_url
else:
resolved = config.base_url
if provider in ("ollama", "ollama_chat"):
resolved = _normalize_ollama_base_url(resolved)
return resolved
def _resolve_headers(
self,
headers: dict[str, str] | None,
config: ProviderConfig,
) -> dict[str, str] | None:
"""Merge user headers with provider default headers.
Args:
headers: User-provided headers.
config: Provider configuration.
Returns:
Merged headers dict, or None if empty.
"""
if not config.default_headers and not headers:
return None
merged = dict(config.default_headers)
if headers:
merged.update(headers)
return merged if merged else None
def supports_function_calling(self) -> bool:
"""Check if the provider supports function calling.
Delegates to the parent OpenAI implementation which handles
edge cases like o1 models (which may be routed through
OpenRouter or other compatible providers).
Returns:
Whether the model supports function calling.
"""
return super().supports_function_calling()