mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-12 05:52:39 +00:00
feat: native OpenAI-compatible providers (OpenRouter, DeepSeek, Ollama, vLLM, Cerebras, Dashscope) (#5042)
* feat: add native OpenAI-compatible providers (OpenRouter, DeepSeek, Ollama, vLLM, Cerebras, Dashscope) Add a data-driven OpenAI-compatible provider system that enables native support for multiple third-party APIs that implement the OpenAI API specification. New providers: - OpenRouter: 500+ models via openrouter.ai - DeepSeek: deepseek-chat, deepseek-coder, deepseek-reasoner - Ollama: local models (llama3, mistral, codellama, etc.) - hosted_vllm: self-hosted vLLM servers - Cerebras: ultra-fast inference - Dashscope: Alibaba Qwen models (qwen-turbo, qwen-max, etc.) Architecture: - Single OpenAICompatibleCompletion class extends OpenAICompletion - ProviderConfig dataclass stores per-provider settings - Registry dict makes adding new providers a single config entry - Handles provider-specific quirks (OpenRouter headers, Ollama base URL normalization, optional API keys) Usage: LLM(model="deepseek/deepseek-chat") LLM(model="ollama/llama3") LLM(model="openrouter/anthropic/claude-3-opus") LLM(model="llama3", provider="ollama") Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix: add is_litellm=True to tests that test litellm-specific methods Tests for _get_custom_llm_provider and _validate_call_params used openrouter/ model prefix which now routes to native provider. Added is_litellm=True to force litellm path since these test litellm-specific internals. --------- Co-authored-by: Joao Moura <joao@crewai.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -309,6 +309,14 @@ SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"aws",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter",
|
||||
"deepseek",
|
||||
"ollama",
|
||||
"ollama_chat",
|
||||
"hosted_vllm",
|
||||
"cerebras",
|
||||
"dashscope",
|
||||
]
|
||||
|
||||
|
||||
@@ -368,6 +376,14 @@ class LLM(BaseLLM):
|
||||
"gemini": "gemini",
|
||||
"bedrock": "bedrock",
|
||||
"aws": "bedrock",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter": "openrouter",
|
||||
"deepseek": "deepseek",
|
||||
"ollama": "ollama",
|
||||
"ollama_chat": "ollama_chat",
|
||||
"hosted_vllm": "hosted_vllm",
|
||||
"cerebras": "cerebras",
|
||||
"dashscope": "dashscope",
|
||||
}
|
||||
|
||||
canonical_provider = provider_mapping.get(prefix.lower())
|
||||
@@ -467,6 +483,29 @@ class LLM(BaseLLM):
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
# OpenAI-compatible providers - accept any model name since these
|
||||
# providers host many different models with varied naming conventions
|
||||
if provider == "deepseek":
|
||||
return model_lower.startswith("deepseek")
|
||||
|
||||
if provider == "ollama" or provider == "ollama_chat":
|
||||
# Ollama accepts any local model name
|
||||
return True
|
||||
|
||||
if provider == "hosted_vllm":
|
||||
# vLLM serves any model
|
||||
return True
|
||||
|
||||
if provider == "cerebras":
|
||||
return True
|
||||
|
||||
if provider == "dashscope":
|
||||
return model_lower.startswith("qwen")
|
||||
|
||||
if provider == "openrouter":
|
||||
# OpenRouter uses org/model format but accepts anything
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@@ -566,6 +605,23 @@ class LLM(BaseLLM):
|
||||
|
||||
return BedrockCompletion
|
||||
|
||||
# OpenAI-compatible providers
|
||||
openai_compatible_providers = {
|
||||
"openrouter",
|
||||
"deepseek",
|
||||
"ollama",
|
||||
"ollama_chat",
|
||||
"hosted_vllm",
|
||||
"cerebras",
|
||||
"dashscope",
|
||||
}
|
||||
if provider in openai_compatible_providers:
|
||||
from crewai.llms.providers.openai_compatible.completion import (
|
||||
OpenAICompatibleCompletion,
|
||||
)
|
||||
|
||||
return OpenAICompatibleCompletion
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
"""OpenAI-compatible providers module."""
|
||||
|
||||
from crewai.llms.providers.openai_compatible.completion import (
|
||||
OPENAI_COMPATIBLE_PROVIDERS,
|
||||
OpenAICompatibleCompletion,
|
||||
ProviderConfig,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OPENAI_COMPATIBLE_PROVIDERS",
|
||||
"OpenAICompatibleCompletion",
|
||||
"ProviderConfig",
|
||||
]
|
||||
@@ -0,0 +1,282 @@
|
||||
"""OpenAI-compatible providers implementation.
|
||||
|
||||
This module provides a thin subclass of OpenAICompletion that supports
|
||||
various OpenAI-compatible APIs like OpenRouter, DeepSeek, Ollama, vLLM,
|
||||
Cerebras, and Dashscope (Alibaba/Qwen).
|
||||
|
||||
Usage:
|
||||
llm = LLM(model="deepseek/deepseek-chat") # Uses DeepSeek API
|
||||
llm = LLM(model="openrouter/anthropic/claude-3-opus") # Uses OpenRouter
|
||||
llm = LLM(model="ollama/llama3") # Uses local Ollama
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderConfig:
|
||||
"""Configuration for an OpenAI-compatible provider.
|
||||
|
||||
Attributes:
|
||||
base_url: Default base URL for the provider's API endpoint.
|
||||
api_key_env: Environment variable name for the API key.
|
||||
base_url_env: Environment variable name for a custom base URL override.
|
||||
default_headers: HTTP headers to include in all requests.
|
||||
api_key_required: Whether an API key is required for this provider.
|
||||
default_api_key: Default API key to use if none is provided and not required.
|
||||
"""
|
||||
|
||||
base_url: str
|
||||
api_key_env: str
|
||||
base_url_env: str | None = None
|
||||
default_headers: dict[str, str] = field(default_factory=dict)
|
||||
api_key_required: bool = True
|
||||
default_api_key: str | None = None
|
||||
|
||||
|
||||
OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
|
||||
"openrouter": ProviderConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key_env="OPENROUTER_API_KEY",
|
||||
base_url_env="OPENROUTER_BASE_URL",
|
||||
default_headers={"HTTP-Referer": "https://crewai.com"},
|
||||
api_key_required=True,
|
||||
),
|
||||
"deepseek": ProviderConfig(
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
api_key_env="DEEPSEEK_API_KEY",
|
||||
base_url_env="DEEPSEEK_BASE_URL",
|
||||
api_key_required=True,
|
||||
),
|
||||
"ollama": ProviderConfig(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key_env="OLLAMA_API_KEY",
|
||||
base_url_env="OLLAMA_HOST",
|
||||
api_key_required=False,
|
||||
default_api_key="ollama",
|
||||
),
|
||||
"ollama_chat": ProviderConfig(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key_env="OLLAMA_API_KEY",
|
||||
base_url_env="OLLAMA_HOST",
|
||||
api_key_required=False,
|
||||
default_api_key="ollama",
|
||||
),
|
||||
"hosted_vllm": ProviderConfig(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key_env="VLLM_API_KEY",
|
||||
base_url_env="VLLM_BASE_URL",
|
||||
api_key_required=False,
|
||||
default_api_key="dummy",
|
||||
),
|
||||
"cerebras": ProviderConfig(
|
||||
base_url="https://api.cerebras.ai/v1",
|
||||
api_key_env="CEREBRAS_API_KEY",
|
||||
base_url_env="CEREBRAS_BASE_URL",
|
||||
api_key_required=True,
|
||||
),
|
||||
"dashscope": ProviderConfig(
|
||||
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
api_key_env="DASHSCOPE_API_KEY",
|
||||
base_url_env="DASHSCOPE_BASE_URL",
|
||||
api_key_required=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_ollama_base_url(base_url: str) -> str:
|
||||
"""Normalize Ollama base URL to ensure it ends with /v1.
|
||||
|
||||
Ollama uses OLLAMA_HOST which may not include the /v1 suffix,
|
||||
but the OpenAI-compatible endpoint requires it.
|
||||
|
||||
Args:
|
||||
base_url: The base URL, potentially without /v1 suffix.
|
||||
|
||||
Returns:
|
||||
The base URL with /v1 suffix if needed.
|
||||
"""
|
||||
base_url = base_url.rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
return f"{base_url}/v1"
|
||||
return base_url
|
||||
|
||||
|
||||
class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
"""OpenAI-compatible completion implementation.
|
||||
|
||||
This class provides support for various OpenAI-compatible APIs by
|
||||
automatically configuring the base URL, API key, and headers based
|
||||
on the provider name.
|
||||
|
||||
Supported providers:
|
||||
- openrouter: OpenRouter (https://openrouter.ai)
|
||||
- deepseek: DeepSeek (https://deepseek.com)
|
||||
- ollama: Ollama local server (https://ollama.ai)
|
||||
- ollama_chat: Alias for ollama
|
||||
- hosted_vllm: vLLM server (https://github.com/vllm-project/vllm)
|
||||
- cerebras: Cerebras (https://cerebras.ai)
|
||||
- dashscope: Alibaba Dashscope/Qwen (https://dashscope.aliyun.com)
|
||||
|
||||
Example:
|
||||
# Using provider prefix
|
||||
llm = LLM(model="deepseek/deepseek-chat")
|
||||
|
||||
# Using explicit provider parameter
|
||||
llm = LLM(model="llama3", provider="ollama")
|
||||
|
||||
# With custom configuration
|
||||
llm = LLM(
|
||||
model="deepseek-chat",
|
||||
provider="deepseek",
|
||||
api_key="my-key",
|
||||
temperature=0.7
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI-compatible completion client.
|
||||
|
||||
Args:
|
||||
model: The model identifier.
|
||||
provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS).
|
||||
api_key: Optional API key override. If not provided, uses the
|
||||
provider's configured environment variable.
|
||||
base_url: Optional base URL override. If not provided, uses the
|
||||
provider's configured default or environment variable.
|
||||
default_headers: Optional headers to merge with provider defaults.
|
||||
**kwargs: Additional arguments passed to OpenAICompletion.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported or required API key
|
||||
is missing.
|
||||
"""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS.get(provider)
|
||||
if config is None:
|
||||
supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys()))
|
||||
raise ValueError(
|
||||
f"Unknown OpenAI-compatible provider: {provider}. "
|
||||
f"Supported providers: {supported}"
|
||||
)
|
||||
|
||||
resolved_api_key = self._resolve_api_key(api_key, config, provider)
|
||||
resolved_base_url = self._resolve_base_url(base_url, config, provider)
|
||||
resolved_headers = self._resolve_headers(default_headers, config)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
provider=provider,
|
||||
api_key=resolved_api_key,
|
||||
base_url=resolved_base_url,
|
||||
default_headers=resolved_headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _resolve_api_key(
|
||||
self,
|
||||
api_key: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
) -> str | None:
|
||||
"""Resolve the API key from explicit value, env var, or default.
|
||||
|
||||
Args:
|
||||
api_key: Explicitly provided API key.
|
||||
config: Provider configuration.
|
||||
provider: Provider name for error messages.
|
||||
|
||||
Returns:
|
||||
The resolved API key.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is required but not found.
|
||||
"""
|
||||
if api_key:
|
||||
return api_key
|
||||
|
||||
env_key = os.getenv(config.api_key_env)
|
||||
if env_key:
|
||||
return env_key
|
||||
|
||||
if config.api_key_required:
|
||||
raise ValueError(
|
||||
f"API key required for {provider}. "
|
||||
f"Set {config.api_key_env} environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
return config.default_api_key
|
||||
|
||||
def _resolve_base_url(
|
||||
self,
|
||||
base_url: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
) -> str:
|
||||
"""Resolve the base URL from explicit value, env var, or default.
|
||||
|
||||
Args:
|
||||
base_url: Explicitly provided base URL.
|
||||
config: Provider configuration.
|
||||
provider: Provider name (used for special handling like Ollama).
|
||||
|
||||
Returns:
|
||||
The resolved base URL.
|
||||
"""
|
||||
if base_url:
|
||||
resolved = base_url
|
||||
elif config.base_url_env:
|
||||
resolved = os.getenv(config.base_url_env, config.base_url)
|
||||
else:
|
||||
resolved = config.base_url
|
||||
|
||||
if provider in ("ollama", "ollama_chat"):
|
||||
resolved = _normalize_ollama_base_url(resolved)
|
||||
|
||||
return resolved
|
||||
|
||||
def _resolve_headers(
|
||||
self,
|
||||
headers: dict[str, str] | None,
|
||||
config: ProviderConfig,
|
||||
) -> dict[str, str] | None:
|
||||
"""Merge user headers with provider default headers.
|
||||
|
||||
Args:
|
||||
headers: User-provided headers.
|
||||
config: Provider configuration.
|
||||
|
||||
Returns:
|
||||
Merged headers dict, or None if empty.
|
||||
"""
|
||||
if not config.default_headers and not headers:
|
||||
return None
|
||||
|
||||
merged = dict(config.default_headers)
|
||||
if headers:
|
||||
merged.update(headers)
|
||||
|
||||
return merged if merged else None
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the provider supports function calling.
|
||||
|
||||
All modern OpenAI-compatible providers support function calling.
|
||||
|
||||
Returns:
|
||||
True, as all supported providers have function calling support.
|
||||
"""
|
||||
return True
|
||||
1
lib/crewai/tests/llms/openai_compatible/__init__.py
Normal file
1
lib/crewai/tests/llms/openai_compatible/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for OpenAI-compatible providers."""
|
||||
@@ -0,0 +1,310 @@
|
||||
"""Tests for OpenAI-compatible providers."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.providers.openai_compatible.completion import (
|
||||
OPENAI_COMPATIBLE_PROVIDERS,
|
||||
OpenAICompatibleCompletion,
|
||||
ProviderConfig,
|
||||
_normalize_ollama_base_url,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderConfig:
|
||||
"""Tests for ProviderConfig dataclass."""
|
||||
|
||||
def test_provider_config_immutable(self):
|
||||
"""Test that ProviderConfig is immutable (frozen)."""
|
||||
config = ProviderConfig(
|
||||
base_url="https://example.com/v1",
|
||||
api_key_env="TEST_API_KEY",
|
||||
)
|
||||
with pytest.raises(AttributeError):
|
||||
config.base_url = "https://other.com/v1"
|
||||
|
||||
def test_provider_config_defaults(self):
|
||||
"""Test ProviderConfig default values."""
|
||||
config = ProviderConfig(
|
||||
base_url="https://example.com/v1",
|
||||
api_key_env="TEST_API_KEY",
|
||||
)
|
||||
assert config.base_url_env is None
|
||||
assert config.default_headers == {}
|
||||
assert config.api_key_required is True
|
||||
assert config.default_api_key is None
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Tests for the OPENAI_COMPATIBLE_PROVIDERS registry."""
|
||||
|
||||
def test_openrouter_config(self):
|
||||
"""Test OpenRouter provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["openrouter"]
|
||||
assert config.base_url == "https://openrouter.ai/api/v1"
|
||||
assert config.api_key_env == "OPENROUTER_API_KEY"
|
||||
assert config.base_url_env == "OPENROUTER_BASE_URL"
|
||||
assert "HTTP-Referer" in config.default_headers
|
||||
assert config.api_key_required is True
|
||||
|
||||
def test_deepseek_config(self):
|
||||
"""Test DeepSeek provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["deepseek"]
|
||||
assert config.base_url == "https://api.deepseek.com/v1"
|
||||
assert config.api_key_env == "DEEPSEEK_API_KEY"
|
||||
assert config.api_key_required is True
|
||||
|
||||
def test_ollama_config(self):
|
||||
"""Test Ollama provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["ollama"]
|
||||
assert config.base_url == "http://localhost:11434/v1"
|
||||
assert config.api_key_env == "OLLAMA_API_KEY"
|
||||
assert config.base_url_env == "OLLAMA_HOST"
|
||||
assert config.api_key_required is False
|
||||
assert config.default_api_key == "ollama"
|
||||
|
||||
def test_ollama_chat_is_alias(self):
|
||||
"""Test ollama_chat is configured same as ollama."""
|
||||
ollama = OPENAI_COMPATIBLE_PROVIDERS["ollama"]
|
||||
ollama_chat = OPENAI_COMPATIBLE_PROVIDERS["ollama_chat"]
|
||||
assert ollama.base_url == ollama_chat.base_url
|
||||
assert ollama.api_key_required == ollama_chat.api_key_required
|
||||
|
||||
def test_hosted_vllm_config(self):
|
||||
"""Test hosted_vllm provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["hosted_vllm"]
|
||||
assert config.base_url == "http://localhost:8000/v1"
|
||||
assert config.api_key_env == "VLLM_API_KEY"
|
||||
assert config.api_key_required is False
|
||||
assert config.default_api_key == "dummy"
|
||||
|
||||
def test_cerebras_config(self):
|
||||
"""Test Cerebras provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["cerebras"]
|
||||
assert config.base_url == "https://api.cerebras.ai/v1"
|
||||
assert config.api_key_env == "CEREBRAS_API_KEY"
|
||||
assert config.api_key_required is True
|
||||
|
||||
def test_dashscope_config(self):
|
||||
"""Test Dashscope provider configuration."""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS["dashscope"]
|
||||
assert config.base_url == "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
||||
assert config.api_key_env == "DASHSCOPE_API_KEY"
|
||||
assert config.api_key_required is True
|
||||
|
||||
|
||||
class TestNormalizeOllamaBaseUrl:
|
||||
"""Tests for _normalize_ollama_base_url helper."""
|
||||
|
||||
def test_adds_v1_suffix(self):
|
||||
"""Test that /v1 is added when missing."""
|
||||
assert _normalize_ollama_base_url("http://localhost:11434") == "http://localhost:11434/v1"
|
||||
|
||||
def test_preserves_existing_v1(self):
|
||||
"""Test that existing /v1 is preserved."""
|
||||
assert _normalize_ollama_base_url("http://localhost:11434/v1") == "http://localhost:11434/v1"
|
||||
|
||||
def test_strips_trailing_slash(self):
|
||||
"""Test that trailing slash is handled."""
|
||||
assert _normalize_ollama_base_url("http://localhost:11434/") == "http://localhost:11434/v1"
|
||||
|
||||
def test_handles_v1_with_trailing_slash(self):
|
||||
"""Test /v1/ is normalized."""
|
||||
assert _normalize_ollama_base_url("http://localhost:11434/v1/") == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
class TestOpenAICompatibleCompletion:
|
||||
"""Tests for OpenAICompatibleCompletion class."""
|
||||
|
||||
def test_unknown_provider_raises_error(self):
|
||||
"""Test that unknown provider raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Unknown OpenAI-compatible provider"):
|
||||
OpenAICompatibleCompletion(model="test", provider="unknown_provider")
|
||||
|
||||
def test_missing_required_api_key_raises_error(self):
|
||||
"""Test that missing required API key raises ValueError."""
|
||||
# Clear any existing env var
|
||||
env_key = "DEEPSEEK_API_KEY"
|
||||
original = os.environ.pop(env_key, None)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="API key required"):
|
||||
OpenAICompatibleCompletion(model="deepseek-chat", provider="deepseek")
|
||||
finally:
|
||||
if original:
|
||||
os.environ[env_key] = original
|
||||
|
||||
def test_api_key_from_env(self):
|
||||
"""Test API key is read from environment variable."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key-from-env"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat", provider="deepseek"
|
||||
)
|
||||
assert completion.api_key == "test-key-from-env"
|
||||
|
||||
def test_explicit_api_key_overrides_env(self):
|
||||
"""Test explicit API key overrides environment variable."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "env-key"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat",
|
||||
provider="deepseek",
|
||||
api_key="explicit-key",
|
||||
)
|
||||
assert completion.api_key == "explicit-key"
|
||||
|
||||
def test_default_api_key_for_optional_providers(self):
|
||||
"""Test default API key is used for providers that don't require it."""
|
||||
# Ollama doesn't require API key
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
assert completion.api_key == "ollama"
|
||||
|
||||
def test_base_url_from_config(self):
|
||||
"""Test base URL is set from provider config."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat", provider="deepseek"
|
||||
)
|
||||
assert completion.base_url == "https://api.deepseek.com/v1"
|
||||
|
||||
def test_base_url_from_env(self):
|
||||
"""Test base URL is read from environment variable."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://custom.deepseek.com/v1"},
|
||||
):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat", provider="deepseek"
|
||||
)
|
||||
assert completion.base_url == "https://custom.deepseek.com/v1"
|
||||
|
||||
def test_explicit_base_url_overrides_all(self):
|
||||
"""Test explicit base URL overrides env and config."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://env.deepseek.com/v1"},
|
||||
):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="deepseek-chat",
|
||||
provider="deepseek",
|
||||
base_url="https://explicit.deepseek.com/v1",
|
||||
)
|
||||
assert completion.base_url == "https://explicit.deepseek.com/v1"
|
||||
|
||||
def test_ollama_base_url_normalized(self):
|
||||
"""Test Ollama base URL is normalized to include /v1."""
|
||||
with patch.dict(os.environ, {"OLLAMA_HOST": "http://custom-ollama:11434"}):
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
assert completion.base_url == "http://custom-ollama:11434/v1"
|
||||
|
||||
def test_openrouter_headers(self):
|
||||
"""Test OpenRouter has HTTP-Referer header."""
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="anthropic/claude-3-opus", provider="openrouter"
|
||||
)
|
||||
assert completion.default_headers is not None
|
||||
assert "HTTP-Referer" in completion.default_headers
|
||||
|
||||
def test_custom_headers_merged_with_defaults(self):
|
||||
"""Test custom headers are merged with provider defaults."""
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
completion = OpenAICompatibleCompletion(
|
||||
model="anthropic/claude-3-opus",
|
||||
provider="openrouter",
|
||||
default_headers={"X-Custom": "value"},
|
||||
)
|
||||
assert completion.default_headers is not None
|
||||
assert "HTTP-Referer" in completion.default_headers
|
||||
assert completion.default_headers.get("X-Custom") == "value"
|
||||
|
||||
def test_supports_function_calling(self):
|
||||
"""Test that function calling is supported."""
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
assert completion.supports_function_calling() is True
|
||||
|
||||
|
||||
class TestLLMIntegration:
|
||||
"""Tests for LLM factory integration with OpenAI-compatible providers."""
|
||||
|
||||
def test_llm_creates_openai_compatible_for_deepseek(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for DeepSeek."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
llm = LLM(model="deepseek/deepseek-chat")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "deepseek"
|
||||
assert llm.model == "deepseek-chat"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_ollama(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for Ollama."""
|
||||
llm = LLM(model="ollama/llama3")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "ollama"
|
||||
assert llm.model == "llama3"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_openrouter(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for OpenRouter."""
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
llm = LLM(model="openrouter/anthropic/claude-3-opus")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "openrouter"
|
||||
# Model should include the full path after provider prefix
|
||||
assert llm.model == "anthropic/claude-3-opus"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_hosted_vllm(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for hosted_vllm."""
|
||||
llm = LLM(model="hosted_vllm/meta-llama/Llama-3-8b")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "hosted_vllm"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_cerebras(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for Cerebras."""
|
||||
with patch.dict(os.environ, {"CEREBRAS_API_KEY": "test-key"}):
|
||||
llm = LLM(model="cerebras/llama3-8b")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "cerebras"
|
||||
|
||||
def test_llm_creates_openai_compatible_for_dashscope(self):
|
||||
"""Test LLM factory creates OpenAICompatibleCompletion for Dashscope."""
|
||||
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test-key"}):
|
||||
llm = LLM(model="dashscope/qwen-turbo")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "dashscope"
|
||||
|
||||
def test_llm_with_explicit_provider(self):
|
||||
"""Test LLM with explicit provider parameter."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
llm = LLM(model="deepseek-chat", provider="deepseek")
|
||||
assert isinstance(llm, OpenAICompatibleCompletion)
|
||||
assert llm.provider == "deepseek"
|
||||
assert llm.model == "deepseek-chat"
|
||||
|
||||
def test_llm_passes_kwargs_to_completion(self):
|
||||
"""Test LLM passes kwargs to OpenAICompatibleCompletion."""
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
|
||||
llm = LLM(
|
||||
model="deepseek/deepseek-chat",
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
)
|
||||
assert llm.temperature == 0.7
|
||||
assert llm.max_tokens == 1000
|
||||
|
||||
|
||||
class TestCallMocking:
|
||||
"""Tests for mocking the call method."""
|
||||
|
||||
def test_call_method_can_be_mocked(self):
|
||||
"""Test that the call method can be mocked for testing."""
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
|
||||
with patch.object(completion, "call", return_value="Mocked response"):
|
||||
result = completion.call("Test message")
|
||||
assert result == "Mocked response"
|
||||
|
||||
def test_acall_method_exists(self):
|
||||
"""Test that acall method exists for async calls."""
|
||||
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
|
||||
assert hasattr(completion, "acall")
|
||||
assert callable(completion.acall)
|
||||
@@ -211,7 +211,7 @@ def test_llm_passes_additional_params():
|
||||
|
||||
|
||||
def test_get_custom_llm_provider_openrouter():
|
||||
llm = LLM(model="openrouter/deepseek/deepseek-chat")
|
||||
llm = LLM(model="openrouter/deepseek/deepseek-chat", is_litellm=True)
|
||||
assert llm._get_custom_llm_provider() == "openrouter"
|
||||
|
||||
|
||||
@@ -232,7 +232,9 @@ def test_validate_call_params_supported():
|
||||
# Patch supports_response_schema to simulate a supported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=True):
|
||||
llm = LLM(
|
||||
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse
|
||||
model="openrouter/deepseek/deepseek-chat",
|
||||
response_format=DummyResponse,
|
||||
is_litellm=True,
|
||||
)
|
||||
# Should not raise any error.
|
||||
llm._validate_call_params()
|
||||
|
||||
Reference in New Issue
Block a user