diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 75b1f6546..7512b9d7c 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -44,6 +44,7 @@ from crewai.llms.constants import ( BEDROCK_MODELS, GEMINI_MODELS, OPENAI_MODELS, + WATSONX_MODELS, ) from crewai.utilities import InternalInstructor from crewai.utilities.exceptions.context_window_exceeding_exception import ( @@ -309,6 +310,8 @@ SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [ "gemini", "bedrock", "aws", + "watsonx", + "ibm", # OpenAI-compatible providers "openrouter", "deepseek", @@ -376,6 +379,8 @@ class LLM(BaseLLM): "gemini": "gemini", "bedrock": "bedrock", "aws": "bedrock", + "watsonx": "watsonx", + "ibm": "watsonx", # OpenAI-compatible providers "openrouter": "openrouter", "deepseek": "deepseek", @@ -506,6 +511,12 @@ class LLM(BaseLLM): # OpenRouter uses org/model format but accepts anything return True + if provider == "watsonx" or provider == "ibm": + return any( + model_lower.startswith(prefix) + for prefix in ["ibm/granite", "granite"] + ) + return False @classmethod @@ -541,6 +552,9 @@ class LLM(BaseLLM): # azure does not provide a list of available models, determine a better way to handle this return True + if (provider == "watsonx" or provider == "ibm") and model in WATSONX_MODELS: + return True + # Fallback to pattern matching for models not in constants return cls._matches_provider_pattern(model, provider) @@ -573,6 +587,9 @@ class LLM(BaseLLM): if model in AZURE_MODELS: return "azure" + if model in WATSONX_MODELS: + return "watsonx" + return "openai" @classmethod @@ -605,6 +622,11 @@ class LLM(BaseLLM): return BedrockCompletion + if provider == "watsonx" or provider == "ibm": + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + return WatsonxCompletion + # OpenAI-compatible providers openai_compatible_providers = { "openrouter", diff --git a/lib/crewai/src/crewai/llms/constants.py b/lib/crewai/src/crewai/llms/constants.py index 595a0a30d..0a6ba19af 100644 --- a/lib/crewai/src/crewai/llms/constants.py +++ b/lib/crewai/src/crewai/llms/constants.py @@ -568,3 +568,33 @@ BEDROCK_MODELS: list[BedrockModels] = [ "qwen.qwen3-coder-30b-a3b-v1:0", "twelvelabs.pegasus-1-2-v1:0", ] + + +WatsonxModels: TypeAlias = Literal[ + "ibm/granite-3-2b-instruct", + "ibm/granite-3-8b-instruct", + "ibm/granite-3-1-2b-instruct", + "ibm/granite-3-1-8b-instruct", + "ibm/granite-3-1-8b-base", + "ibm/granite-3-3-2b-instruct", + "ibm/granite-3-3-8b-instruct", + "ibm/granite-4-h-micro", + "ibm/granite-4-h-tiny", + "ibm/granite-4-h-small", + "ibm/granite-8b-code-instruct", + "ibm/granite-guardian-3-8b", +] +WATSONX_MODELS: list[WatsonxModels] = [ + "ibm/granite-3-2b-instruct", + "ibm/granite-3-8b-instruct", + "ibm/granite-3-1-2b-instruct", + "ibm/granite-3-1-8b-instruct", + "ibm/granite-3-1-8b-base", + "ibm/granite-3-3-2b-instruct", + "ibm/granite-3-3-8b-instruct", + "ibm/granite-4-h-micro", + "ibm/granite-4-h-tiny", + "ibm/granite-4-h-small", + "ibm/granite-8b-code-instruct", + "ibm/granite-guardian-3-8b", +] diff --git a/lib/crewai/src/crewai/llms/providers/watsonx/__init__.py b/lib/crewai/src/crewai/llms/providers/watsonx/__init__.py new file mode 100644 index 000000000..3e403d5c5 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/watsonx/__init__.py @@ -0,0 +1,5 @@ +"""IBM watsonx.ai provider module.""" + +from crewai.llms.providers.watsonx.completion import WatsonxCompletion + +__all__ = ["WatsonxCompletion"] diff --git a/lib/crewai/src/crewai/llms/providers/watsonx/completion.py b/lib/crewai/src/crewai/llms/providers/watsonx/completion.py new file mode 100644 index 000000000..4bf43bfbb --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/watsonx/completion.py @@ -0,0 +1,444 @@ +"""IBM watsonx.ai provider implementation. + +This module provides native support for IBM Granite models via the +watsonx.ai Model Gateway, which exposes an OpenAI-compatible API. + +Authentication uses IBM Cloud IAM token exchange: an API key is exchanged +for a short-lived Bearer token via the IAM identity service. + +Usage: + llm = LLM(model="watsonx/ibm/granite-4-h-small") + llm = LLM(model="ibm/granite-4-h-small", provider="watsonx") + +Environment variables: + WATSONX_API_KEY: IBM Cloud API key (required) + WATSONX_PROJECT_ID: watsonx.ai project ID (required) + WATSONX_REGION: IBM Cloud region (default: us-south) + WATSONX_URL: Full base URL override (optional) +""" + +from __future__ import annotations + +import logging +import os +import threading +import time +from typing import Any + +import httpx +from openai import OpenAI + +from crewai.llms.providers.openai.completion import OpenAICompletion + +logger = logging.getLogger(__name__) + +# IBM Cloud IAM endpoint for token exchange +_IAM_TOKEN_URL = "https://iam.cloud.ibm.com/identity/token" + +# Default region for watsonx.ai +_DEFAULT_REGION = "us-south" + +# Refresh token 60 seconds before expiry to avoid race conditions +_TOKEN_REFRESH_BUFFER_SECONDS = 60 + +# Supported watsonx.ai regions +_SUPPORTED_REGIONS = frozenset({ + "us-south", + "eu-de", + "eu-gb", + "jp-tok", + "au-syd", +}) + + +class _IAMTokenManager: + """Thread-safe IBM IAM token manager with automatic refresh. + + Exchanges an IBM Cloud API key for a short-lived Bearer token and + caches it, refreshing automatically when the token approaches expiry. + """ + + def __init__(self, api_key: str) -> None: + self._api_key = api_key + self._token: str | None = None + self._expiry: float = 0.0 + self._lock = threading.Lock() + + def get_token(self) -> str: + """Get a valid IAM Bearer token, refreshing if needed. + + Returns: + A valid Bearer token string. + + Raises: + RuntimeError: If the token exchange fails. + """ + if self._token and time.time() < self._expiry - _TOKEN_REFRESH_BUFFER_SECONDS: + return self._token + + with self._lock: + # Double-check after acquiring lock + if ( + self._token + and time.time() < self._expiry - _TOKEN_REFRESH_BUFFER_SECONDS + ): + return self._token + + self._refresh_token() + assert self._token is not None + return self._token + + def _refresh_token(self) -> None: + """Exchange API key for a new IAM token.""" + try: + response = httpx.post( + _IAM_TOKEN_URL, + data={ + "grant_type": "urn:ibm:params:oauth:grant-type:apikey", + "apikey": self._api_key, + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }, + timeout=30.0, + ) + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise RuntimeError( + f"IBM IAM token exchange failed (HTTP {e.response.status_code}): " + f"{e.response.text}" + ) from e + except httpx.HTTPError as e: + raise RuntimeError( + f"IBM IAM token exchange request failed: {e}" + ) from e + + data = response.json() + self._token = data["access_token"] + self._expiry = float(data["expiration"]) + logger.debug( + "IBM IAM token refreshed, expires at %s", + time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(self._expiry)), + ) + + +class WatsonxCompletion(OpenAICompletion): + """IBM watsonx.ai completion implementation. + + This class provides support for IBM Granite models and other foundation + models hosted on watsonx.ai via the OpenAI-compatible Model Gateway. + + Authentication is handled transparently via IBM Cloud IAM token exchange. + The API key is exchanged for a Bearer token which is automatically + refreshed when it approaches expiry. + + Supported models include the IBM Granite family: + - ibm/granite-4-h-small (32B hybrid) + - ibm/granite-4-h-tiny (7B hybrid) + - ibm/granite-4-h-micro (3B hybrid) + - ibm/granite-3-8b-instruct + - ibm/granite-3-3-8b-instruct + - ibm/granite-8b-code-instruct + - ibm/granite-guardian-3-8b + - And other models available on watsonx.ai + + Example: + # Using provider prefix + llm = LLM(model="watsonx/ibm/granite-4-h-small") + + # Using explicit provider + llm = LLM(model="ibm/granite-4-h-small", provider="watsonx") + + # With custom configuration + llm = LLM( + model="ibm/granite-4-h-small", + provider="watsonx", + api_key="my-ibm-cloud-api-key", + temperature=0.7, + ) + """ + + def __init__( + self, + model: str, + provider: str = "watsonx", + api_key: str | None = None, + base_url: str | None = None, + project_id: str | None = None, + region: str | None = None, + **kwargs: Any, + ) -> None: + """Initialize watsonx.ai completion client. + + Args: + model: The model identifier (e.g., "ibm/granite-4-h-small"). + provider: The provider name (default: "watsonx"). + api_key: IBM Cloud API key. If not provided, reads from + WATSONX_API_KEY environment variable. + base_url: Full base URL override for the watsonx.ai endpoint. + If not provided, constructed from region. + project_id: watsonx.ai project ID. If not provided, reads from + WATSONX_PROJECT_ID environment variable. + region: IBM Cloud region (default: "us-south"). If not provided, + reads from WATSONX_REGION environment variable. + **kwargs: Additional arguments passed to OpenAICompletion. + + Raises: + ValueError: If required credentials are missing. + """ + resolved_api_key = self._resolve_api_key(api_key) + resolved_project_id = self._resolve_project_id(project_id) + resolved_region = self._resolve_region(region) + resolved_base_url = self._resolve_base_url(base_url, resolved_region) + + # Initialize IAM token manager for transparent auth + self._iam_manager = _IAMTokenManager(resolved_api_key) + self._project_id = resolved_project_id + + # Get initial token for client construction + initial_token = self._iam_manager.get_token() + + # Pass the bearer token as api_key to OpenAI client + # The OpenAI SDK uses this as Authorization: Bearer + super().__init__( + model=model, + provider=provider, + api_key=initial_token, + base_url=resolved_base_url, + **kwargs, + ) + + @staticmethod + def _resolve_api_key(api_key: str | None) -> str: + """Resolve IBM Cloud API key from parameter or environment. + + Args: + api_key: Explicitly provided API key. + + Returns: + The resolved API key. + + Raises: + ValueError: If no API key is found. + """ + resolved = api_key or os.getenv("WATSONX_API_KEY") + if not resolved: + raise ValueError( + "IBM Cloud API key is required for watsonx.ai provider. " + "Set the WATSONX_API_KEY environment variable or pass " + "api_key parameter." + ) + return resolved + + @staticmethod + def _resolve_project_id(project_id: str | None) -> str: + """Resolve watsonx.ai project ID from parameter or environment. + + Args: + project_id: Explicitly provided project ID. + + Returns: + The resolved project ID. + + Raises: + ValueError: If no project ID is found. + """ + resolved = project_id or os.getenv("WATSONX_PROJECT_ID") + if not resolved: + raise ValueError( + "watsonx.ai project ID is required. " + "Set the WATSONX_PROJECT_ID environment variable or pass " + "project_id parameter." + ) + return resolved + + @staticmethod + def _resolve_region(region: str | None) -> str: + """Resolve IBM Cloud region from parameter or environment. + + Args: + region: Explicitly provided region. + + Returns: + The resolved region string. + """ + resolved = region or os.getenv("WATSONX_REGION", _DEFAULT_REGION) + if resolved not in _SUPPORTED_REGIONS: + logger.warning( + "Region '%s' is not in the known supported regions: %s. " + "Proceeding anyway in case IBM has added new regions.", + resolved, + ", ".join(sorted(_SUPPORTED_REGIONS)), + ) + return resolved + + @staticmethod + def _resolve_base_url(base_url: str | None, region: str) -> str: + """Resolve the watsonx.ai base URL. + + Priority: + 1. Explicit base_url parameter + 2. WATSONX_URL environment variable + 3. Constructed from region + + Args: + base_url: Explicitly provided base URL. + region: IBM Cloud region for URL construction. + + Returns: + The resolved base URL. + """ + if base_url: + return base_url.rstrip("/") + + env_url = os.getenv("WATSONX_URL") + if env_url: + return env_url.rstrip("/") + + return f"https://{region}.ml.cloud.ibm.com/ml/v1" + + def _build_client( + self, + api_key: str | None = None, + base_url: str | None = None, + default_headers: dict[str, str] | None = None, + ) -> OpenAI: + """Build the OpenAI client with watsonx-specific configuration. + + Overrides the parent method to inject the project_id header + and ensure the IAM token is current. + + Args: + api_key: Bearer token (from IAM exchange). + base_url: watsonx.ai endpoint URL. + default_headers: Additional headers. + + Returns: + Configured OpenAI client instance. + """ + # Refresh token if needed + current_token = self._iam_manager.get_token() + + # Merge watsonx-specific headers + watsonx_headers = { + "X-Watsonx-Project-Id": self._project_id, + } + if default_headers: + watsonx_headers.update(default_headers) + + return super()._build_client( + api_key=current_token, + base_url=base_url, + default_headers=watsonx_headers, + ) + + def _ensure_fresh_token(self) -> None: + """Refresh the IAM token on the client if needed. + + Updates the client's API key (Bearer token) if the cached + token has been refreshed. + """ + current_token = self._iam_manager.get_token() + if hasattr(self, "client") and self.client is not None: + self.client.api_key = current_token + + def call(self, messages, tools=None, callbacks=None, available_functions=None, + from_task=None, from_agent=None, response_model=None): + """Call the LLM, refreshing the IAM token if needed.""" + self._ensure_fresh_token() + return super().call( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + async def acall(self, messages, tools=None, callbacks=None, available_functions=None, + from_task=None, from_agent=None, response_model=None): + """Async call the LLM, refreshing the IAM token if needed.""" + self._ensure_fresh_token() + return await super().acall( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + def get_context_window_size(self) -> int: + """Get context window size for Granite models. + + Returns: + The context window size in tokens. + """ + model_lower = self.model.lower() + + # Granite 4.x models have 128K context + if "granite-4" in model_lower: + return 131072 + + # Granite 3.x instruct models have 128K context + if "granite-3" in model_lower and "instruct" in model_lower: + return 131072 + + # Granite 3.x base models have 4K context + if "granite-3" in model_lower: + return 4096 + + # Granite code models + if "granite" in model_lower and "code" in model_lower: + return 8192 + + # Default for unknown models + return 8192 + + def supports_function_calling(self) -> bool: + """Check if the model supports function calling / tool use. + + Granite 3.x instruct and 4.x models support tool use. + + Returns: + True if the model supports function calling. + """ + model_lower = self.model.lower() + + # Granite 4.x models support tool use + if "granite-4" in model_lower: + return True + + # Granite 3.x instruct models support tool use + if "granite-3" in model_lower and "instruct" in model_lower: + return True + + # Granite guardian models don't do tool use + if "guardian" in model_lower: + return False + + # Default: assume no tool use for unknown models + return False + + def supports_multimodal(self) -> bool: + """Check if the model supports multimodal inputs. + + Currently, Granite models are text-only. + + Returns: + False (Granite models are text-only). + """ + return False + + def to_config_dict(self) -> dict[str, Any]: + """Serialize this LLM to a dict for reconstruction. + + Returns: + Configuration dict with watsonx-specific fields. + """ + config = super().to_config_dict() + config["model"] = f"watsonx/{self.model}" if "/" not in self.model else f"watsonx/{self.model}" + return config diff --git a/lib/crewai/tests/llms/providers/watsonx/__init__.py b/lib/crewai/tests/llms/providers/watsonx/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/crewai/tests/llms/providers/watsonx/test_completion.py b/lib/crewai/tests/llms/providers/watsonx/test_completion.py new file mode 100644 index 000000000..b23724872 --- /dev/null +++ b/lib/crewai/tests/llms/providers/watsonx/test_completion.py @@ -0,0 +1,293 @@ +"""Tests for IBM watsonx.ai provider.""" + +from __future__ import annotations + +import os +import time +from unittest.mock import MagicMock, patch + +import httpx +import pytest + + +class TestIAMTokenManager: + """Tests for the IAM token manager.""" + + def test_token_exchange_success(self): + """Test successful IAM token exchange.""" + from crewai.llms.providers.watsonx.completion import _IAMTokenManager + + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "test-bearer-token", + "expiration": time.time() + 3600, + "token_type": "Bearer", + } + mock_response.raise_for_status = MagicMock() + + with patch("httpx.post", return_value=mock_response) as mock_post: + manager = _IAMTokenManager("test-api-key") + token = manager.get_token() + + assert token == "test-bearer-token" + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + assert call_kwargs[1]["data"]["apikey"] == "test-api-key" + assert ( + call_kwargs[1]["data"]["grant_type"] + == "urn:ibm:params:oauth:grant-type:apikey" + ) + + def test_token_caching(self): + """Test that tokens are cached and not re-fetched.""" + from crewai.llms.providers.watsonx.completion import _IAMTokenManager + + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "cached-token", + "expiration": time.time() + 3600, + } + mock_response.raise_for_status = MagicMock() + + with patch("httpx.post", return_value=mock_response) as mock_post: + manager = _IAMTokenManager("test-api-key") + + # First call - should fetch + token1 = manager.get_token() + # Second call - should use cache + token2 = manager.get_token() + + assert token1 == token2 == "cached-token" + assert mock_post.call_count == 1 # Only one HTTP call + + def test_token_refresh_on_expiry(self): + """Test that expired tokens are refreshed.""" + from crewai.llms.providers.watsonx.completion import _IAMTokenManager + + call_count = 0 + + def mock_post(*args, **kwargs): + nonlocal call_count + call_count += 1 + mock_resp = MagicMock() + mock_resp.json.return_value = { + "access_token": f"token-{call_count}", + "expiration": time.time() + (0 if call_count == 1 else 3600), + } + mock_resp.raise_for_status = MagicMock() + return mock_resp + + with patch("httpx.post", side_effect=mock_post): + manager = _IAMTokenManager("test-api-key") + + # First call - gets token-1 which is already expired + token1 = manager.get_token() + assert token1 == "token-1" + + # Second call - token-1 is expired, should refresh to token-2 + token2 = manager.get_token() + assert token2 == "token-2" + assert call_count == 2 + + def test_token_exchange_http_error(self): + """Test that HTTP errors during token exchange raise RuntimeError.""" + from crewai.llms.providers.watsonx.completion import _IAMTokenManager + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "401", request=MagicMock(), response=mock_response + ) + + with patch("httpx.post", return_value=mock_response): + manager = _IAMTokenManager("bad-api-key") + with pytest.raises(RuntimeError, match="IBM IAM token exchange failed"): + manager.get_token() + + +class TestWatsonxCompletionInit: + """Tests for WatsonxCompletion initialization.""" + + @patch.dict(os.environ, {}, clear=True) + def test_missing_api_key_raises(self): + """Test that missing API key raises ValueError.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + with pytest.raises(ValueError, match="IBM Cloud API key is required"): + WatsonxCompletion(model="ibm/granite-4-h-small") + + @patch.dict( + os.environ, + {"WATSONX_API_KEY": "test-key"}, + clear=True, + ) + def test_missing_project_id_raises(self): + """Test that missing project ID raises ValueError.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + with pytest.raises(ValueError, match="project ID is required"): + WatsonxCompletion(model="ibm/granite-4-h-small") + + @patch.dict( + os.environ, + { + "WATSONX_API_KEY": "test-key", + "WATSONX_PROJECT_ID": "test-project", + }, + clear=True, + ) + def test_default_region_url(self): + """Test that default region constructs correct URL.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "test-token", + "expiration": time.time() + 3600, + } + mock_response.raise_for_status = MagicMock() + + with patch("httpx.post", return_value=mock_response): + with patch( + "crewai.llms.providers.openai.completion.OpenAICompletion.__init__", + return_value=None, + ) as mock_init: + completion = WatsonxCompletion.__new__(WatsonxCompletion) + # Manually set _iam_manager and _project_id since we skip __init__ + # Instead, test the static method directly + url = WatsonxCompletion._resolve_base_url(None, "us-south") + assert url == "https://us-south.ml.cloud.ibm.com/ml/v1" + + def test_resolve_base_url_custom_region(self): + """Test URL construction with custom region.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + url = WatsonxCompletion._resolve_base_url(None, "eu-de") + assert url == "https://eu-de.ml.cloud.ibm.com/ml/v1" + + def test_resolve_base_url_explicit(self): + """Test that explicit base_url takes priority.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + url = WatsonxCompletion._resolve_base_url( + "https://custom.example.com/v1", "us-south" + ) + assert url == "https://custom.example.com/v1" + + @patch.dict( + os.environ, + {"WATSONX_URL": "https://env-override.example.com/v1"}, + clear=True, + ) + def test_resolve_base_url_env_override(self): + """Test that WATSONX_URL env var overrides region-based URL.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + url = WatsonxCompletion._resolve_base_url(None, "us-south") + assert url == "https://env-override.example.com/v1" + + def test_resolve_region_default(self): + """Test default region resolution.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + with patch.dict(os.environ, {}, clear=True): + region = WatsonxCompletion._resolve_region(None) + assert region == "us-south" + + @patch.dict(os.environ, {"WATSONX_REGION": "eu-gb"}, clear=True) + def test_resolve_region_from_env(self): + """Test region resolution from environment variable.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + region = WatsonxCompletion._resolve_region(None) + assert region == "eu-gb" + + def test_resolve_region_explicit(self): + """Test explicit region parameter takes priority.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + region = WatsonxCompletion._resolve_region("jp-tok") + assert region == "jp-tok" + + +class TestWatsonxModelCapabilities: + """Tests for model capability detection.""" + + def _make_completion(self, model: str) -> object: + """Create a minimal WatsonxCompletion-like object for testing.""" + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + # Create a bare instance without calling __init__ + obj = object.__new__(WatsonxCompletion) + obj.model = model + return obj + + def test_granite_4_context_window(self): + """Test Granite 4.x models report 128K context.""" + comp = self._make_completion("ibm/granite-4-h-small") + assert comp.get_context_window_size() == 131072 + + def test_granite_3_instruct_context_window(self): + """Test Granite 3.x instruct models report 128K context.""" + comp = self._make_completion("ibm/granite-3-8b-instruct") + assert comp.get_context_window_size() == 131072 + + def test_granite_code_context_window(self): + """Test Granite code models report 8K context.""" + comp = self._make_completion("ibm/granite-8b-code-instruct") + assert comp.get_context_window_size() == 8192 + + def test_granite_4_supports_function_calling(self): + """Test Granite 4.x models support function calling.""" + comp = self._make_completion("ibm/granite-4-h-small") + assert comp.supports_function_calling() is True + + def test_granite_3_instruct_supports_function_calling(self): + """Test Granite 3.x instruct models support function calling.""" + comp = self._make_completion("ibm/granite-3-8b-instruct") + assert comp.supports_function_calling() is True + + def test_granite_guardian_no_function_calling(self): + """Test Granite Guardian models don't support function calling.""" + comp = self._make_completion("ibm/granite-guardian-3-8b") + assert comp.supports_function_calling() is False + + def test_granite_not_multimodal(self): + """Test Granite models are not multimodal.""" + comp = self._make_completion("ibm/granite-4-h-small") + assert comp.supports_multimodal() is False + + +class TestWatsonxModelRouting: + """Tests for model routing through the LLM factory.""" + + def test_watsonx_models_in_constants(self): + """Test that WATSONX_MODELS is properly defined.""" + from crewai.llms.constants import WATSONX_MODELS + + assert "ibm/granite-4-h-small" in WATSONX_MODELS + assert "ibm/granite-3-8b-instruct" in WATSONX_MODELS + assert "ibm/granite-guardian-3-8b" in WATSONX_MODELS + assert len(WATSONX_MODELS) >= 10 + + def test_watsonx_in_supported_providers(self): + """Test that watsonx is in the supported native providers list.""" + from crewai.llm import SUPPORTED_NATIVE_PROVIDERS + + assert "watsonx" in SUPPORTED_NATIVE_PROVIDERS + assert "ibm" in SUPPORTED_NATIVE_PROVIDERS + + def test_get_native_provider_watsonx(self): + """Test that _get_native_provider returns WatsonxCompletion.""" + from crewai.llm import LLM + from crewai.llms.providers.watsonx.completion import WatsonxCompletion + + assert LLM._get_native_provider("watsonx") is WatsonxCompletion + assert LLM._get_native_provider("ibm") is WatsonxCompletion + + def test_infer_provider_from_watsonx_model(self): + """Test that Granite models are inferred as watsonx provider.""" + from crewai.llm import LLM + + assert LLM._infer_provider_from_model("ibm/granite-4-h-small") == "watsonx"