mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-14 15:02:37 +00:00
feat: add IBM Granite model support via watsonx.ai provider (OSS-35)
Add native support for IBM Granite models through the watsonx.ai Model Gateway OpenAI-compatible API. Implementation: - New watsonx provider at llms/providers/watsonx/ extending OpenAICompletion - IBM Cloud IAM token exchange with thread-safe caching and auto-refresh - Support for WATSONX_API_KEY, WATSONX_PROJECT_ID, WATSONX_REGION env vars - 12 Granite models in constants (3.x, 4.x, code, guardian families) - Full LLM routing: watsonx/ibm/granite-4-h-small or provider='watsonx' - No new dependencies required (uses existing openai + httpx) Usage: llm = LLM(model='watsonx/ibm/granite-4-h-small') llm = LLM(model='ibm/granite-4-h-small', provider='watsonx') Closes OSS-35
This commit is contained in:
@@ -44,6 +44,7 @@ from crewai.llms.constants import (
|
||||
BEDROCK_MODELS,
|
||||
GEMINI_MODELS,
|
||||
OPENAI_MODELS,
|
||||
WATSONX_MODELS,
|
||||
)
|
||||
from crewai.utilities import InternalInstructor
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -309,6 +310,8 @@ SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"aws",
|
||||
"watsonx",
|
||||
"ibm",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter",
|
||||
"deepseek",
|
||||
@@ -376,6 +379,8 @@ class LLM(BaseLLM):
|
||||
"gemini": "gemini",
|
||||
"bedrock": "bedrock",
|
||||
"aws": "bedrock",
|
||||
"watsonx": "watsonx",
|
||||
"ibm": "watsonx",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter": "openrouter",
|
||||
"deepseek": "deepseek",
|
||||
@@ -506,6 +511,12 @@ class LLM(BaseLLM):
|
||||
# OpenRouter uses org/model format but accepts anything
|
||||
return True
|
||||
|
||||
if provider == "watsonx" or provider == "ibm":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["ibm/granite", "granite"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@@ -541,6 +552,9 @@ class LLM(BaseLLM):
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
if (provider == "watsonx" or provider == "ibm") and model in WATSONX_MODELS:
|
||||
return True
|
||||
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@@ -573,6 +587,9 @@ class LLM(BaseLLM):
|
||||
if model in AZURE_MODELS:
|
||||
return "azure"
|
||||
|
||||
if model in WATSONX_MODELS:
|
||||
return "watsonx"
|
||||
|
||||
return "openai"
|
||||
|
||||
@classmethod
|
||||
@@ -605,6 +622,11 @@ class LLM(BaseLLM):
|
||||
|
||||
return BedrockCompletion
|
||||
|
||||
if provider == "watsonx" or provider == "ibm":
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
return WatsonxCompletion
|
||||
|
||||
# OpenAI-compatible providers
|
||||
openai_compatible_providers = {
|
||||
"openrouter",
|
||||
|
||||
@@ -568,3 +568,33 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"qwen.qwen3-coder-30b-a3b-v1:0",
|
||||
"twelvelabs.pegasus-1-2-v1:0",
|
||||
]
|
||||
|
||||
|
||||
WatsonxModels: TypeAlias = Literal[
|
||||
"ibm/granite-3-2b-instruct",
|
||||
"ibm/granite-3-8b-instruct",
|
||||
"ibm/granite-3-1-2b-instruct",
|
||||
"ibm/granite-3-1-8b-instruct",
|
||||
"ibm/granite-3-1-8b-base",
|
||||
"ibm/granite-3-3-2b-instruct",
|
||||
"ibm/granite-3-3-8b-instruct",
|
||||
"ibm/granite-4-h-micro",
|
||||
"ibm/granite-4-h-tiny",
|
||||
"ibm/granite-4-h-small",
|
||||
"ibm/granite-8b-code-instruct",
|
||||
"ibm/granite-guardian-3-8b",
|
||||
]
|
||||
WATSONX_MODELS: list[WatsonxModels] = [
|
||||
"ibm/granite-3-2b-instruct",
|
||||
"ibm/granite-3-8b-instruct",
|
||||
"ibm/granite-3-1-2b-instruct",
|
||||
"ibm/granite-3-1-8b-instruct",
|
||||
"ibm/granite-3-1-8b-base",
|
||||
"ibm/granite-3-3-2b-instruct",
|
||||
"ibm/granite-3-3-8b-instruct",
|
||||
"ibm/granite-4-h-micro",
|
||||
"ibm/granite-4-h-tiny",
|
||||
"ibm/granite-4-h-small",
|
||||
"ibm/granite-8b-code-instruct",
|
||||
"ibm/granite-guardian-3-8b",
|
||||
]
|
||||
|
||||
5
lib/crewai/src/crewai/llms/providers/watsonx/__init__.py
Normal file
5
lib/crewai/src/crewai/llms/providers/watsonx/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""IBM watsonx.ai provider module."""
|
||||
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
__all__ = ["WatsonxCompletion"]
|
||||
444
lib/crewai/src/crewai/llms/providers/watsonx/completion.py
Normal file
444
lib/crewai/src/crewai/llms/providers/watsonx/completion.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""IBM watsonx.ai provider implementation.
|
||||
|
||||
This module provides native support for IBM Granite models via the
|
||||
watsonx.ai Model Gateway, which exposes an OpenAI-compatible API.
|
||||
|
||||
Authentication uses IBM Cloud IAM token exchange: an API key is exchanged
|
||||
for a short-lived Bearer token via the IAM identity service.
|
||||
|
||||
Usage:
|
||||
llm = LLM(model="watsonx/ibm/granite-4-h-small")
|
||||
llm = LLM(model="ibm/granite-4-h-small", provider="watsonx")
|
||||
|
||||
Environment variables:
|
||||
WATSONX_API_KEY: IBM Cloud API key (required)
|
||||
WATSONX_PROJECT_ID: watsonx.ai project ID (required)
|
||||
WATSONX_REGION: IBM Cloud region (default: us-south)
|
||||
WATSONX_URL: Full base URL override (optional)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# IBM Cloud IAM endpoint for token exchange
|
||||
_IAM_TOKEN_URL = "https://iam.cloud.ibm.com/identity/token"
|
||||
|
||||
# Default region for watsonx.ai
|
||||
_DEFAULT_REGION = "us-south"
|
||||
|
||||
# Refresh token 60 seconds before expiry to avoid race conditions
|
||||
_TOKEN_REFRESH_BUFFER_SECONDS = 60
|
||||
|
||||
# Supported watsonx.ai regions
|
||||
_SUPPORTED_REGIONS = frozenset({
|
||||
"us-south",
|
||||
"eu-de",
|
||||
"eu-gb",
|
||||
"jp-tok",
|
||||
"au-syd",
|
||||
})
|
||||
|
||||
|
||||
class _IAMTokenManager:
|
||||
"""Thread-safe IBM IAM token manager with automatic refresh.
|
||||
|
||||
Exchanges an IBM Cloud API key for a short-lived Bearer token and
|
||||
caches it, refreshing automatically when the token approaches expiry.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self._api_key = api_key
|
||||
self._token: str | None = None
|
||||
self._expiry: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_token(self) -> str:
|
||||
"""Get a valid IAM Bearer token, refreshing if needed.
|
||||
|
||||
Returns:
|
||||
A valid Bearer token string.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the token exchange fails.
|
||||
"""
|
||||
if self._token and time.time() < self._expiry - _TOKEN_REFRESH_BUFFER_SECONDS:
|
||||
return self._token
|
||||
|
||||
with self._lock:
|
||||
# Double-check after acquiring lock
|
||||
if (
|
||||
self._token
|
||||
and time.time() < self._expiry - _TOKEN_REFRESH_BUFFER_SECONDS
|
||||
):
|
||||
return self._token
|
||||
|
||||
self._refresh_token()
|
||||
assert self._token is not None
|
||||
return self._token
|
||||
|
||||
def _refresh_token(self) -> None:
|
||||
"""Exchange API key for a new IAM token."""
|
||||
try:
|
||||
response = httpx.post(
|
||||
_IAM_TOKEN_URL,
|
||||
data={
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": self._api_key,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise RuntimeError(
|
||||
f"IBM IAM token exchange failed (HTTP {e.response.status_code}): "
|
||||
f"{e.response.text}"
|
||||
) from e
|
||||
except httpx.HTTPError as e:
|
||||
raise RuntimeError(
|
||||
f"IBM IAM token exchange request failed: {e}"
|
||||
) from e
|
||||
|
||||
data = response.json()
|
||||
self._token = data["access_token"]
|
||||
self._expiry = float(data["expiration"])
|
||||
logger.debug(
|
||||
"IBM IAM token refreshed, expires at %s",
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(self._expiry)),
|
||||
)
|
||||
|
||||
|
||||
class WatsonxCompletion(OpenAICompletion):
|
||||
"""IBM watsonx.ai completion implementation.
|
||||
|
||||
This class provides support for IBM Granite models and other foundation
|
||||
models hosted on watsonx.ai via the OpenAI-compatible Model Gateway.
|
||||
|
||||
Authentication is handled transparently via IBM Cloud IAM token exchange.
|
||||
The API key is exchanged for a Bearer token which is automatically
|
||||
refreshed when it approaches expiry.
|
||||
|
||||
Supported models include the IBM Granite family:
|
||||
- ibm/granite-4-h-small (32B hybrid)
|
||||
- ibm/granite-4-h-tiny (7B hybrid)
|
||||
- ibm/granite-4-h-micro (3B hybrid)
|
||||
- ibm/granite-3-8b-instruct
|
||||
- ibm/granite-3-3-8b-instruct
|
||||
- ibm/granite-8b-code-instruct
|
||||
- ibm/granite-guardian-3-8b
|
||||
- And other models available on watsonx.ai
|
||||
|
||||
Example:
|
||||
# Using provider prefix
|
||||
llm = LLM(model="watsonx/ibm/granite-4-h-small")
|
||||
|
||||
# Using explicit provider
|
||||
llm = LLM(model="ibm/granite-4-h-small", provider="watsonx")
|
||||
|
||||
# With custom configuration
|
||||
llm = LLM(
|
||||
model="ibm/granite-4-h-small",
|
||||
provider="watsonx",
|
||||
api_key="my-ibm-cloud-api-key",
|
||||
temperature=0.7,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str = "watsonx",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
project_id: str | None = None,
|
||||
region: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize watsonx.ai completion client.
|
||||
|
||||
Args:
|
||||
model: The model identifier (e.g., "ibm/granite-4-h-small").
|
||||
provider: The provider name (default: "watsonx").
|
||||
api_key: IBM Cloud API key. If not provided, reads from
|
||||
WATSONX_API_KEY environment variable.
|
||||
base_url: Full base URL override for the watsonx.ai endpoint.
|
||||
If not provided, constructed from region.
|
||||
project_id: watsonx.ai project ID. If not provided, reads from
|
||||
WATSONX_PROJECT_ID environment variable.
|
||||
region: IBM Cloud region (default: "us-south"). If not provided,
|
||||
reads from WATSONX_REGION environment variable.
|
||||
**kwargs: Additional arguments passed to OpenAICompletion.
|
||||
|
||||
Raises:
|
||||
ValueError: If required credentials are missing.
|
||||
"""
|
||||
resolved_api_key = self._resolve_api_key(api_key)
|
||||
resolved_project_id = self._resolve_project_id(project_id)
|
||||
resolved_region = self._resolve_region(region)
|
||||
resolved_base_url = self._resolve_base_url(base_url, resolved_region)
|
||||
|
||||
# Initialize IAM token manager for transparent auth
|
||||
self._iam_manager = _IAMTokenManager(resolved_api_key)
|
||||
self._project_id = resolved_project_id
|
||||
|
||||
# Get initial token for client construction
|
||||
initial_token = self._iam_manager.get_token()
|
||||
|
||||
# Pass the bearer token as api_key to OpenAI client
|
||||
# The OpenAI SDK uses this as Authorization: Bearer <token>
|
||||
super().__init__(
|
||||
model=model,
|
||||
provider=provider,
|
||||
api_key=initial_token,
|
||||
base_url=resolved_base_url,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_api_key(api_key: str | None) -> str:
|
||||
"""Resolve IBM Cloud API key from parameter or environment.
|
||||
|
||||
Args:
|
||||
api_key: Explicitly provided API key.
|
||||
|
||||
Returns:
|
||||
The resolved API key.
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is found.
|
||||
"""
|
||||
resolved = api_key or os.getenv("WATSONX_API_KEY")
|
||||
if not resolved:
|
||||
raise ValueError(
|
||||
"IBM Cloud API key is required for watsonx.ai provider. "
|
||||
"Set the WATSONX_API_KEY environment variable or pass "
|
||||
"api_key parameter."
|
||||
)
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _resolve_project_id(project_id: str | None) -> str:
|
||||
"""Resolve watsonx.ai project ID from parameter or environment.
|
||||
|
||||
Args:
|
||||
project_id: Explicitly provided project ID.
|
||||
|
||||
Returns:
|
||||
The resolved project ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If no project ID is found.
|
||||
"""
|
||||
resolved = project_id or os.getenv("WATSONX_PROJECT_ID")
|
||||
if not resolved:
|
||||
raise ValueError(
|
||||
"watsonx.ai project ID is required. "
|
||||
"Set the WATSONX_PROJECT_ID environment variable or pass "
|
||||
"project_id parameter."
|
||||
)
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _resolve_region(region: str | None) -> str:
|
||||
"""Resolve IBM Cloud region from parameter or environment.
|
||||
|
||||
Args:
|
||||
region: Explicitly provided region.
|
||||
|
||||
Returns:
|
||||
The resolved region string.
|
||||
"""
|
||||
resolved = region or os.getenv("WATSONX_REGION", _DEFAULT_REGION)
|
||||
if resolved not in _SUPPORTED_REGIONS:
|
||||
logger.warning(
|
||||
"Region '%s' is not in the known supported regions: %s. "
|
||||
"Proceeding anyway in case IBM has added new regions.",
|
||||
resolved,
|
||||
", ".join(sorted(_SUPPORTED_REGIONS)),
|
||||
)
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _resolve_base_url(base_url: str | None, region: str) -> str:
|
||||
"""Resolve the watsonx.ai base URL.
|
||||
|
||||
Priority:
|
||||
1. Explicit base_url parameter
|
||||
2. WATSONX_URL environment variable
|
||||
3. Constructed from region
|
||||
|
||||
Args:
|
||||
base_url: Explicitly provided base URL.
|
||||
region: IBM Cloud region for URL construction.
|
||||
|
||||
Returns:
|
||||
The resolved base URL.
|
||||
"""
|
||||
if base_url:
|
||||
return base_url.rstrip("/")
|
||||
|
||||
env_url = os.getenv("WATSONX_URL")
|
||||
if env_url:
|
||||
return env_url.rstrip("/")
|
||||
|
||||
return f"https://{region}.ml.cloud.ibm.com/ml/v1"
|
||||
|
||||
def _build_client(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
) -> OpenAI:
|
||||
"""Build the OpenAI client with watsonx-specific configuration.
|
||||
|
||||
Overrides the parent method to inject the project_id header
|
||||
and ensure the IAM token is current.
|
||||
|
||||
Args:
|
||||
api_key: Bearer token (from IAM exchange).
|
||||
base_url: watsonx.ai endpoint URL.
|
||||
default_headers: Additional headers.
|
||||
|
||||
Returns:
|
||||
Configured OpenAI client instance.
|
||||
"""
|
||||
# Refresh token if needed
|
||||
current_token = self._iam_manager.get_token()
|
||||
|
||||
# Merge watsonx-specific headers
|
||||
watsonx_headers = {
|
||||
"X-Watsonx-Project-Id": self._project_id,
|
||||
}
|
||||
if default_headers:
|
||||
watsonx_headers.update(default_headers)
|
||||
|
||||
return super()._build_client(
|
||||
api_key=current_token,
|
||||
base_url=base_url,
|
||||
default_headers=watsonx_headers,
|
||||
)
|
||||
|
||||
def _ensure_fresh_token(self) -> None:
|
||||
"""Refresh the IAM token on the client if needed.
|
||||
|
||||
Updates the client's API key (Bearer token) if the cached
|
||||
token has been refreshed.
|
||||
"""
|
||||
current_token = self._iam_manager.get_token()
|
||||
if hasattr(self, "client") and self.client is not None:
|
||||
self.client.api_key = current_token
|
||||
|
||||
def call(self, messages, tools=None, callbacks=None, available_functions=None,
|
||||
from_task=None, from_agent=None, response_model=None):
|
||||
"""Call the LLM, refreshing the IAM token if needed."""
|
||||
self._ensure_fresh_token()
|
||||
return super().call(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
async def acall(self, messages, tools=None, callbacks=None, available_functions=None,
|
||||
from_task=None, from_agent=None, response_model=None):
|
||||
"""Async call the LLM, refreshing the IAM token if needed."""
|
||||
self._ensure_fresh_token()
|
||||
return await super().acall(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get context window size for Granite models.
|
||||
|
||||
Returns:
|
||||
The context window size in tokens.
|
||||
"""
|
||||
model_lower = self.model.lower()
|
||||
|
||||
# Granite 4.x models have 128K context
|
||||
if "granite-4" in model_lower:
|
||||
return 131072
|
||||
|
||||
# Granite 3.x instruct models have 128K context
|
||||
if "granite-3" in model_lower and "instruct" in model_lower:
|
||||
return 131072
|
||||
|
||||
# Granite 3.x base models have 4K context
|
||||
if "granite-3" in model_lower:
|
||||
return 4096
|
||||
|
||||
# Granite code models
|
||||
if "granite" in model_lower and "code" in model_lower:
|
||||
return 8192
|
||||
|
||||
# Default for unknown models
|
||||
return 8192
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling / tool use.
|
||||
|
||||
Granite 3.x instruct and 4.x models support tool use.
|
||||
|
||||
Returns:
|
||||
True if the model supports function calling.
|
||||
"""
|
||||
model_lower = self.model.lower()
|
||||
|
||||
# Granite 4.x models support tool use
|
||||
if "granite-4" in model_lower:
|
||||
return True
|
||||
|
||||
# Granite 3.x instruct models support tool use
|
||||
if "granite-3" in model_lower and "instruct" in model_lower:
|
||||
return True
|
||||
|
||||
# Granite guardian models don't do tool use
|
||||
if "guardian" in model_lower:
|
||||
return False
|
||||
|
||||
# Default: assume no tool use for unknown models
|
||||
return False
|
||||
|
||||
def supports_multimodal(self) -> bool:
|
||||
"""Check if the model supports multimodal inputs.
|
||||
|
||||
Currently, Granite models are text-only.
|
||||
|
||||
Returns:
|
||||
False (Granite models are text-only).
|
||||
"""
|
||||
return False
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Serialize this LLM to a dict for reconstruction.
|
||||
|
||||
Returns:
|
||||
Configuration dict with watsonx-specific fields.
|
||||
"""
|
||||
config = super().to_config_dict()
|
||||
config["model"] = f"watsonx/{self.model}" if "/" not in self.model else f"watsonx/{self.model}"
|
||||
return config
|
||||
0
lib/crewai/tests/llms/providers/watsonx/__init__.py
Normal file
0
lib/crewai/tests/llms/providers/watsonx/__init__.py
Normal file
293
lib/crewai/tests/llms/providers/watsonx/test_completion.py
Normal file
293
lib/crewai/tests/llms/providers/watsonx/test_completion.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Tests for IBM watsonx.ai provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
class TestIAMTokenManager:
|
||||
"""Tests for the IAM token manager."""
|
||||
|
||||
def test_token_exchange_success(self):
|
||||
"""Test successful IAM token exchange."""
|
||||
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "test-bearer-token",
|
||||
"expiration": time.time() + 3600,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.post", return_value=mock_response) as mock_post:
|
||||
manager = _IAMTokenManager("test-api-key")
|
||||
token = manager.get_token()
|
||||
|
||||
assert token == "test-bearer-token"
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
assert call_kwargs[1]["data"]["apikey"] == "test-api-key"
|
||||
assert (
|
||||
call_kwargs[1]["data"]["grant_type"]
|
||||
== "urn:ibm:params:oauth:grant-type:apikey"
|
||||
)
|
||||
|
||||
def test_token_caching(self):
|
||||
"""Test that tokens are cached and not re-fetched."""
|
||||
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "cached-token",
|
||||
"expiration": time.time() + 3600,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.post", return_value=mock_response) as mock_post:
|
||||
manager = _IAMTokenManager("test-api-key")
|
||||
|
||||
# First call - should fetch
|
||||
token1 = manager.get_token()
|
||||
# Second call - should use cache
|
||||
token2 = manager.get_token()
|
||||
|
||||
assert token1 == token2 == "cached-token"
|
||||
assert mock_post.call_count == 1 # Only one HTTP call
|
||||
|
||||
def test_token_refresh_on_expiry(self):
|
||||
"""Test that expired tokens are refreshed."""
|
||||
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
|
||||
|
||||
call_count = 0
|
||||
|
||||
def mock_post(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"access_token": f"token-{call_count}",
|
||||
"expiration": time.time() + (0 if call_count == 1 else 3600),
|
||||
}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
return mock_resp
|
||||
|
||||
with patch("httpx.post", side_effect=mock_post):
|
||||
manager = _IAMTokenManager("test-api-key")
|
||||
|
||||
# First call - gets token-1 which is already expired
|
||||
token1 = manager.get_token()
|
||||
assert token1 == "token-1"
|
||||
|
||||
# Second call - token-1 is expired, should refresh to token-2
|
||||
token2 = manager.get_token()
|
||||
assert token2 == "token-2"
|
||||
assert call_count == 2
|
||||
|
||||
def test_token_exchange_http_error(self):
|
||||
"""Test that HTTP errors during token exchange raise RuntimeError."""
|
||||
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"401", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch("httpx.post", return_value=mock_response):
|
||||
manager = _IAMTokenManager("bad-api-key")
|
||||
with pytest.raises(RuntimeError, match="IBM IAM token exchange failed"):
|
||||
manager.get_token()
|
||||
|
||||
|
||||
class TestWatsonxCompletionInit:
|
||||
"""Tests for WatsonxCompletion initialization."""
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_missing_api_key_raises(self):
|
||||
"""Test that missing API key raises ValueError."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
with pytest.raises(ValueError, match="IBM Cloud API key is required"):
|
||||
WatsonxCompletion(model="ibm/granite-4-h-small")
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"WATSONX_API_KEY": "test-key"},
|
||||
clear=True,
|
||||
)
|
||||
def test_missing_project_id_raises(self):
|
||||
"""Test that missing project ID raises ValueError."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
with pytest.raises(ValueError, match="project ID is required"):
|
||||
WatsonxCompletion(model="ibm/granite-4-h-small")
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WATSONX_API_KEY": "test-key",
|
||||
"WATSONX_PROJECT_ID": "test-project",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_default_region_url(self):
|
||||
"""Test that default region constructs correct URL."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "test-token",
|
||||
"expiration": time.time() + 3600,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.post", return_value=mock_response):
|
||||
with patch(
|
||||
"crewai.llms.providers.openai.completion.OpenAICompletion.__init__",
|
||||
return_value=None,
|
||||
) as mock_init:
|
||||
completion = WatsonxCompletion.__new__(WatsonxCompletion)
|
||||
# Manually set _iam_manager and _project_id since we skip __init__
|
||||
# Instead, test the static method directly
|
||||
url = WatsonxCompletion._resolve_base_url(None, "us-south")
|
||||
assert url == "https://us-south.ml.cloud.ibm.com/ml/v1"
|
||||
|
||||
def test_resolve_base_url_custom_region(self):
|
||||
"""Test URL construction with custom region."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
url = WatsonxCompletion._resolve_base_url(None, "eu-de")
|
||||
assert url == "https://eu-de.ml.cloud.ibm.com/ml/v1"
|
||||
|
||||
def test_resolve_base_url_explicit(self):
|
||||
"""Test that explicit base_url takes priority."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
url = WatsonxCompletion._resolve_base_url(
|
||||
"https://custom.example.com/v1", "us-south"
|
||||
)
|
||||
assert url == "https://custom.example.com/v1"
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"WATSONX_URL": "https://env-override.example.com/v1"},
|
||||
clear=True,
|
||||
)
|
||||
def test_resolve_base_url_env_override(self):
|
||||
"""Test that WATSONX_URL env var overrides region-based URL."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
url = WatsonxCompletion._resolve_base_url(None, "us-south")
|
||||
assert url == "https://env-override.example.com/v1"
|
||||
|
||||
def test_resolve_region_default(self):
|
||||
"""Test default region resolution."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
region = WatsonxCompletion._resolve_region(None)
|
||||
assert region == "us-south"
|
||||
|
||||
@patch.dict(os.environ, {"WATSONX_REGION": "eu-gb"}, clear=True)
|
||||
def test_resolve_region_from_env(self):
|
||||
"""Test region resolution from environment variable."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
region = WatsonxCompletion._resolve_region(None)
|
||||
assert region == "eu-gb"
|
||||
|
||||
def test_resolve_region_explicit(self):
|
||||
"""Test explicit region parameter takes priority."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
region = WatsonxCompletion._resolve_region("jp-tok")
|
||||
assert region == "jp-tok"
|
||||
|
||||
|
||||
class TestWatsonxModelCapabilities:
|
||||
"""Tests for model capability detection."""
|
||||
|
||||
def _make_completion(self, model: str) -> object:
|
||||
"""Create a minimal WatsonxCompletion-like object for testing."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
# Create a bare instance without calling __init__
|
||||
obj = object.__new__(WatsonxCompletion)
|
||||
obj.model = model
|
||||
return obj
|
||||
|
||||
def test_granite_4_context_window(self):
|
||||
"""Test Granite 4.x models report 128K context."""
|
||||
comp = self._make_completion("ibm/granite-4-h-small")
|
||||
assert comp.get_context_window_size() == 131072
|
||||
|
||||
def test_granite_3_instruct_context_window(self):
|
||||
"""Test Granite 3.x instruct models report 128K context."""
|
||||
comp = self._make_completion("ibm/granite-3-8b-instruct")
|
||||
assert comp.get_context_window_size() == 131072
|
||||
|
||||
def test_granite_code_context_window(self):
|
||||
"""Test Granite code models report 8K context."""
|
||||
comp = self._make_completion("ibm/granite-8b-code-instruct")
|
||||
assert comp.get_context_window_size() == 8192
|
||||
|
||||
def test_granite_4_supports_function_calling(self):
|
||||
"""Test Granite 4.x models support function calling."""
|
||||
comp = self._make_completion("ibm/granite-4-h-small")
|
||||
assert comp.supports_function_calling() is True
|
||||
|
||||
def test_granite_3_instruct_supports_function_calling(self):
|
||||
"""Test Granite 3.x instruct models support function calling."""
|
||||
comp = self._make_completion("ibm/granite-3-8b-instruct")
|
||||
assert comp.supports_function_calling() is True
|
||||
|
||||
def test_granite_guardian_no_function_calling(self):
|
||||
"""Test Granite Guardian models don't support function calling."""
|
||||
comp = self._make_completion("ibm/granite-guardian-3-8b")
|
||||
assert comp.supports_function_calling() is False
|
||||
|
||||
def test_granite_not_multimodal(self):
|
||||
"""Test Granite models are not multimodal."""
|
||||
comp = self._make_completion("ibm/granite-4-h-small")
|
||||
assert comp.supports_multimodal() is False
|
||||
|
||||
|
||||
class TestWatsonxModelRouting:
|
||||
"""Tests for model routing through the LLM factory."""
|
||||
|
||||
def test_watsonx_models_in_constants(self):
|
||||
"""Test that WATSONX_MODELS is properly defined."""
|
||||
from crewai.llms.constants import WATSONX_MODELS
|
||||
|
||||
assert "ibm/granite-4-h-small" in WATSONX_MODELS
|
||||
assert "ibm/granite-3-8b-instruct" in WATSONX_MODELS
|
||||
assert "ibm/granite-guardian-3-8b" in WATSONX_MODELS
|
||||
assert len(WATSONX_MODELS) >= 10
|
||||
|
||||
def test_watsonx_in_supported_providers(self):
|
||||
"""Test that watsonx is in the supported native providers list."""
|
||||
from crewai.llm import SUPPORTED_NATIVE_PROVIDERS
|
||||
|
||||
assert "watsonx" in SUPPORTED_NATIVE_PROVIDERS
|
||||
assert "ibm" in SUPPORTED_NATIVE_PROVIDERS
|
||||
|
||||
def test_get_native_provider_watsonx(self):
|
||||
"""Test that _get_native_provider returns WatsonxCompletion."""
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
assert LLM._get_native_provider("watsonx") is WatsonxCompletion
|
||||
assert LLM._get_native_provider("ibm") is WatsonxCompletion
|
||||
|
||||
def test_infer_provider_from_watsonx_model(self):
|
||||
"""Test that Granite models are inferred as watsonx provider."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
assert LLM._infer_provider_from_model("ibm/granite-4-h-small") == "watsonx"
|
||||
Reference in New Issue
Block a user