feat: add IBM Granite model support via watsonx.ai provider (OSS-35)

Add native support for IBM Granite models through the watsonx.ai
Model Gateway OpenAI-compatible API.

Implementation:
- New watsonx provider at llms/providers/watsonx/ extending OpenAICompletion
- IBM Cloud IAM token exchange with thread-safe caching and auto-refresh
- Support for WATSONX_API_KEY, WATSONX_PROJECT_ID, WATSONX_REGION env vars
- 12 Granite models in constants (3.x, 4.x, code, guardian families)
- Full LLM routing: watsonx/ibm/granite-4-h-small or provider='watsonx'
- No new dependencies required (uses existing openai + httpx)

Usage:
  llm = LLM(model='watsonx/ibm/granite-4-h-small')
  llm = LLM(model='ibm/granite-4-h-small', provider='watsonx')

Closes OSS-35
This commit is contained in:
Iris Clawd
2026-04-13 19:15:59 +00:00
parent e21c506214
commit d3f422a121
6 changed files with 794 additions and 0 deletions

View File

@@ -44,6 +44,7 @@ from crewai.llms.constants import (
BEDROCK_MODELS,
GEMINI_MODELS,
OPENAI_MODELS,
WATSONX_MODELS,
)
from crewai.utilities import InternalInstructor
from crewai.utilities.exceptions.context_window_exceeding_exception import (
@@ -309,6 +310,8 @@ SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [
"gemini",
"bedrock",
"aws",
"watsonx",
"ibm",
# OpenAI-compatible providers
"openrouter",
"deepseek",
@@ -376,6 +379,8 @@ class LLM(BaseLLM):
"gemini": "gemini",
"bedrock": "bedrock",
"aws": "bedrock",
"watsonx": "watsonx",
"ibm": "watsonx",
# OpenAI-compatible providers
"openrouter": "openrouter",
"deepseek": "deepseek",
@@ -506,6 +511,12 @@ class LLM(BaseLLM):
# OpenRouter uses org/model format but accepts anything
return True
if provider == "watsonx" or provider == "ibm":
return any(
model_lower.startswith(prefix)
for prefix in ["ibm/granite", "granite"]
)
return False
@classmethod
@@ -541,6 +552,9 @@ class LLM(BaseLLM):
# azure does not provide a list of available models, determine a better way to handle this
return True
if (provider == "watsonx" or provider == "ibm") and model in WATSONX_MODELS:
return True
# Fallback to pattern matching for models not in constants
return cls._matches_provider_pattern(model, provider)
@@ -573,6 +587,9 @@ class LLM(BaseLLM):
if model in AZURE_MODELS:
return "azure"
if model in WATSONX_MODELS:
return "watsonx"
return "openai"
@classmethod
@@ -605,6 +622,11 @@ class LLM(BaseLLM):
return BedrockCompletion
if provider == "watsonx" or provider == "ibm":
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
return WatsonxCompletion
# OpenAI-compatible providers
openai_compatible_providers = {
"openrouter",

View File

@@ -568,3 +568,33 @@ BEDROCK_MODELS: list[BedrockModels] = [
"qwen.qwen3-coder-30b-a3b-v1:0",
"twelvelabs.pegasus-1-2-v1:0",
]
WatsonxModels: TypeAlias = Literal[
"ibm/granite-3-2b-instruct",
"ibm/granite-3-8b-instruct",
"ibm/granite-3-1-2b-instruct",
"ibm/granite-3-1-8b-instruct",
"ibm/granite-3-1-8b-base",
"ibm/granite-3-3-2b-instruct",
"ibm/granite-3-3-8b-instruct",
"ibm/granite-4-h-micro",
"ibm/granite-4-h-tiny",
"ibm/granite-4-h-small",
"ibm/granite-8b-code-instruct",
"ibm/granite-guardian-3-8b",
]
WATSONX_MODELS: list[WatsonxModels] = [
"ibm/granite-3-2b-instruct",
"ibm/granite-3-8b-instruct",
"ibm/granite-3-1-2b-instruct",
"ibm/granite-3-1-8b-instruct",
"ibm/granite-3-1-8b-base",
"ibm/granite-3-3-2b-instruct",
"ibm/granite-3-3-8b-instruct",
"ibm/granite-4-h-micro",
"ibm/granite-4-h-tiny",
"ibm/granite-4-h-small",
"ibm/granite-8b-code-instruct",
"ibm/granite-guardian-3-8b",
]

View File

@@ -0,0 +1,5 @@
"""IBM watsonx.ai provider module."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
__all__ = ["WatsonxCompletion"]

View File

@@ -0,0 +1,444 @@
"""IBM watsonx.ai provider implementation.
This module provides native support for IBM Granite models via the
watsonx.ai Model Gateway, which exposes an OpenAI-compatible API.
Authentication uses IBM Cloud IAM token exchange: an API key is exchanged
for a short-lived Bearer token via the IAM identity service.
Usage:
llm = LLM(model="watsonx/ibm/granite-4-h-small")
llm = LLM(model="ibm/granite-4-h-small", provider="watsonx")
Environment variables:
WATSONX_API_KEY: IBM Cloud API key (required)
WATSONX_PROJECT_ID: watsonx.ai project ID (required)
WATSONX_REGION: IBM Cloud region (default: us-south)
WATSONX_URL: Full base URL override (optional)
"""
from __future__ import annotations
import logging
import os
import threading
import time
from typing import Any
import httpx
from openai import OpenAI
from crewai.llms.providers.openai.completion import OpenAICompletion
logger = logging.getLogger(__name__)
# IBM Cloud IAM endpoint for token exchange
_IAM_TOKEN_URL = "https://iam.cloud.ibm.com/identity/token"
# Default region for watsonx.ai
_DEFAULT_REGION = "us-south"
# Refresh token 60 seconds before expiry to avoid race conditions
_TOKEN_REFRESH_BUFFER_SECONDS = 60
# Supported watsonx.ai regions
_SUPPORTED_REGIONS = frozenset({
"us-south",
"eu-de",
"eu-gb",
"jp-tok",
"au-syd",
})
class _IAMTokenManager:
"""Thread-safe IBM IAM token manager with automatic refresh.
Exchanges an IBM Cloud API key for a short-lived Bearer token and
caches it, refreshing automatically when the token approaches expiry.
"""
def __init__(self, api_key: str) -> None:
self._api_key = api_key
self._token: str | None = None
self._expiry: float = 0.0
self._lock = threading.Lock()
def get_token(self) -> str:
"""Get a valid IAM Bearer token, refreshing if needed.
Returns:
A valid Bearer token string.
Raises:
RuntimeError: If the token exchange fails.
"""
if self._token and time.time() < self._expiry - _TOKEN_REFRESH_BUFFER_SECONDS:
return self._token
with self._lock:
# Double-check after acquiring lock
if (
self._token
and time.time() < self._expiry - _TOKEN_REFRESH_BUFFER_SECONDS
):
return self._token
self._refresh_token()
assert self._token is not None
return self._token
def _refresh_token(self) -> None:
"""Exchange API key for a new IAM token."""
try:
response = httpx.post(
_IAM_TOKEN_URL,
data={
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
"apikey": self._api_key,
},
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
timeout=30.0,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise RuntimeError(
f"IBM IAM token exchange failed (HTTP {e.response.status_code}): "
f"{e.response.text}"
) from e
except httpx.HTTPError as e:
raise RuntimeError(
f"IBM IAM token exchange request failed: {e}"
) from e
data = response.json()
self._token = data["access_token"]
self._expiry = float(data["expiration"])
logger.debug(
"IBM IAM token refreshed, expires at %s",
time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(self._expiry)),
)
class WatsonxCompletion(OpenAICompletion):
"""IBM watsonx.ai completion implementation.
This class provides support for IBM Granite models and other foundation
models hosted on watsonx.ai via the OpenAI-compatible Model Gateway.
Authentication is handled transparently via IBM Cloud IAM token exchange.
The API key is exchanged for a Bearer token which is automatically
refreshed when it approaches expiry.
Supported models include the IBM Granite family:
- ibm/granite-4-h-small (32B hybrid)
- ibm/granite-4-h-tiny (7B hybrid)
- ibm/granite-4-h-micro (3B hybrid)
- ibm/granite-3-8b-instruct
- ibm/granite-3-3-8b-instruct
- ibm/granite-8b-code-instruct
- ibm/granite-guardian-3-8b
- And other models available on watsonx.ai
Example:
# Using provider prefix
llm = LLM(model="watsonx/ibm/granite-4-h-small")
# Using explicit provider
llm = LLM(model="ibm/granite-4-h-small", provider="watsonx")
# With custom configuration
llm = LLM(
model="ibm/granite-4-h-small",
provider="watsonx",
api_key="my-ibm-cloud-api-key",
temperature=0.7,
)
"""
def __init__(
self,
model: str,
provider: str = "watsonx",
api_key: str | None = None,
base_url: str | None = None,
project_id: str | None = None,
region: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize watsonx.ai completion client.
Args:
model: The model identifier (e.g., "ibm/granite-4-h-small").
provider: The provider name (default: "watsonx").
api_key: IBM Cloud API key. If not provided, reads from
WATSONX_API_KEY environment variable.
base_url: Full base URL override for the watsonx.ai endpoint.
If not provided, constructed from region.
project_id: watsonx.ai project ID. If not provided, reads from
WATSONX_PROJECT_ID environment variable.
region: IBM Cloud region (default: "us-south"). If not provided,
reads from WATSONX_REGION environment variable.
**kwargs: Additional arguments passed to OpenAICompletion.
Raises:
ValueError: If required credentials are missing.
"""
resolved_api_key = self._resolve_api_key(api_key)
resolved_project_id = self._resolve_project_id(project_id)
resolved_region = self._resolve_region(region)
resolved_base_url = self._resolve_base_url(base_url, resolved_region)
# Initialize IAM token manager for transparent auth
self._iam_manager = _IAMTokenManager(resolved_api_key)
self._project_id = resolved_project_id
# Get initial token for client construction
initial_token = self._iam_manager.get_token()
# Pass the bearer token as api_key to OpenAI client
# The OpenAI SDK uses this as Authorization: Bearer <token>
super().__init__(
model=model,
provider=provider,
api_key=initial_token,
base_url=resolved_base_url,
**kwargs,
)
@staticmethod
def _resolve_api_key(api_key: str | None) -> str:
"""Resolve IBM Cloud API key from parameter or environment.
Args:
api_key: Explicitly provided API key.
Returns:
The resolved API key.
Raises:
ValueError: If no API key is found.
"""
resolved = api_key or os.getenv("WATSONX_API_KEY")
if not resolved:
raise ValueError(
"IBM Cloud API key is required for watsonx.ai provider. "
"Set the WATSONX_API_KEY environment variable or pass "
"api_key parameter."
)
return resolved
@staticmethod
def _resolve_project_id(project_id: str | None) -> str:
"""Resolve watsonx.ai project ID from parameter or environment.
Args:
project_id: Explicitly provided project ID.
Returns:
The resolved project ID.
Raises:
ValueError: If no project ID is found.
"""
resolved = project_id or os.getenv("WATSONX_PROJECT_ID")
if not resolved:
raise ValueError(
"watsonx.ai project ID is required. "
"Set the WATSONX_PROJECT_ID environment variable or pass "
"project_id parameter."
)
return resolved
@staticmethod
def _resolve_region(region: str | None) -> str:
"""Resolve IBM Cloud region from parameter or environment.
Args:
region: Explicitly provided region.
Returns:
The resolved region string.
"""
resolved = region or os.getenv("WATSONX_REGION", _DEFAULT_REGION)
if resolved not in _SUPPORTED_REGIONS:
logger.warning(
"Region '%s' is not in the known supported regions: %s. "
"Proceeding anyway in case IBM has added new regions.",
resolved,
", ".join(sorted(_SUPPORTED_REGIONS)),
)
return resolved
@staticmethod
def _resolve_base_url(base_url: str | None, region: str) -> str:
"""Resolve the watsonx.ai base URL.
Priority:
1. Explicit base_url parameter
2. WATSONX_URL environment variable
3. Constructed from region
Args:
base_url: Explicitly provided base URL.
region: IBM Cloud region for URL construction.
Returns:
The resolved base URL.
"""
if base_url:
return base_url.rstrip("/")
env_url = os.getenv("WATSONX_URL")
if env_url:
return env_url.rstrip("/")
return f"https://{region}.ml.cloud.ibm.com/ml/v1"
def _build_client(
self,
api_key: str | None = None,
base_url: str | None = None,
default_headers: dict[str, str] | None = None,
) -> OpenAI:
"""Build the OpenAI client with watsonx-specific configuration.
Overrides the parent method to inject the project_id header
and ensure the IAM token is current.
Args:
api_key: Bearer token (from IAM exchange).
base_url: watsonx.ai endpoint URL.
default_headers: Additional headers.
Returns:
Configured OpenAI client instance.
"""
# Refresh token if needed
current_token = self._iam_manager.get_token()
# Merge watsonx-specific headers
watsonx_headers = {
"X-Watsonx-Project-Id": self._project_id,
}
if default_headers:
watsonx_headers.update(default_headers)
return super()._build_client(
api_key=current_token,
base_url=base_url,
default_headers=watsonx_headers,
)
def _ensure_fresh_token(self) -> None:
"""Refresh the IAM token on the client if needed.
Updates the client's API key (Bearer token) if the cached
token has been refreshed.
"""
current_token = self._iam_manager.get_token()
if hasattr(self, "client") and self.client is not None:
self.client.api_key = current_token
def call(self, messages, tools=None, callbacks=None, available_functions=None,
from_task=None, from_agent=None, response_model=None):
"""Call the LLM, refreshing the IAM token if needed."""
self._ensure_fresh_token()
return super().call(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
async def acall(self, messages, tools=None, callbacks=None, available_functions=None,
from_task=None, from_agent=None, response_model=None):
"""Async call the LLM, refreshing the IAM token if needed."""
self._ensure_fresh_token()
return await super().acall(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
def get_context_window_size(self) -> int:
"""Get context window size for Granite models.
Returns:
The context window size in tokens.
"""
model_lower = self.model.lower()
# Granite 4.x models have 128K context
if "granite-4" in model_lower:
return 131072
# Granite 3.x instruct models have 128K context
if "granite-3" in model_lower and "instruct" in model_lower:
return 131072
# Granite 3.x base models have 4K context
if "granite-3" in model_lower:
return 4096
# Granite code models
if "granite" in model_lower and "code" in model_lower:
return 8192
# Default for unknown models
return 8192
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling / tool use.
Granite 3.x instruct and 4.x models support tool use.
Returns:
True if the model supports function calling.
"""
model_lower = self.model.lower()
# Granite 4.x models support tool use
if "granite-4" in model_lower:
return True
# Granite 3.x instruct models support tool use
if "granite-3" in model_lower and "instruct" in model_lower:
return True
# Granite guardian models don't do tool use
if "guardian" in model_lower:
return False
# Default: assume no tool use for unknown models
return False
def supports_multimodal(self) -> bool:
"""Check if the model supports multimodal inputs.
Currently, Granite models are text-only.
Returns:
False (Granite models are text-only).
"""
return False
def to_config_dict(self) -> dict[str, Any]:
"""Serialize this LLM to a dict for reconstruction.
Returns:
Configuration dict with watsonx-specific fields.
"""
config = super().to_config_dict()
config["model"] = f"watsonx/{self.model}" if "/" not in self.model else f"watsonx/{self.model}"
return config

View File

@@ -0,0 +1,293 @@
"""Tests for IBM watsonx.ai provider."""
from __future__ import annotations
import os
import time
from unittest.mock import MagicMock, patch
import httpx
import pytest
class TestIAMTokenManager:
"""Tests for the IAM token manager."""
def test_token_exchange_success(self):
"""Test successful IAM token exchange."""
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
mock_response = MagicMock()
mock_response.json.return_value = {
"access_token": "test-bearer-token",
"expiration": time.time() + 3600,
"token_type": "Bearer",
}
mock_response.raise_for_status = MagicMock()
with patch("httpx.post", return_value=mock_response) as mock_post:
manager = _IAMTokenManager("test-api-key")
token = manager.get_token()
assert token == "test-bearer-token"
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
assert call_kwargs[1]["data"]["apikey"] == "test-api-key"
assert (
call_kwargs[1]["data"]["grant_type"]
== "urn:ibm:params:oauth:grant-type:apikey"
)
def test_token_caching(self):
"""Test that tokens are cached and not re-fetched."""
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
mock_response = MagicMock()
mock_response.json.return_value = {
"access_token": "cached-token",
"expiration": time.time() + 3600,
}
mock_response.raise_for_status = MagicMock()
with patch("httpx.post", return_value=mock_response) as mock_post:
manager = _IAMTokenManager("test-api-key")
# First call - should fetch
token1 = manager.get_token()
# Second call - should use cache
token2 = manager.get_token()
assert token1 == token2 == "cached-token"
assert mock_post.call_count == 1 # Only one HTTP call
def test_token_refresh_on_expiry(self):
"""Test that expired tokens are refreshed."""
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
call_count = 0
def mock_post(*args, **kwargs):
nonlocal call_count
call_count += 1
mock_resp = MagicMock()
mock_resp.json.return_value = {
"access_token": f"token-{call_count}",
"expiration": time.time() + (0 if call_count == 1 else 3600),
}
mock_resp.raise_for_status = MagicMock()
return mock_resp
with patch("httpx.post", side_effect=mock_post):
manager = _IAMTokenManager("test-api-key")
# First call - gets token-1 which is already expired
token1 = manager.get_token()
assert token1 == "token-1"
# Second call - token-1 is expired, should refresh to token-2
token2 = manager.get_token()
assert token2 == "token-2"
assert call_count == 2
def test_token_exchange_http_error(self):
"""Test that HTTP errors during token exchange raise RuntimeError."""
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"401", request=MagicMock(), response=mock_response
)
with patch("httpx.post", return_value=mock_response):
manager = _IAMTokenManager("bad-api-key")
with pytest.raises(RuntimeError, match="IBM IAM token exchange failed"):
manager.get_token()
class TestWatsonxCompletionInit:
"""Tests for WatsonxCompletion initialization."""
@patch.dict(os.environ, {}, clear=True)
def test_missing_api_key_raises(self):
"""Test that missing API key raises ValueError."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
with pytest.raises(ValueError, match="IBM Cloud API key is required"):
WatsonxCompletion(model="ibm/granite-4-h-small")
@patch.dict(
os.environ,
{"WATSONX_API_KEY": "test-key"},
clear=True,
)
def test_missing_project_id_raises(self):
"""Test that missing project ID raises ValueError."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
with pytest.raises(ValueError, match="project ID is required"):
WatsonxCompletion(model="ibm/granite-4-h-small")
@patch.dict(
os.environ,
{
"WATSONX_API_KEY": "test-key",
"WATSONX_PROJECT_ID": "test-project",
},
clear=True,
)
def test_default_region_url(self):
"""Test that default region constructs correct URL."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
mock_response = MagicMock()
mock_response.json.return_value = {
"access_token": "test-token",
"expiration": time.time() + 3600,
}
mock_response.raise_for_status = MagicMock()
with patch("httpx.post", return_value=mock_response):
with patch(
"crewai.llms.providers.openai.completion.OpenAICompletion.__init__",
return_value=None,
) as mock_init:
completion = WatsonxCompletion.__new__(WatsonxCompletion)
# Manually set _iam_manager and _project_id since we skip __init__
# Instead, test the static method directly
url = WatsonxCompletion._resolve_base_url(None, "us-south")
assert url == "https://us-south.ml.cloud.ibm.com/ml/v1"
def test_resolve_base_url_custom_region(self):
"""Test URL construction with custom region."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
url = WatsonxCompletion._resolve_base_url(None, "eu-de")
assert url == "https://eu-de.ml.cloud.ibm.com/ml/v1"
def test_resolve_base_url_explicit(self):
"""Test that explicit base_url takes priority."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
url = WatsonxCompletion._resolve_base_url(
"https://custom.example.com/v1", "us-south"
)
assert url == "https://custom.example.com/v1"
@patch.dict(
os.environ,
{"WATSONX_URL": "https://env-override.example.com/v1"},
clear=True,
)
def test_resolve_base_url_env_override(self):
"""Test that WATSONX_URL env var overrides region-based URL."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
url = WatsonxCompletion._resolve_base_url(None, "us-south")
assert url == "https://env-override.example.com/v1"
def test_resolve_region_default(self):
"""Test default region resolution."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
with patch.dict(os.environ, {}, clear=True):
region = WatsonxCompletion._resolve_region(None)
assert region == "us-south"
@patch.dict(os.environ, {"WATSONX_REGION": "eu-gb"}, clear=True)
def test_resolve_region_from_env(self):
"""Test region resolution from environment variable."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
region = WatsonxCompletion._resolve_region(None)
assert region == "eu-gb"
def test_resolve_region_explicit(self):
"""Test explicit region parameter takes priority."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
region = WatsonxCompletion._resolve_region("jp-tok")
assert region == "jp-tok"
class TestWatsonxModelCapabilities:
"""Tests for model capability detection."""
def _make_completion(self, model: str) -> object:
"""Create a minimal WatsonxCompletion-like object for testing."""
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
# Create a bare instance without calling __init__
obj = object.__new__(WatsonxCompletion)
obj.model = model
return obj
def test_granite_4_context_window(self):
"""Test Granite 4.x models report 128K context."""
comp = self._make_completion("ibm/granite-4-h-small")
assert comp.get_context_window_size() == 131072
def test_granite_3_instruct_context_window(self):
"""Test Granite 3.x instruct models report 128K context."""
comp = self._make_completion("ibm/granite-3-8b-instruct")
assert comp.get_context_window_size() == 131072
def test_granite_code_context_window(self):
"""Test Granite code models report 8K context."""
comp = self._make_completion("ibm/granite-8b-code-instruct")
assert comp.get_context_window_size() == 8192
def test_granite_4_supports_function_calling(self):
"""Test Granite 4.x models support function calling."""
comp = self._make_completion("ibm/granite-4-h-small")
assert comp.supports_function_calling() is True
def test_granite_3_instruct_supports_function_calling(self):
"""Test Granite 3.x instruct models support function calling."""
comp = self._make_completion("ibm/granite-3-8b-instruct")
assert comp.supports_function_calling() is True
def test_granite_guardian_no_function_calling(self):
"""Test Granite Guardian models don't support function calling."""
comp = self._make_completion("ibm/granite-guardian-3-8b")
assert comp.supports_function_calling() is False
def test_granite_not_multimodal(self):
"""Test Granite models are not multimodal."""
comp = self._make_completion("ibm/granite-4-h-small")
assert comp.supports_multimodal() is False
class TestWatsonxModelRouting:
"""Tests for model routing through the LLM factory."""
def test_watsonx_models_in_constants(self):
"""Test that WATSONX_MODELS is properly defined."""
from crewai.llms.constants import WATSONX_MODELS
assert "ibm/granite-4-h-small" in WATSONX_MODELS
assert "ibm/granite-3-8b-instruct" in WATSONX_MODELS
assert "ibm/granite-guardian-3-8b" in WATSONX_MODELS
assert len(WATSONX_MODELS) >= 10
def test_watsonx_in_supported_providers(self):
"""Test that watsonx is in the supported native providers list."""
from crewai.llm import SUPPORTED_NATIVE_PROVIDERS
assert "watsonx" in SUPPORTED_NATIVE_PROVIDERS
assert "ibm" in SUPPORTED_NATIVE_PROVIDERS
def test_get_native_provider_watsonx(self):
"""Test that _get_native_provider returns WatsonxCompletion."""
from crewai.llm import LLM
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
assert LLM._get_native_provider("watsonx") is WatsonxCompletion
assert LLM._get_native_provider("ibm") is WatsonxCompletion
def test_infer_provider_from_watsonx_model(self):
"""Test that Granite models are inferred as watsonx provider."""
from crewai.llm import LLM
assert LLM._infer_provider_from_model("ibm/granite-4-h-small") == "watsonx"