Files
crewAI/lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py
alex-clawd c183b77991
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
fix: address Copilot review on OpenAI-compatible providers (#5042) (#5089)
- Delegate supports_function_calling() to parent (handles o1 models via OpenRouter)
- Guard empty env vars in base_url resolution
- Fix misleading comment about model validation rules
- Remove unused MagicMock import
- Use 'is not None' for env var restoration in tests

Co-authored-by: Joao Moura <joao@crewai.com>
2026-03-25 18:22:13 -03:00

311 lines
13 KiB
Python

"""Tests for OpenAI-compatible providers."""
import os
from unittest.mock import patch
import pytest
from crewai.llm import LLM
from crewai.llms.providers.openai_compatible.completion import (
OPENAI_COMPATIBLE_PROVIDERS,
OpenAICompatibleCompletion,
ProviderConfig,
_normalize_ollama_base_url,
)
class TestProviderConfig:
"""Tests for ProviderConfig dataclass."""
def test_provider_config_immutable(self):
"""Test that ProviderConfig is immutable (frozen)."""
config = ProviderConfig(
base_url="https://example.com/v1",
api_key_env="TEST_API_KEY",
)
with pytest.raises(AttributeError):
config.base_url = "https://other.com/v1"
def test_provider_config_defaults(self):
"""Test ProviderConfig default values."""
config = ProviderConfig(
base_url="https://example.com/v1",
api_key_env="TEST_API_KEY",
)
assert config.base_url_env is None
assert config.default_headers == {}
assert config.api_key_required is True
assert config.default_api_key is None
class TestProviderRegistry:
"""Tests for the OPENAI_COMPATIBLE_PROVIDERS registry."""
def test_openrouter_config(self):
"""Test OpenRouter provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["openrouter"]
assert config.base_url == "https://openrouter.ai/api/v1"
assert config.api_key_env == "OPENROUTER_API_KEY"
assert config.base_url_env == "OPENROUTER_BASE_URL"
assert "HTTP-Referer" in config.default_headers
assert config.api_key_required is True
def test_deepseek_config(self):
"""Test DeepSeek provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["deepseek"]
assert config.base_url == "https://api.deepseek.com/v1"
assert config.api_key_env == "DEEPSEEK_API_KEY"
assert config.api_key_required is True
def test_ollama_config(self):
"""Test Ollama provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["ollama"]
assert config.base_url == "http://localhost:11434/v1"
assert config.api_key_env == "OLLAMA_API_KEY"
assert config.base_url_env == "OLLAMA_HOST"
assert config.api_key_required is False
assert config.default_api_key == "ollama"
def test_ollama_chat_is_alias(self):
"""Test ollama_chat is configured same as ollama."""
ollama = OPENAI_COMPATIBLE_PROVIDERS["ollama"]
ollama_chat = OPENAI_COMPATIBLE_PROVIDERS["ollama_chat"]
assert ollama.base_url == ollama_chat.base_url
assert ollama.api_key_required == ollama_chat.api_key_required
def test_hosted_vllm_config(self):
"""Test hosted_vllm provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["hosted_vllm"]
assert config.base_url == "http://localhost:8000/v1"
assert config.api_key_env == "VLLM_API_KEY"
assert config.api_key_required is False
assert config.default_api_key == "dummy"
def test_cerebras_config(self):
"""Test Cerebras provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["cerebras"]
assert config.base_url == "https://api.cerebras.ai/v1"
assert config.api_key_env == "CEREBRAS_API_KEY"
assert config.api_key_required is True
def test_dashscope_config(self):
"""Test Dashscope provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["dashscope"]
assert config.base_url == "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
assert config.api_key_env == "DASHSCOPE_API_KEY"
assert config.api_key_required is True
class TestNormalizeOllamaBaseUrl:
"""Tests for _normalize_ollama_base_url helper."""
def test_adds_v1_suffix(self):
"""Test that /v1 is added when missing."""
assert _normalize_ollama_base_url("http://localhost:11434") == "http://localhost:11434/v1"
def test_preserves_existing_v1(self):
"""Test that existing /v1 is preserved."""
assert _normalize_ollama_base_url("http://localhost:11434/v1") == "http://localhost:11434/v1"
def test_strips_trailing_slash(self):
"""Test that trailing slash is handled."""
assert _normalize_ollama_base_url("http://localhost:11434/") == "http://localhost:11434/v1"
def test_handles_v1_with_trailing_slash(self):
"""Test /v1/ is normalized."""
assert _normalize_ollama_base_url("http://localhost:11434/v1/") == "http://localhost:11434/v1"
class TestOpenAICompatibleCompletion:
"""Tests for OpenAICompatibleCompletion class."""
def test_unknown_provider_raises_error(self):
"""Test that unknown provider raises ValueError."""
with pytest.raises(ValueError, match="Unknown OpenAI-compatible provider"):
OpenAICompatibleCompletion(model="test", provider="unknown_provider")
def test_missing_required_api_key_raises_error(self):
"""Test that missing required API key raises ValueError."""
# Clear any existing env var
env_key = "DEEPSEEK_API_KEY"
original = os.environ.pop(env_key, None)
try:
with pytest.raises(ValueError, match="API key required"):
OpenAICompatibleCompletion(model="deepseek-chat", provider="deepseek")
finally:
if original is not None:
os.environ[env_key] = original
def test_api_key_from_env(self):
"""Test API key is read from environment variable."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key-from-env"}):
completion = OpenAICompatibleCompletion(
model="deepseek-chat", provider="deepseek"
)
assert completion.api_key == "test-key-from-env"
def test_explicit_api_key_overrides_env(self):
"""Test explicit API key overrides environment variable."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "env-key"}):
completion = OpenAICompatibleCompletion(
model="deepseek-chat",
provider="deepseek",
api_key="explicit-key",
)
assert completion.api_key == "explicit-key"
def test_default_api_key_for_optional_providers(self):
"""Test default API key is used for providers that don't require it."""
# Ollama doesn't require API key
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
assert completion.api_key == "ollama"
def test_base_url_from_config(self):
"""Test base URL is set from provider config."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
completion = OpenAICompatibleCompletion(
model="deepseek-chat", provider="deepseek"
)
assert completion.base_url == "https://api.deepseek.com/v1"
def test_base_url_from_env(self):
"""Test base URL is read from environment variable."""
with patch.dict(
os.environ,
{"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://custom.deepseek.com/v1"},
):
completion = OpenAICompatibleCompletion(
model="deepseek-chat", provider="deepseek"
)
assert completion.base_url == "https://custom.deepseek.com/v1"
def test_explicit_base_url_overrides_all(self):
"""Test explicit base URL overrides env and config."""
with patch.dict(
os.environ,
{"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://env.deepseek.com/v1"},
):
completion = OpenAICompatibleCompletion(
model="deepseek-chat",
provider="deepseek",
base_url="https://explicit.deepseek.com/v1",
)
assert completion.base_url == "https://explicit.deepseek.com/v1"
def test_ollama_base_url_normalized(self):
"""Test Ollama base URL is normalized to include /v1."""
with patch.dict(os.environ, {"OLLAMA_HOST": "http://custom-ollama:11434"}):
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
assert completion.base_url == "http://custom-ollama:11434/v1"
def test_openrouter_headers(self):
"""Test OpenRouter has HTTP-Referer header."""
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
completion = OpenAICompatibleCompletion(
model="anthropic/claude-3-opus", provider="openrouter"
)
assert completion.default_headers is not None
assert "HTTP-Referer" in completion.default_headers
def test_custom_headers_merged_with_defaults(self):
"""Test custom headers are merged with provider defaults."""
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
completion = OpenAICompatibleCompletion(
model="anthropic/claude-3-opus",
provider="openrouter",
default_headers={"X-Custom": "value"},
)
assert completion.default_headers is not None
assert "HTTP-Referer" in completion.default_headers
assert completion.default_headers.get("X-Custom") == "value"
def test_supports_function_calling(self):
"""Test that function calling is supported."""
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
assert completion.supports_function_calling() is True
class TestLLMIntegration:
"""Tests for LLM factory integration with OpenAI-compatible providers."""
def test_llm_creates_openai_compatible_for_deepseek(self):
"""Test LLM factory creates OpenAICompatibleCompletion for DeepSeek."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
llm = LLM(model="deepseek/deepseek-chat")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "deepseek"
assert llm.model == "deepseek-chat"
def test_llm_creates_openai_compatible_for_ollama(self):
"""Test LLM factory creates OpenAICompatibleCompletion for Ollama."""
llm = LLM(model="ollama/llama3")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "ollama"
assert llm.model == "llama3"
def test_llm_creates_openai_compatible_for_openrouter(self):
"""Test LLM factory creates OpenAICompatibleCompletion for OpenRouter."""
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
llm = LLM(model="openrouter/anthropic/claude-3-opus")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "openrouter"
# Model should include the full path after provider prefix
assert llm.model == "anthropic/claude-3-opus"
def test_llm_creates_openai_compatible_for_hosted_vllm(self):
"""Test LLM factory creates OpenAICompatibleCompletion for hosted_vllm."""
llm = LLM(model="hosted_vllm/meta-llama/Llama-3-8b")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "hosted_vllm"
def test_llm_creates_openai_compatible_for_cerebras(self):
"""Test LLM factory creates OpenAICompatibleCompletion for Cerebras."""
with patch.dict(os.environ, {"CEREBRAS_API_KEY": "test-key"}):
llm = LLM(model="cerebras/llama3-8b")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "cerebras"
def test_llm_creates_openai_compatible_for_dashscope(self):
"""Test LLM factory creates OpenAICompatibleCompletion for Dashscope."""
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test-key"}):
llm = LLM(model="dashscope/qwen-turbo")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "dashscope"
def test_llm_with_explicit_provider(self):
"""Test LLM with explicit provider parameter."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
llm = LLM(model="deepseek-chat", provider="deepseek")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "deepseek"
assert llm.model == "deepseek-chat"
def test_llm_passes_kwargs_to_completion(self):
"""Test LLM passes kwargs to OpenAICompatibleCompletion."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
llm = LLM(
model="deepseek/deepseek-chat",
temperature=0.7,
max_tokens=1000,
)
assert llm.temperature == 0.7
assert llm.max_tokens == 1000
class TestCallMocking:
"""Tests for mocking the call method."""
def test_call_method_can_be_mocked(self):
"""Test that the call method can be mocked for testing."""
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
with patch.object(completion, "call", return_value="Mocked response"):
result = completion.call("Test message")
assert result == "Mocked response"
def test_acall_method_exists(self):
"""Test that acall method exists for async calls."""
completion = OpenAICompatibleCompletion(model="llama3", provider="ollama")
assert hasattr(completion, "acall")
assert callable(completion.acall)