From 7f5ffce057796bb86fe2e884c0e453710cface35 Mon Sep 17 00:00:00 2001 From: alex-clawd Date: Tue, 24 Mar 2026 08:05:43 -0700 Subject: [PATCH] feat: native OpenAI-compatible providers (OpenRouter, DeepSeek, Ollama, vLLM, Cerebras, Dashscope) (#5042) * feat: add native OpenAI-compatible providers (OpenRouter, DeepSeek, Ollama, vLLM, Cerebras, Dashscope) Add a data-driven OpenAI-compatible provider system that enables native support for multiple third-party APIs that implement the OpenAI API specification. New providers: - OpenRouter: 500+ models via openrouter.ai - DeepSeek: deepseek-chat, deepseek-coder, deepseek-reasoner - Ollama: local models (llama3, mistral, codellama, etc.) - hosted_vllm: self-hosted vLLM servers - Cerebras: ultra-fast inference - Dashscope: Alibaba Qwen models (qwen-turbo, qwen-max, etc.) Architecture: - Single OpenAICompatibleCompletion class extends OpenAICompletion - ProviderConfig dataclass stores per-provider settings - Registry dict makes adding new providers a single config entry - Handles provider-specific quirks (OpenRouter headers, Ollama base URL normalization, optional API keys) Usage: LLM(model="deepseek/deepseek-chat") LLM(model="ollama/llama3") LLM(model="openrouter/anthropic/claude-3-opus") LLM(model="llama3", provider="ollama") Co-Authored-By: Claude Opus 4.5 * fix: add is_litellm=True to tests that test litellm-specific methods Tests for _get_custom_llm_provider and _validate_call_params used openrouter/ model prefix which now routes to native provider. Added is_litellm=True to force litellm path since these test litellm-specific internals. --------- Co-authored-by: Joao Moura Co-authored-by: Claude Opus 4.5 --- lib/crewai/src/crewai/llm.py | 56 ++++ .../providers/openai_compatible/__init__.py | 14 + .../providers/openai_compatible/completion.py | 282 ++++++++++++++++ .../tests/llms/openai_compatible/__init__.py | 1 + .../test_openai_compatible.py | 310 ++++++++++++++++++ lib/crewai/tests/test_llm.py | 6 +- 6 files changed, 667 insertions(+), 2 deletions(-) create mode 100644 lib/crewai/src/crewai/llms/providers/openai_compatible/__init__.py create mode 100644 lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py create mode 100644 lib/crewai/tests/llms/openai_compatible/__init__.py create mode 100644 lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 6bf7c0942..cfb369c75 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -309,6 +309,14 @@ SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [ "gemini", "bedrock", "aws", + # OpenAI-compatible providers + "openrouter", + "deepseek", + "ollama", + "ollama_chat", + "hosted_vllm", + "cerebras", + "dashscope", ] @@ -368,6 +376,14 @@ class LLM(BaseLLM): "gemini": "gemini", "bedrock": "bedrock", "aws": "bedrock", + # OpenAI-compatible providers + "openrouter": "openrouter", + "deepseek": "deepseek", + "ollama": "ollama", + "ollama_chat": "ollama_chat", + "hosted_vllm": "hosted_vllm", + "cerebras": "cerebras", + "dashscope": "dashscope", } canonical_provider = provider_mapping.get(prefix.lower()) @@ -467,6 +483,29 @@ class LLM(BaseLLM): for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"] ) + # OpenAI-compatible providers - accept any model name since these + # providers host many different models with varied naming conventions + if provider == "deepseek": + return model_lower.startswith("deepseek") + + if provider == "ollama" or provider == "ollama_chat": + # Ollama accepts any local model name + return True + + if provider == "hosted_vllm": + # vLLM serves any model + return True + + if provider == "cerebras": + return True + + if provider == "dashscope": + return model_lower.startswith("qwen") + + if provider == "openrouter": + # OpenRouter uses org/model format but accepts anything + return True + return False @classmethod @@ -566,6 +605,23 @@ class LLM(BaseLLM): return BedrockCompletion + # OpenAI-compatible providers + openai_compatible_providers = { + "openrouter", + "deepseek", + "ollama", + "ollama_chat", + "hosted_vllm", + "cerebras", + "dashscope", + } + if provider in openai_compatible_providers: + from crewai.llms.providers.openai_compatible.completion import ( + OpenAICompatibleCompletion, + ) + + return OpenAICompatibleCompletion + return None def __init__( diff --git a/lib/crewai/src/crewai/llms/providers/openai_compatible/__init__.py b/lib/crewai/src/crewai/llms/providers/openai_compatible/__init__.py new file mode 100644 index 000000000..12683e8cf --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/openai_compatible/__init__.py @@ -0,0 +1,14 @@ +"""OpenAI-compatible providers module.""" + +from crewai.llms.providers.openai_compatible.completion import ( + OPENAI_COMPATIBLE_PROVIDERS, + OpenAICompatibleCompletion, + ProviderConfig, +) + + +__all__ = [ + "OPENAI_COMPATIBLE_PROVIDERS", + "OpenAICompatibleCompletion", + "ProviderConfig", +] diff --git a/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py b/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py new file mode 100644 index 000000000..9c308f52e --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py @@ -0,0 +1,282 @@ +"""OpenAI-compatible providers implementation. + +This module provides a thin subclass of OpenAICompletion that supports +various OpenAI-compatible APIs like OpenRouter, DeepSeek, Ollama, vLLM, +Cerebras, and Dashscope (Alibaba/Qwen). + +Usage: + llm = LLM(model="deepseek/deepseek-chat") # Uses DeepSeek API + llm = LLM(model="openrouter/anthropic/claude-3-opus") # Uses OpenRouter + llm = LLM(model="ollama/llama3") # Uses local Ollama +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import os +from typing import Any + +from crewai.llms.providers.openai.completion import OpenAICompletion + + +@dataclass(frozen=True) +class ProviderConfig: + """Configuration for an OpenAI-compatible provider. + + Attributes: + base_url: Default base URL for the provider's API endpoint. + api_key_env: Environment variable name for the API key. + base_url_env: Environment variable name for a custom base URL override. + default_headers: HTTP headers to include in all requests. + api_key_required: Whether an API key is required for this provider. + default_api_key: Default API key to use if none is provided and not required. + """ + + base_url: str + api_key_env: str + base_url_env: str | None = None + default_headers: dict[str, str] = field(default_factory=dict) + api_key_required: bool = True + default_api_key: str | None = None + + +OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = { + "openrouter": ProviderConfig( + base_url="https://openrouter.ai/api/v1", + api_key_env="OPENROUTER_API_KEY", + base_url_env="OPENROUTER_BASE_URL", + default_headers={"HTTP-Referer": "https://crewai.com"}, + api_key_required=True, + ), + "deepseek": ProviderConfig( + base_url="https://api.deepseek.com/v1", + api_key_env="DEEPSEEK_API_KEY", + base_url_env="DEEPSEEK_BASE_URL", + api_key_required=True, + ), + "ollama": ProviderConfig( + base_url="http://localhost:11434/v1", + api_key_env="OLLAMA_API_KEY", + base_url_env="OLLAMA_HOST", + api_key_required=False, + default_api_key="ollama", + ), + "ollama_chat": ProviderConfig( + base_url="http://localhost:11434/v1", + api_key_env="OLLAMA_API_KEY", + base_url_env="OLLAMA_HOST", + api_key_required=False, + default_api_key="ollama", + ), + "hosted_vllm": ProviderConfig( + base_url="http://localhost:8000/v1", + api_key_env="VLLM_API_KEY", + base_url_env="VLLM_BASE_URL", + api_key_required=False, + default_api_key="dummy", + ), + "cerebras": ProviderConfig( + base_url="https://api.cerebras.ai/v1", + api_key_env="CEREBRAS_API_KEY", + base_url_env="CEREBRAS_BASE_URL", + api_key_required=True, + ), + "dashscope": ProviderConfig( + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + api_key_env="DASHSCOPE_API_KEY", + base_url_env="DASHSCOPE_BASE_URL", + api_key_required=True, + ), +} + + +def _normalize_ollama_base_url(base_url: str) -> str: + """Normalize Ollama base URL to ensure it ends with /v1. + + Ollama uses OLLAMA_HOST which may not include the /v1 suffix, + but the OpenAI-compatible endpoint requires it. + + Args: + base_url: The base URL, potentially without /v1 suffix. + + Returns: + The base URL with /v1 suffix if needed. + """ + base_url = base_url.rstrip("/") + if not base_url.endswith("/v1"): + return f"{base_url}/v1" + return base_url + + +class OpenAICompatibleCompletion(OpenAICompletion): + """OpenAI-compatible completion implementation. + + This class provides support for various OpenAI-compatible APIs by + automatically configuring the base URL, API key, and headers based + on the provider name. + + Supported providers: + - openrouter: OpenRouter (https://openrouter.ai) + - deepseek: DeepSeek (https://deepseek.com) + - ollama: Ollama local server (https://ollama.ai) + - ollama_chat: Alias for ollama + - hosted_vllm: vLLM server (https://github.com/vllm-project/vllm) + - cerebras: Cerebras (https://cerebras.ai) + - dashscope: Alibaba Dashscope/Qwen (https://dashscope.aliyun.com) + + Example: + # Using provider prefix + llm = LLM(model="deepseek/deepseek-chat") + + # Using explicit provider parameter + llm = LLM(model="llama3", provider="ollama") + + # With custom configuration + llm = LLM( + model="deepseek-chat", + provider="deepseek", + api_key="my-key", + temperature=0.7 + ) + """ + + def __init__( + self, + model: str, + provider: str, + api_key: str | None = None, + base_url: str | None = None, + default_headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + """Initialize OpenAI-compatible completion client. + + Args: + model: The model identifier. + provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS). + api_key: Optional API key override. If not provided, uses the + provider's configured environment variable. + base_url: Optional base URL override. If not provided, uses the + provider's configured default or environment variable. + default_headers: Optional headers to merge with provider defaults. + **kwargs: Additional arguments passed to OpenAICompletion. + + Raises: + ValueError: If the provider is not supported or required API key + is missing. + """ + config = OPENAI_COMPATIBLE_PROVIDERS.get(provider) + if config is None: + supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys())) + raise ValueError( + f"Unknown OpenAI-compatible provider: {provider}. " + f"Supported providers: {supported}" + ) + + resolved_api_key = self._resolve_api_key(api_key, config, provider) + resolved_base_url = self._resolve_base_url(base_url, config, provider) + resolved_headers = self._resolve_headers(default_headers, config) + + super().__init__( + model=model, + provider=provider, + api_key=resolved_api_key, + base_url=resolved_base_url, + default_headers=resolved_headers, + **kwargs, + ) + + def _resolve_api_key( + self, + api_key: str | None, + config: ProviderConfig, + provider: str, + ) -> str | None: + """Resolve the API key from explicit value, env var, or default. + + Args: + api_key: Explicitly provided API key. + config: Provider configuration. + provider: Provider name for error messages. + + Returns: + The resolved API key. + + Raises: + ValueError: If API key is required but not found. + """ + if api_key: + return api_key + + env_key = os.getenv(config.api_key_env) + if env_key: + return env_key + + if config.api_key_required: + raise ValueError( + f"API key required for {provider}. " + f"Set {config.api_key_env} environment variable or pass api_key parameter." + ) + + return config.default_api_key + + def _resolve_base_url( + self, + base_url: str | None, + config: ProviderConfig, + provider: str, + ) -> str: + """Resolve the base URL from explicit value, env var, or default. + + Args: + base_url: Explicitly provided base URL. + config: Provider configuration. + provider: Provider name (used for special handling like Ollama). + + Returns: + The resolved base URL. + """ + if base_url: + resolved = base_url + elif config.base_url_env: + resolved = os.getenv(config.base_url_env, config.base_url) + else: + resolved = config.base_url + + if provider in ("ollama", "ollama_chat"): + resolved = _normalize_ollama_base_url(resolved) + + return resolved + + def _resolve_headers( + self, + headers: dict[str, str] | None, + config: ProviderConfig, + ) -> dict[str, str] | None: + """Merge user headers with provider default headers. + + Args: + headers: User-provided headers. + config: Provider configuration. + + Returns: + Merged headers dict, or None if empty. + """ + if not config.default_headers and not headers: + return None + + merged = dict(config.default_headers) + if headers: + merged.update(headers) + + return merged if merged else None + + def supports_function_calling(self) -> bool: + """Check if the provider supports function calling. + + All modern OpenAI-compatible providers support function calling. + + Returns: + True, as all supported providers have function calling support. + """ + return True diff --git a/lib/crewai/tests/llms/openai_compatible/__init__.py b/lib/crewai/tests/llms/openai_compatible/__init__.py new file mode 100644 index 000000000..bb8da735f --- /dev/null +++ b/lib/crewai/tests/llms/openai_compatible/__init__.py @@ -0,0 +1 @@ +"""Tests for OpenAI-compatible providers.""" diff --git a/lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py b/lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py new file mode 100644 index 000000000..ade54fb8c --- /dev/null +++ b/lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py @@ -0,0 +1,310 @@ +"""Tests for OpenAI-compatible providers.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.llm import LLM +from crewai.llms.providers.openai_compatible.completion import ( + OPENAI_COMPATIBLE_PROVIDERS, + OpenAICompatibleCompletion, + ProviderConfig, + _normalize_ollama_base_url, +) + + +class TestProviderConfig: + """Tests for ProviderConfig dataclass.""" + + def test_provider_config_immutable(self): + """Test that ProviderConfig is immutable (frozen).""" + config = ProviderConfig( + base_url="https://example.com/v1", + api_key_env="TEST_API_KEY", + ) + with pytest.raises(AttributeError): + config.base_url = "https://other.com/v1" + + def test_provider_config_defaults(self): + """Test ProviderConfig default values.""" + config = ProviderConfig( + base_url="https://example.com/v1", + api_key_env="TEST_API_KEY", + ) + assert config.base_url_env is None + assert config.default_headers == {} + assert config.api_key_required is True + assert config.default_api_key is None + + +class TestProviderRegistry: + """Tests for the OPENAI_COMPATIBLE_PROVIDERS registry.""" + + def test_openrouter_config(self): + """Test OpenRouter provider configuration.""" + config = OPENAI_COMPATIBLE_PROVIDERS["openrouter"] + assert config.base_url == "https://openrouter.ai/api/v1" + assert config.api_key_env == "OPENROUTER_API_KEY" + assert config.base_url_env == "OPENROUTER_BASE_URL" + assert "HTTP-Referer" in config.default_headers + assert config.api_key_required is True + + def test_deepseek_config(self): + """Test DeepSeek provider configuration.""" + config = OPENAI_COMPATIBLE_PROVIDERS["deepseek"] + assert config.base_url == "https://api.deepseek.com/v1" + assert config.api_key_env == "DEEPSEEK_API_KEY" + assert config.api_key_required is True + + def test_ollama_config(self): + """Test Ollama provider configuration.""" + config = OPENAI_COMPATIBLE_PROVIDERS["ollama"] + assert config.base_url == "http://localhost:11434/v1" + assert config.api_key_env == "OLLAMA_API_KEY" + assert config.base_url_env == "OLLAMA_HOST" + assert config.api_key_required is False + assert config.default_api_key == "ollama" + + def test_ollama_chat_is_alias(self): + """Test ollama_chat is configured same as ollama.""" + ollama = OPENAI_COMPATIBLE_PROVIDERS["ollama"] + ollama_chat = OPENAI_COMPATIBLE_PROVIDERS["ollama_chat"] + assert ollama.base_url == ollama_chat.base_url + assert ollama.api_key_required == ollama_chat.api_key_required + + def test_hosted_vllm_config(self): + """Test hosted_vllm provider configuration.""" + config = OPENAI_COMPATIBLE_PROVIDERS["hosted_vllm"] + assert config.base_url == "http://localhost:8000/v1" + assert config.api_key_env == "VLLM_API_KEY" + assert config.api_key_required is False + assert config.default_api_key == "dummy" + + def test_cerebras_config(self): + """Test Cerebras provider configuration.""" + config = OPENAI_COMPATIBLE_PROVIDERS["cerebras"] + assert config.base_url == "https://api.cerebras.ai/v1" + assert config.api_key_env == "CEREBRAS_API_KEY" + assert config.api_key_required is True + + def test_dashscope_config(self): + """Test Dashscope provider configuration.""" + config = OPENAI_COMPATIBLE_PROVIDERS["dashscope"] + assert config.base_url == "https://dashscope-intl.aliyuncs.com/compatible-mode/v1" + assert config.api_key_env == "DASHSCOPE_API_KEY" + assert config.api_key_required is True + + +class TestNormalizeOllamaBaseUrl: + """Tests for _normalize_ollama_base_url helper.""" + + def test_adds_v1_suffix(self): + """Test that /v1 is added when missing.""" + assert _normalize_ollama_base_url("http://localhost:11434") == "http://localhost:11434/v1" + + def test_preserves_existing_v1(self): + """Test that existing /v1 is preserved.""" + assert _normalize_ollama_base_url("http://localhost:11434/v1") == "http://localhost:11434/v1" + + def test_strips_trailing_slash(self): + """Test that trailing slash is handled.""" + assert _normalize_ollama_base_url("http://localhost:11434/") == "http://localhost:11434/v1" + + def test_handles_v1_with_trailing_slash(self): + """Test /v1/ is normalized.""" + assert _normalize_ollama_base_url("http://localhost:11434/v1/") == "http://localhost:11434/v1" + + +class TestOpenAICompatibleCompletion: + """Tests for OpenAICompatibleCompletion class.""" + + def test_unknown_provider_raises_error(self): + """Test that unknown provider raises ValueError.""" + with pytest.raises(ValueError, match="Unknown OpenAI-compatible provider"): + OpenAICompatibleCompletion(model="test", provider="unknown_provider") + + def test_missing_required_api_key_raises_error(self): + """Test that missing required API key raises ValueError.""" + # Clear any existing env var + env_key = "DEEPSEEK_API_KEY" + original = os.environ.pop(env_key, None) + try: + with pytest.raises(ValueError, match="API key required"): + OpenAICompatibleCompletion(model="deepseek-chat", provider="deepseek") + finally: + if original: + os.environ[env_key] = original + + def test_api_key_from_env(self): + """Test API key is read from environment variable.""" + with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key-from-env"}): + completion = OpenAICompatibleCompletion( + model="deepseek-chat", provider="deepseek" + ) + assert completion.api_key == "test-key-from-env" + + def test_explicit_api_key_overrides_env(self): + """Test explicit API key overrides environment variable.""" + with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "env-key"}): + completion = OpenAICompatibleCompletion( + model="deepseek-chat", + provider="deepseek", + api_key="explicit-key", + ) + assert completion.api_key == "explicit-key" + + def test_default_api_key_for_optional_providers(self): + """Test default API key is used for providers that don't require it.""" + # Ollama doesn't require API key + completion = OpenAICompatibleCompletion(model="llama3", provider="ollama") + assert completion.api_key == "ollama" + + def test_base_url_from_config(self): + """Test base URL is set from provider config.""" + with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}): + completion = OpenAICompatibleCompletion( + model="deepseek-chat", provider="deepseek" + ) + assert completion.base_url == "https://api.deepseek.com/v1" + + def test_base_url_from_env(self): + """Test base URL is read from environment variable.""" + with patch.dict( + os.environ, + {"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://custom.deepseek.com/v1"}, + ): + completion = OpenAICompatibleCompletion( + model="deepseek-chat", provider="deepseek" + ) + assert completion.base_url == "https://custom.deepseek.com/v1" + + def test_explicit_base_url_overrides_all(self): + """Test explicit base URL overrides env and config.""" + with patch.dict( + os.environ, + {"DEEPSEEK_API_KEY": "test-key", "DEEPSEEK_BASE_URL": "https://env.deepseek.com/v1"}, + ): + completion = OpenAICompatibleCompletion( + model="deepseek-chat", + provider="deepseek", + base_url="https://explicit.deepseek.com/v1", + ) + assert completion.base_url == "https://explicit.deepseek.com/v1" + + def test_ollama_base_url_normalized(self): + """Test Ollama base URL is normalized to include /v1.""" + with patch.dict(os.environ, {"OLLAMA_HOST": "http://custom-ollama:11434"}): + completion = OpenAICompatibleCompletion(model="llama3", provider="ollama") + assert completion.base_url == "http://custom-ollama:11434/v1" + + def test_openrouter_headers(self): + """Test OpenRouter has HTTP-Referer header.""" + with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}): + completion = OpenAICompatibleCompletion( + model="anthropic/claude-3-opus", provider="openrouter" + ) + assert completion.default_headers is not None + assert "HTTP-Referer" in completion.default_headers + + def test_custom_headers_merged_with_defaults(self): + """Test custom headers are merged with provider defaults.""" + with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}): + completion = OpenAICompatibleCompletion( + model="anthropic/claude-3-opus", + provider="openrouter", + default_headers={"X-Custom": "value"}, + ) + assert completion.default_headers is not None + assert "HTTP-Referer" in completion.default_headers + assert completion.default_headers.get("X-Custom") == "value" + + def test_supports_function_calling(self): + """Test that function calling is supported.""" + completion = OpenAICompatibleCompletion(model="llama3", provider="ollama") + assert completion.supports_function_calling() is True + + +class TestLLMIntegration: + """Tests for LLM factory integration with OpenAI-compatible providers.""" + + def test_llm_creates_openai_compatible_for_deepseek(self): + """Test LLM factory creates OpenAICompatibleCompletion for DeepSeek.""" + with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}): + llm = LLM(model="deepseek/deepseek-chat") + assert isinstance(llm, OpenAICompatibleCompletion) + assert llm.provider == "deepseek" + assert llm.model == "deepseek-chat" + + def test_llm_creates_openai_compatible_for_ollama(self): + """Test LLM factory creates OpenAICompatibleCompletion for Ollama.""" + llm = LLM(model="ollama/llama3") + assert isinstance(llm, OpenAICompatibleCompletion) + assert llm.provider == "ollama" + assert llm.model == "llama3" + + def test_llm_creates_openai_compatible_for_openrouter(self): + """Test LLM factory creates OpenAICompatibleCompletion for OpenRouter.""" + with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}): + llm = LLM(model="openrouter/anthropic/claude-3-opus") + assert isinstance(llm, OpenAICompatibleCompletion) + assert llm.provider == "openrouter" + # Model should include the full path after provider prefix + assert llm.model == "anthropic/claude-3-opus" + + def test_llm_creates_openai_compatible_for_hosted_vllm(self): + """Test LLM factory creates OpenAICompatibleCompletion for hosted_vllm.""" + llm = LLM(model="hosted_vllm/meta-llama/Llama-3-8b") + assert isinstance(llm, OpenAICompatibleCompletion) + assert llm.provider == "hosted_vllm" + + def test_llm_creates_openai_compatible_for_cerebras(self): + """Test LLM factory creates OpenAICompatibleCompletion for Cerebras.""" + with patch.dict(os.environ, {"CEREBRAS_API_KEY": "test-key"}): + llm = LLM(model="cerebras/llama3-8b") + assert isinstance(llm, OpenAICompatibleCompletion) + assert llm.provider == "cerebras" + + def test_llm_creates_openai_compatible_for_dashscope(self): + """Test LLM factory creates OpenAICompatibleCompletion for Dashscope.""" + with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test-key"}): + llm = LLM(model="dashscope/qwen-turbo") + assert isinstance(llm, OpenAICompatibleCompletion) + assert llm.provider == "dashscope" + + def test_llm_with_explicit_provider(self): + """Test LLM with explicit provider parameter.""" + with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}): + llm = LLM(model="deepseek-chat", provider="deepseek") + assert isinstance(llm, OpenAICompatibleCompletion) + assert llm.provider == "deepseek" + assert llm.model == "deepseek-chat" + + def test_llm_passes_kwargs_to_completion(self): + """Test LLM passes kwargs to OpenAICompatibleCompletion.""" + with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}): + llm = LLM( + model="deepseek/deepseek-chat", + temperature=0.7, + max_tokens=1000, + ) + assert llm.temperature == 0.7 + assert llm.max_tokens == 1000 + + +class TestCallMocking: + """Tests for mocking the call method.""" + + def test_call_method_can_be_mocked(self): + """Test that the call method can be mocked for testing.""" + completion = OpenAICompatibleCompletion(model="llama3", provider="ollama") + + with patch.object(completion, "call", return_value="Mocked response"): + result = completion.call("Test message") + assert result == "Mocked response" + + def test_acall_method_exists(self): + """Test that acall method exists for async calls.""" + completion = OpenAICompatibleCompletion(model="llama3", provider="ollama") + assert hasattr(completion, "acall") + assert callable(completion.acall) diff --git a/lib/crewai/tests/test_llm.py b/lib/crewai/tests/test_llm.py index 71cb69790..1ed217166 100644 --- a/lib/crewai/tests/test_llm.py +++ b/lib/crewai/tests/test_llm.py @@ -211,7 +211,7 @@ def test_llm_passes_additional_params(): def test_get_custom_llm_provider_openrouter(): - llm = LLM(model="openrouter/deepseek/deepseek-chat") + llm = LLM(model="openrouter/deepseek/deepseek-chat", is_litellm=True) assert llm._get_custom_llm_provider() == "openrouter" @@ -232,7 +232,9 @@ def test_validate_call_params_supported(): # Patch supports_response_schema to simulate a supported model. with patch("crewai.llm.supports_response_schema", return_value=True): llm = LLM( - model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse + model="openrouter/deepseek/deepseek-chat", + response_format=DummyResponse, + is_litellm=True, ) # Should not raise any error. llm._validate_call_params()