mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-14 23:12:37 +00:00
Compare commits
2 Commits
1.13.0rc1
...
feature/ib
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3f422a121 | ||
|
|
e21c506214 |
1303
docs/docs.json
1303
docs/docs.json
File diff suppressed because it is too large
Load Diff
550
docs/en/enterprise/features/sso.mdx
Normal file
550
docs/en/enterprise/features/sso.mdx
Normal file
@@ -0,0 +1,550 @@
|
||||
---
|
||||
title: Single Sign-On (SSO)
|
||||
icon: "key"
|
||||
description: Configure enterprise SSO authentication for CrewAI Platform — SaaS and Factory
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
CrewAI Platform supports enterprise Single Sign-On (SSO) across both **SaaS (AMP)** and **Factory (self-hosted)** deployments. SSO enables your team to authenticate using your organization's existing identity provider, enforcing centralized access control, MFA policies, and user lifecycle management.
|
||||
|
||||
### Supported Providers
|
||||
|
||||
| Provider | SaaS | Factory | Protocol | CLI Support |
|
||||
|---|---|---|---|---|
|
||||
| **WorkOS** | ✅ (default) | ✅ | OAuth 2.0 / OIDC | ✅ |
|
||||
| **Microsoft Entra ID** (Azure AD) | ✅ (enterprise) | ✅ | OAuth 2.0 / SAML 2.0 | ✅ |
|
||||
| **Okta** | ✅ (enterprise) | ✅ | OAuth 2.0 / OIDC | ✅ |
|
||||
| **Auth0** | ✅ (enterprise) | ✅ | OAuth 2.0 / OIDC | ✅ |
|
||||
| **Keycloak** | — | ✅ | OAuth 2.0 / OIDC | ✅ |
|
||||
|
||||
### Key Capabilities
|
||||
|
||||
- **SAML 2.0 and OAuth 2.0 / OIDC** protocol support
|
||||
- **Device Authorization Grant** flow for CLI authentication
|
||||
- **Role-Based Access Control (RBAC)** with custom roles and per-resource permissions
|
||||
- **MFA enforcement** delegated to your identity provider
|
||||
- **User provisioning** through IdP assignment (users/groups)
|
||||
|
||||
---
|
||||
|
||||
## SaaS SSO
|
||||
|
||||
### Default Authentication
|
||||
|
||||
CrewAI's managed SaaS platform (AMP) uses **WorkOS** as the default authentication provider. When you sign up at [app.crewai.com](https://app.crewai.com), authentication is handled through `login.crewai.com` — no additional SSO configuration is required.
|
||||
|
||||
### Enterprise Custom SSO
|
||||
|
||||
Enterprise SaaS customers can configure SSO with their own identity provider (Entra ID, Okta, Auth0). Contact your CrewAI account team to enable custom SSO for your organization. Once configured:
|
||||
|
||||
1. Your team members authenticate through your organization's IdP
|
||||
2. Access control and MFA policies are enforced by your IdP
|
||||
3. The CrewAI CLI automatically detects your SSO configuration via `crewai enterprise configure`
|
||||
|
||||
### CLI Defaults (SaaS)
|
||||
|
||||
| Setting | Default Value |
|
||||
|---|---|
|
||||
| `enterprise_base_url` | `https://app.crewai.com` |
|
||||
| `oauth2_provider` | `workos` |
|
||||
| `oauth2_domain` | `login.crewai.com` |
|
||||
|
||||
---
|
||||
|
||||
## Factory SSO Setup
|
||||
|
||||
Factory (self-hosted) deployments require you to configure SSO by setting environment variables in your Helm `values.yaml` and registering an application in your identity provider.
|
||||
|
||||
### Microsoft Entra ID (Azure AD)
|
||||
|
||||
<Steps>
|
||||
<Step title="Register an Application">
|
||||
1. Go to [portal.azure.com](https://portal.azure.com) → **Microsoft Entra ID** → **App registrations** → **New registration**
|
||||
2. Configure:
|
||||
- **Name:** `CrewAI` (or your preferred name)
|
||||
- **Supported account types:** Accounts in this organizational directory only
|
||||
- **Redirect URI:** Select **Web**, enter `https://<your-domain>/auth/entra_id/callback`
|
||||
3. Click **Register**
|
||||
</Step>
|
||||
|
||||
<Step title="Collect Credentials">
|
||||
From the app overview page, copy:
|
||||
- **Application (client) ID** → `ENTRA_ID_CLIENT_ID`
|
||||
- **Directory (tenant) ID** → `ENTRA_ID_TENANT_ID`
|
||||
</Step>
|
||||
|
||||
<Step title="Create Client Secret">
|
||||
1. Navigate to **Certificates & Secrets** → **New client secret**
|
||||
2. Add a description and select expiration period
|
||||
3. Copy the secret value immediately (it won't be shown again) → `ENTRA_ID_CLIENT_SECRET`
|
||||
</Step>
|
||||
|
||||
<Step title="Grant Admin Consent">
|
||||
1. Go to **Enterprise applications** → select your app
|
||||
2. Under **Security** → **Permissions**, click **Grant admin consent**
|
||||
3. Ensure **Microsoft Graph → User.Read** is granted
|
||||
</Step>
|
||||
|
||||
<Step title="Configure App Roles (Recommended)">
|
||||
Under **App registrations** → your app → **App roles**, create:
|
||||
|
||||
| Display Name | Value | Allowed Member Types |
|
||||
|---|---|---|
|
||||
| Member | `member` | Users/Groups |
|
||||
| Factory Admin | `factory-admin` | Users/Groups |
|
||||
|
||||
<Note>
|
||||
The `member` role grants login access. The `factory-admin` role grants admin panel access. Roles are included in the JWT automatically.
|
||||
</Note>
|
||||
</Step>
|
||||
|
||||
<Step title="Assign Users">
|
||||
1. Under **Properties**, set **Assignment required?** to **Yes**
|
||||
2. Under **Users and groups**, assign users/groups with the appropriate role
|
||||
</Step>
|
||||
|
||||
<Step title="Set Environment Variables">
|
||||
```yaml
|
||||
envVars:
|
||||
AUTH_PROVIDER: "entra_id"
|
||||
|
||||
secrets:
|
||||
ENTRA_ID_CLIENT_ID: "<Application (client) ID>"
|
||||
ENTRA_ID_CLIENT_SECRET: "<Client Secret>"
|
||||
ENTRA_ID_TENANT_ID: "<Directory (tenant) ID>"
|
||||
```
|
||||
</Step>
|
||||
|
||||
<Step title="Enable CLI Support (Optional)">
|
||||
To allow `crewai login` via Device Authorization Grant:
|
||||
|
||||
1. Under **Authentication** → **Advanced settings**, enable **Allow public client flows**
|
||||
2. Under **Expose an API**, add an Application ID URI (e.g., `api://crewai-cli`)
|
||||
3. Add a scope (e.g., `read`) with **Admins and users** consent
|
||||
4. Under **Manifest**, set `accessTokenAcceptedVersion` to `2`
|
||||
5. Add environment variables:
|
||||
|
||||
```yaml
|
||||
secrets:
|
||||
ENTRA_ID_DEVICE_AUTHORIZATION_CLIENT_ID: "<Application (client) ID>"
|
||||
ENTRA_ID_CUSTOM_OPENID_SCOPE: "<scope URI, e.g. api://crewai-cli/read>"
|
||||
```
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
---
|
||||
|
||||
### Okta
|
||||
|
||||
<Steps>
|
||||
<Step title="Create App Integration">
|
||||
1. Open Okta Admin Console → **Applications** → **Create App Integration**
|
||||
2. Select **OIDC - OpenID Connect** → **Web Application** → **Next**
|
||||
3. Configure:
|
||||
- **App integration name:** `CrewAI SSO`
|
||||
- **Sign-in redirect URI:** `https://<your-domain>/auth/okta/callback`
|
||||
- **Sign-out redirect URI:** `https://<your-domain>`
|
||||
- **Assignments:** Choose who can access (everyone or specific groups)
|
||||
4. Click **Save**
|
||||
</Step>
|
||||
|
||||
<Step title="Collect Credentials">
|
||||
From the app details page:
|
||||
- **Client ID** → `OKTA_CLIENT_ID`
|
||||
- **Client Secret** → `OKTA_CLIENT_SECRET`
|
||||
- **Okta URL** (top-right corner, under your username) → `OKTA_SITE`
|
||||
</Step>
|
||||
|
||||
<Step title="Configure Authorization Server">
|
||||
1. Navigate to **Security** → **API**
|
||||
2. Select your authorization server (default: `default`)
|
||||
3. Under **Access Policies**, add a policy and rule:
|
||||
- In the rule, under **Scopes requested**, select **The following scopes** → **OIDC default scopes**
|
||||
4. Note the **Name** and **Audience** of the authorization server
|
||||
|
||||
<Warning>
|
||||
The authorization server name and audience must match `OKTA_AUTHORIZATION_SERVER` and `OKTA_AUDIENCE` exactly. Mismatches cause `401 Unauthorized` or `Invalid token: Signature verification failed` errors.
|
||||
</Warning>
|
||||
</Step>
|
||||
|
||||
<Step title="Set Environment Variables">
|
||||
```yaml
|
||||
envVars:
|
||||
AUTH_PROVIDER: "okta"
|
||||
|
||||
secrets:
|
||||
OKTA_CLIENT_ID: "<Okta app client ID>"
|
||||
OKTA_CLIENT_SECRET: "<Okta client secret>"
|
||||
OKTA_SITE: "https://your-domain.okta.com"
|
||||
OKTA_AUTHORIZATION_SERVER: "default"
|
||||
OKTA_AUDIENCE: "api://default"
|
||||
```
|
||||
</Step>
|
||||
|
||||
<Step title="Enable CLI Support (Optional)">
|
||||
1. Create a **new** app integration: **OIDC** → **Native Application**
|
||||
2. Enable **Device Authorization** and **Refresh Token** grant types
|
||||
3. Allow everyone in your organization to access
|
||||
4. Add environment variable:
|
||||
|
||||
```yaml
|
||||
secrets:
|
||||
OKTA_DEVICE_AUTHORIZATION_CLIENT_ID: "<Native app client ID>"
|
||||
```
|
||||
|
||||
<Note>
|
||||
Device Authorization requires a **Native Application** — it cannot use the Web Application created for browser-based SSO.
|
||||
</Note>
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
---
|
||||
|
||||
### Keycloak
|
||||
|
||||
<Steps>
|
||||
<Step title="Create a Client">
|
||||
1. Open Keycloak Admin Console → navigate to your realm
|
||||
2. **Clients** → **Create client**:
|
||||
- **Client type:** OpenID Connect
|
||||
- **Client ID:** `crewai-factory` (suggested)
|
||||
3. Capability config:
|
||||
- **Client authentication:** On
|
||||
- **Standard flow:** Checked
|
||||
4. Login settings:
|
||||
- **Root URL:** `https://<your-domain>`
|
||||
- **Valid redirect URIs:** `https://<your-domain>/auth/keycloak/callback`
|
||||
- **Valid post logout redirect URIs:** `https://<your-domain>`
|
||||
5. Click **Save**
|
||||
</Step>
|
||||
|
||||
<Step title="Collect Credentials">
|
||||
- **Client ID** → `KEYCLOAK_CLIENT_ID`
|
||||
- Under **Credentials** tab: **Client secret** → `KEYCLOAK_CLIENT_SECRET`
|
||||
- **Realm name** → `KEYCLOAK_REALM`
|
||||
- **Keycloak server URL** → `KEYCLOAK_SITE`
|
||||
</Step>
|
||||
|
||||
<Step title="Set Environment Variables">
|
||||
```yaml
|
||||
envVars:
|
||||
AUTH_PROVIDER: "keycloak"
|
||||
|
||||
secrets:
|
||||
KEYCLOAK_CLIENT_ID: "<client ID>"
|
||||
KEYCLOAK_CLIENT_SECRET: "<client secret>"
|
||||
KEYCLOAK_SITE: "https://keycloak.yourdomain.com"
|
||||
KEYCLOAK_REALM: "<realm name>"
|
||||
KEYCLOAK_AUDIENCE: "account"
|
||||
# Only set if using a custom base path (pre-v17 migrations):
|
||||
# KEYCLOAK_BASE_URL: "/auth"
|
||||
```
|
||||
|
||||
<Note>
|
||||
Keycloak includes `account` as the default audience in access tokens. For most installations, `KEYCLOAK_AUDIENCE=account` works without additional configuration. See [Keycloak audience documentation](https://www.keycloak.org/docs/latest/authorization_services/index.html) if you need a custom audience.
|
||||
</Note>
|
||||
</Step>
|
||||
|
||||
<Step title="Enable CLI Support (Optional)">
|
||||
1. Create a **second** client:
|
||||
- **Client type:** OpenID Connect
|
||||
- **Client ID:** `crewai-factory-cli` (suggested)
|
||||
- **Client authentication:** Off (Device Authorization requires a public client)
|
||||
- **Authentication flow:** Check **only** OAuth 2.0 Device Authorization Grant
|
||||
2. Add environment variable:
|
||||
|
||||
```yaml
|
||||
secrets:
|
||||
KEYCLOAK_DEVICE_AUTHORIZATION_CLIENT_ID: "<CLI client ID>"
|
||||
```
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
---
|
||||
|
||||
### WorkOS
|
||||
|
||||
<Steps>
|
||||
<Step title="Configure in WorkOS Dashboard">
|
||||
1. Create an application in the [WorkOS Dashboard](https://dashboard.workos.com)
|
||||
2. Configure the redirect URI: `https://<your-domain>/auth/workos/callback`
|
||||
3. Note the **Client ID** and **AuthKit domain**
|
||||
4. Set up organizations in the WorkOS dashboard
|
||||
</Step>
|
||||
|
||||
<Step title="Set Environment Variables">
|
||||
```yaml
|
||||
envVars:
|
||||
AUTH_PROVIDER: "workos"
|
||||
|
||||
secrets:
|
||||
WORKOS_CLIENT_ID: "<WorkOS client ID>"
|
||||
WORKOS_AUTHKIT_DOMAIN: "<your-authkit-domain.authkit.com>"
|
||||
```
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
---
|
||||
|
||||
### Auth0
|
||||
|
||||
<Steps>
|
||||
<Step title="Create Application">
|
||||
1. In the [Auth0 Dashboard](https://manage.auth0.com), create a new **Regular Web Application**
|
||||
2. Configure:
|
||||
- **Allowed Callback URLs:** `https://<your-domain>/auth/auth0/callback`
|
||||
- **Allowed Logout URLs:** `https://<your-domain>`
|
||||
3. Note the **Domain**, **Client ID**, and **Client Secret**
|
||||
</Step>
|
||||
|
||||
<Step title="Set Environment Variables">
|
||||
```yaml
|
||||
envVars:
|
||||
AUTH_PROVIDER: "auth0"
|
||||
|
||||
secrets:
|
||||
AUTH0_CLIENT_ID: "<Auth0 client ID>"
|
||||
AUTH0_CLIENT_SECRET: "<Auth0 client secret>"
|
||||
AUTH0_DOMAIN: "<your-tenant.auth0.com>"
|
||||
```
|
||||
</Step>
|
||||
|
||||
<Step title="Enable CLI Support (Optional)">
|
||||
1. Create a **Native** application in Auth0 for Device Authorization
|
||||
2. Enable the **Device Authorization** grant type under application settings
|
||||
3. Configure the CLI with the appropriate audience and client ID
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
---
|
||||
|
||||
## CLI Authentication
|
||||
|
||||
The CrewAI CLI supports SSO authentication via the **Device Authorization Grant** flow. This allows developers to authenticate from their terminal without exposing credentials.
|
||||
|
||||
### Quick Setup
|
||||
|
||||
For Factory installations, the CLI can auto-configure all OAuth2 settings:
|
||||
|
||||
```bash
|
||||
crewai enterprise configure https://your-factory-url.app
|
||||
```
|
||||
|
||||
This command fetches the SSO configuration from your Factory instance and sets all required CLI parameters automatically.
|
||||
|
||||
Then authenticate:
|
||||
|
||||
```bash
|
||||
crewai login
|
||||
```
|
||||
|
||||
<Note>
|
||||
Requires CrewAI CLI version **1.6.0** or higher for Entra ID, **0.159.0** or higher for Okta, and **1.9.0** or higher for Keycloak.
|
||||
</Note>
|
||||
|
||||
### Manual CLI Configuration
|
||||
|
||||
If you need to configure the CLI manually, use `crewai config set`:
|
||||
|
||||
```bash
|
||||
# Set the provider
|
||||
crewai config set oauth2_provider okta
|
||||
|
||||
# Set provider-specific values
|
||||
crewai config set oauth2_domain your-domain.okta.com
|
||||
crewai config set oauth2_client_id your-client-id
|
||||
crewai config set oauth2_audience api://default
|
||||
|
||||
# Set the enterprise base URL
|
||||
crewai config set enterprise_base_url https://your-factory-url.app
|
||||
```
|
||||
|
||||
### CLI Configuration Reference
|
||||
|
||||
| Setting | Description | Example |
|
||||
|---|---|---|
|
||||
| `enterprise_base_url` | Your CrewAI instance URL | `https://crewai.yourcompany.com` |
|
||||
| `oauth2_provider` | Provider name | `workos`, `okta`, `auth0`, `entra_id`, `keycloak` |
|
||||
| `oauth2_domain` | Provider domain | `your-domain.okta.com` |
|
||||
| `oauth2_client_id` | OAuth2 client ID | `0oaqnwji7pGW7VT6T697` |
|
||||
| `oauth2_audience` | API audience identifier | `api://default` |
|
||||
|
||||
View current configuration:
|
||||
|
||||
```bash
|
||||
crewai config list
|
||||
```
|
||||
|
||||
### How Device Authorization Works
|
||||
|
||||
1. Run `crewai login` — the CLI requests a device code from your IdP
|
||||
2. A verification URL and code are displayed in your terminal
|
||||
3. Your browser opens to the verification URL
|
||||
4. Enter the code and authenticate with your IdP credentials
|
||||
5. The CLI receives an access token and stores it locally
|
||||
|
||||
---
|
||||
|
||||
## Role-Based Access Control (RBAC)
|
||||
|
||||
CrewAI Platform provides granular RBAC that integrates with your SSO provider.
|
||||
|
||||
### Permission Model
|
||||
|
||||
| Permission | Description |
|
||||
|---|---|
|
||||
| **Read** | View resources (dashboards, automations, logs) |
|
||||
| **Write** | Create and modify resources |
|
||||
| **Manage** | Full control including deletion and configuration |
|
||||
|
||||
### Resources
|
||||
|
||||
Permissions can be scoped to individual resources:
|
||||
|
||||
- **Usage Dashboard** — Platform usage metrics and analytics
|
||||
- **Automations Dashboard** — Crew and flow management
|
||||
- **Environment Variables** — Secret and configuration management
|
||||
- **Individual Automations** — Per-automation access control
|
||||
|
||||
### Roles
|
||||
|
||||
- **Predefined roles** come out of the box with standard permission sets
|
||||
- **Custom roles** can be created with any combination of permissions
|
||||
- **Per-resource assignment** — limit specific automations to individual users or roles
|
||||
|
||||
### Factory Admin Access
|
||||
|
||||
For Factory deployments using Entra ID, admin access is controlled via App Roles:
|
||||
|
||||
- Assign the `factory-admin` role to users who need admin panel access
|
||||
- Assign the `member` role for standard platform access
|
||||
- Roles are communicated via JWT claims — no additional configuration needed after IdP setup
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Invalid Redirect URI
|
||||
|
||||
**Symptom:** Authentication fails with a redirect URI mismatch error.
|
||||
|
||||
**Fix:** Ensure the redirect URI in your IdP exactly matches the expected callback URL:
|
||||
|
||||
| Provider | Callback URL |
|
||||
|---|---|
|
||||
| Entra ID | `https://<domain>/auth/entra_id/callback` |
|
||||
| Okta | `https://<domain>/auth/okta/callback` |
|
||||
| Keycloak | `https://<domain>/auth/keycloak/callback` |
|
||||
| WorkOS | `https://<domain>/auth/workos/callback` |
|
||||
| Auth0 | `https://<domain>/auth/auth0/callback` |
|
||||
|
||||
### CLI Login Fails (Device Authorization)
|
||||
|
||||
**Symptom:** `crewai login` returns an error or times out.
|
||||
|
||||
**Fix:**
|
||||
- Verify that Device Authorization Grant is enabled in your IdP
|
||||
- For Okta: ensure you have a **Native Application** (not Web) with Device Authorization grant
|
||||
- For Entra ID: ensure **Allow public client flows** is enabled
|
||||
- For Keycloak: ensure the CLI client has **Client authentication: Off** and only Device Authorization Grant enabled
|
||||
- Check that `*_DEVICE_AUTHORIZATION_CLIENT_ID` environment variable is set on the server
|
||||
|
||||
### Token Validation Errors
|
||||
|
||||
**Symptom:** `Invalid token: Signature verification failed` or `401 Unauthorized` after login.
|
||||
|
||||
**Fix:**
|
||||
- **Okta:** Verify `OKTA_AUTHORIZATION_SERVER` and `OKTA_AUDIENCE` match the authorization server's Name and Audience exactly
|
||||
- **Entra ID:** Ensure `accessTokenAcceptedVersion` is set to `2` in the app manifest
|
||||
- **Keycloak:** Verify `KEYCLOAK_AUDIENCE` matches the audience in your access tokens (default: `account`)
|
||||
|
||||
### Admin Consent Not Granted (Entra ID)
|
||||
|
||||
**Symptom:** Users can't log in, see "needs admin approval" message.
|
||||
|
||||
**Fix:** Go to **Enterprise applications** → your app → **Permissions** → **Grant admin consent**. Ensure `User.Read` is granted for Microsoft Graph.
|
||||
|
||||
### 403 Forbidden After Login
|
||||
|
||||
**Symptom:** User authenticates successfully but gets 403 errors.
|
||||
|
||||
**Fix:**
|
||||
- Check that the user is assigned to the application in your IdP
|
||||
- For Entra ID with **Assignment required = Yes**: ensure the user has a role assignment (Member or Factory Admin)
|
||||
- For Okta: verify the user or their group is assigned under the app's **Assignments** tab
|
||||
|
||||
### CLI Can't Reach Factory Instance
|
||||
|
||||
**Symptom:** `crewai enterprise configure` fails to connect.
|
||||
|
||||
**Fix:**
|
||||
- Verify the Factory URL is reachable from your machine
|
||||
- Check that `enterprise_base_url` is set correctly: `crewai config list`
|
||||
- Ensure TLS certificates are valid and trusted
|
||||
|
||||
---
|
||||
|
||||
## Environment Variables Reference
|
||||
|
||||
### Common
|
||||
|
||||
| Variable | Description |
|
||||
|---|---|
|
||||
| `AUTH_PROVIDER` | Authentication provider: `entra_id`, `okta`, `workos`, `auth0`, `keycloak`, `local` |
|
||||
|
||||
### Microsoft Entra ID
|
||||
|
||||
| Variable | Required | Description |
|
||||
|---|---|---|
|
||||
| `ENTRA_ID_CLIENT_ID` | ✅ | Application (client) ID from Azure |
|
||||
| `ENTRA_ID_CLIENT_SECRET` | ✅ | Client secret from Azure |
|
||||
| `ENTRA_ID_TENANT_ID` | ✅ | Directory (tenant) ID from Azure |
|
||||
| `ENTRA_ID_DEVICE_AUTHORIZATION_CLIENT_ID` | CLI only | Client ID for Device Authorization Grant |
|
||||
| `ENTRA_ID_CUSTOM_OPENID_SCOPE` | CLI only | Custom scope from "Expose an API" (e.g., `api://crewai-cli/read`) |
|
||||
|
||||
### Okta
|
||||
|
||||
| Variable | Required | Description |
|
||||
|---|---|---|
|
||||
| `OKTA_CLIENT_ID` | ✅ | Okta application client ID |
|
||||
| `OKTA_CLIENT_SECRET` | ✅ | Okta client secret |
|
||||
| `OKTA_SITE` | ✅ | Okta organization URL (e.g., `https://your-domain.okta.com`) |
|
||||
| `OKTA_AUTHORIZATION_SERVER` | ✅ | Authorization server name (e.g., `default`) |
|
||||
| `OKTA_AUDIENCE` | ✅ | Authorization server audience (e.g., `api://default`) |
|
||||
| `OKTA_DEVICE_AUTHORIZATION_CLIENT_ID` | CLI only | Native app client ID for Device Authorization |
|
||||
|
||||
### WorkOS
|
||||
|
||||
| Variable | Required | Description |
|
||||
|---|---|---|
|
||||
| `WORKOS_CLIENT_ID` | ✅ | WorkOS application client ID |
|
||||
| `WORKOS_AUTHKIT_DOMAIN` | ✅ | AuthKit domain (e.g., `your-domain.authkit.com`) |
|
||||
|
||||
### Auth0
|
||||
|
||||
| Variable | Required | Description |
|
||||
|---|---|---|
|
||||
| `AUTH0_CLIENT_ID` | ✅ | Auth0 application client ID |
|
||||
| `AUTH0_CLIENT_SECRET` | ✅ | Auth0 client secret |
|
||||
| `AUTH0_DOMAIN` | ✅ | Auth0 tenant domain (e.g., `your-tenant.auth0.com`) |
|
||||
|
||||
### Keycloak
|
||||
|
||||
| Variable | Required | Description |
|
||||
|---|---|---|
|
||||
| `KEYCLOAK_CLIENT_ID` | ✅ | Keycloak client ID |
|
||||
| `KEYCLOAK_CLIENT_SECRET` | ✅ | Keycloak client secret |
|
||||
| `KEYCLOAK_SITE` | ✅ | Keycloak server URL |
|
||||
| `KEYCLOAK_REALM` | ✅ | Keycloak realm name |
|
||||
| `KEYCLOAK_AUDIENCE` | ✅ | Token audience (default: `account`) |
|
||||
| `KEYCLOAK_BASE_URL` | Optional | Base URL path (e.g., `/auth` for pre-v17 migrations) |
|
||||
| `KEYCLOAK_DEVICE_AUTHORIZATION_CLIENT_ID` | CLI only | Public client ID for Device Authorization |
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Installation Guide](/installation) — Get started with CrewAI
|
||||
- [Quickstart](/quickstart) — Build your first crew
|
||||
- [RBAC Setup](/enterprise/features/rbac) — Detailed role and permission management
|
||||
@@ -44,6 +44,7 @@ from crewai.llms.constants import (
|
||||
BEDROCK_MODELS,
|
||||
GEMINI_MODELS,
|
||||
OPENAI_MODELS,
|
||||
WATSONX_MODELS,
|
||||
)
|
||||
from crewai.utilities import InternalInstructor
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -309,6 +310,8 @@ SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"aws",
|
||||
"watsonx",
|
||||
"ibm",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter",
|
||||
"deepseek",
|
||||
@@ -376,6 +379,8 @@ class LLM(BaseLLM):
|
||||
"gemini": "gemini",
|
||||
"bedrock": "bedrock",
|
||||
"aws": "bedrock",
|
||||
"watsonx": "watsonx",
|
||||
"ibm": "watsonx",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter": "openrouter",
|
||||
"deepseek": "deepseek",
|
||||
@@ -506,6 +511,12 @@ class LLM(BaseLLM):
|
||||
# OpenRouter uses org/model format but accepts anything
|
||||
return True
|
||||
|
||||
if provider == "watsonx" or provider == "ibm":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["ibm/granite", "granite"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@@ -541,6 +552,9 @@ class LLM(BaseLLM):
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
if (provider == "watsonx" or provider == "ibm") and model in WATSONX_MODELS:
|
||||
return True
|
||||
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@@ -573,6 +587,9 @@ class LLM(BaseLLM):
|
||||
if model in AZURE_MODELS:
|
||||
return "azure"
|
||||
|
||||
if model in WATSONX_MODELS:
|
||||
return "watsonx"
|
||||
|
||||
return "openai"
|
||||
|
||||
@classmethod
|
||||
@@ -605,6 +622,11 @@ class LLM(BaseLLM):
|
||||
|
||||
return BedrockCompletion
|
||||
|
||||
if provider == "watsonx" or provider == "ibm":
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
return WatsonxCompletion
|
||||
|
||||
# OpenAI-compatible providers
|
||||
openai_compatible_providers = {
|
||||
"openrouter",
|
||||
|
||||
@@ -568,3 +568,33 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"qwen.qwen3-coder-30b-a3b-v1:0",
|
||||
"twelvelabs.pegasus-1-2-v1:0",
|
||||
]
|
||||
|
||||
|
||||
WatsonxModels: TypeAlias = Literal[
|
||||
"ibm/granite-3-2b-instruct",
|
||||
"ibm/granite-3-8b-instruct",
|
||||
"ibm/granite-3-1-2b-instruct",
|
||||
"ibm/granite-3-1-8b-instruct",
|
||||
"ibm/granite-3-1-8b-base",
|
||||
"ibm/granite-3-3-2b-instruct",
|
||||
"ibm/granite-3-3-8b-instruct",
|
||||
"ibm/granite-4-h-micro",
|
||||
"ibm/granite-4-h-tiny",
|
||||
"ibm/granite-4-h-small",
|
||||
"ibm/granite-8b-code-instruct",
|
||||
"ibm/granite-guardian-3-8b",
|
||||
]
|
||||
WATSONX_MODELS: list[WatsonxModels] = [
|
||||
"ibm/granite-3-2b-instruct",
|
||||
"ibm/granite-3-8b-instruct",
|
||||
"ibm/granite-3-1-2b-instruct",
|
||||
"ibm/granite-3-1-8b-instruct",
|
||||
"ibm/granite-3-1-8b-base",
|
||||
"ibm/granite-3-3-2b-instruct",
|
||||
"ibm/granite-3-3-8b-instruct",
|
||||
"ibm/granite-4-h-micro",
|
||||
"ibm/granite-4-h-tiny",
|
||||
"ibm/granite-4-h-small",
|
||||
"ibm/granite-8b-code-instruct",
|
||||
"ibm/granite-guardian-3-8b",
|
||||
]
|
||||
|
||||
5
lib/crewai/src/crewai/llms/providers/watsonx/__init__.py
Normal file
5
lib/crewai/src/crewai/llms/providers/watsonx/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""IBM watsonx.ai provider module."""
|
||||
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
__all__ = ["WatsonxCompletion"]
|
||||
444
lib/crewai/src/crewai/llms/providers/watsonx/completion.py
Normal file
444
lib/crewai/src/crewai/llms/providers/watsonx/completion.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""IBM watsonx.ai provider implementation.
|
||||
|
||||
This module provides native support for IBM Granite models via the
|
||||
watsonx.ai Model Gateway, which exposes an OpenAI-compatible API.
|
||||
|
||||
Authentication uses IBM Cloud IAM token exchange: an API key is exchanged
|
||||
for a short-lived Bearer token via the IAM identity service.
|
||||
|
||||
Usage:
|
||||
llm = LLM(model="watsonx/ibm/granite-4-h-small")
|
||||
llm = LLM(model="ibm/granite-4-h-small", provider="watsonx")
|
||||
|
||||
Environment variables:
|
||||
WATSONX_API_KEY: IBM Cloud API key (required)
|
||||
WATSONX_PROJECT_ID: watsonx.ai project ID (required)
|
||||
WATSONX_REGION: IBM Cloud region (default: us-south)
|
||||
WATSONX_URL: Full base URL override (optional)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# IBM Cloud IAM endpoint for token exchange
|
||||
_IAM_TOKEN_URL = "https://iam.cloud.ibm.com/identity/token"
|
||||
|
||||
# Default region for watsonx.ai
|
||||
_DEFAULT_REGION = "us-south"
|
||||
|
||||
# Refresh token 60 seconds before expiry to avoid race conditions
|
||||
_TOKEN_REFRESH_BUFFER_SECONDS = 60
|
||||
|
||||
# Supported watsonx.ai regions
|
||||
_SUPPORTED_REGIONS = frozenset({
|
||||
"us-south",
|
||||
"eu-de",
|
||||
"eu-gb",
|
||||
"jp-tok",
|
||||
"au-syd",
|
||||
})
|
||||
|
||||
|
||||
class _IAMTokenManager:
|
||||
"""Thread-safe IBM IAM token manager with automatic refresh.
|
||||
|
||||
Exchanges an IBM Cloud API key for a short-lived Bearer token and
|
||||
caches it, refreshing automatically when the token approaches expiry.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self._api_key = api_key
|
||||
self._token: str | None = None
|
||||
self._expiry: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_token(self) -> str:
|
||||
"""Get a valid IAM Bearer token, refreshing if needed.
|
||||
|
||||
Returns:
|
||||
A valid Bearer token string.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the token exchange fails.
|
||||
"""
|
||||
if self._token and time.time() < self._expiry - _TOKEN_REFRESH_BUFFER_SECONDS:
|
||||
return self._token
|
||||
|
||||
with self._lock:
|
||||
# Double-check after acquiring lock
|
||||
if (
|
||||
self._token
|
||||
and time.time() < self._expiry - _TOKEN_REFRESH_BUFFER_SECONDS
|
||||
):
|
||||
return self._token
|
||||
|
||||
self._refresh_token()
|
||||
assert self._token is not None
|
||||
return self._token
|
||||
|
||||
def _refresh_token(self) -> None:
|
||||
"""Exchange API key for a new IAM token."""
|
||||
try:
|
||||
response = httpx.post(
|
||||
_IAM_TOKEN_URL,
|
||||
data={
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": self._api_key,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise RuntimeError(
|
||||
f"IBM IAM token exchange failed (HTTP {e.response.status_code}): "
|
||||
f"{e.response.text}"
|
||||
) from e
|
||||
except httpx.HTTPError as e:
|
||||
raise RuntimeError(
|
||||
f"IBM IAM token exchange request failed: {e}"
|
||||
) from e
|
||||
|
||||
data = response.json()
|
||||
self._token = data["access_token"]
|
||||
self._expiry = float(data["expiration"])
|
||||
logger.debug(
|
||||
"IBM IAM token refreshed, expires at %s",
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(self._expiry)),
|
||||
)
|
||||
|
||||
|
||||
class WatsonxCompletion(OpenAICompletion):
|
||||
"""IBM watsonx.ai completion implementation.
|
||||
|
||||
This class provides support for IBM Granite models and other foundation
|
||||
models hosted on watsonx.ai via the OpenAI-compatible Model Gateway.
|
||||
|
||||
Authentication is handled transparently via IBM Cloud IAM token exchange.
|
||||
The API key is exchanged for a Bearer token which is automatically
|
||||
refreshed when it approaches expiry.
|
||||
|
||||
Supported models include the IBM Granite family:
|
||||
- ibm/granite-4-h-small (32B hybrid)
|
||||
- ibm/granite-4-h-tiny (7B hybrid)
|
||||
- ibm/granite-4-h-micro (3B hybrid)
|
||||
- ibm/granite-3-8b-instruct
|
||||
- ibm/granite-3-3-8b-instruct
|
||||
- ibm/granite-8b-code-instruct
|
||||
- ibm/granite-guardian-3-8b
|
||||
- And other models available on watsonx.ai
|
||||
|
||||
Example:
|
||||
# Using provider prefix
|
||||
llm = LLM(model="watsonx/ibm/granite-4-h-small")
|
||||
|
||||
# Using explicit provider
|
||||
llm = LLM(model="ibm/granite-4-h-small", provider="watsonx")
|
||||
|
||||
# With custom configuration
|
||||
llm = LLM(
|
||||
model="ibm/granite-4-h-small",
|
||||
provider="watsonx",
|
||||
api_key="my-ibm-cloud-api-key",
|
||||
temperature=0.7,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str = "watsonx",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
project_id: str | None = None,
|
||||
region: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize watsonx.ai completion client.
|
||||
|
||||
Args:
|
||||
model: The model identifier (e.g., "ibm/granite-4-h-small").
|
||||
provider: The provider name (default: "watsonx").
|
||||
api_key: IBM Cloud API key. If not provided, reads from
|
||||
WATSONX_API_KEY environment variable.
|
||||
base_url: Full base URL override for the watsonx.ai endpoint.
|
||||
If not provided, constructed from region.
|
||||
project_id: watsonx.ai project ID. If not provided, reads from
|
||||
WATSONX_PROJECT_ID environment variable.
|
||||
region: IBM Cloud region (default: "us-south"). If not provided,
|
||||
reads from WATSONX_REGION environment variable.
|
||||
**kwargs: Additional arguments passed to OpenAICompletion.
|
||||
|
||||
Raises:
|
||||
ValueError: If required credentials are missing.
|
||||
"""
|
||||
resolved_api_key = self._resolve_api_key(api_key)
|
||||
resolved_project_id = self._resolve_project_id(project_id)
|
||||
resolved_region = self._resolve_region(region)
|
||||
resolved_base_url = self._resolve_base_url(base_url, resolved_region)
|
||||
|
||||
# Initialize IAM token manager for transparent auth
|
||||
self._iam_manager = _IAMTokenManager(resolved_api_key)
|
||||
self._project_id = resolved_project_id
|
||||
|
||||
# Get initial token for client construction
|
||||
initial_token = self._iam_manager.get_token()
|
||||
|
||||
# Pass the bearer token as api_key to OpenAI client
|
||||
# The OpenAI SDK uses this as Authorization: Bearer <token>
|
||||
super().__init__(
|
||||
model=model,
|
||||
provider=provider,
|
||||
api_key=initial_token,
|
||||
base_url=resolved_base_url,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_api_key(api_key: str | None) -> str:
|
||||
"""Resolve IBM Cloud API key from parameter or environment.
|
||||
|
||||
Args:
|
||||
api_key: Explicitly provided API key.
|
||||
|
||||
Returns:
|
||||
The resolved API key.
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is found.
|
||||
"""
|
||||
resolved = api_key or os.getenv("WATSONX_API_KEY")
|
||||
if not resolved:
|
||||
raise ValueError(
|
||||
"IBM Cloud API key is required for watsonx.ai provider. "
|
||||
"Set the WATSONX_API_KEY environment variable or pass "
|
||||
"api_key parameter."
|
||||
)
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _resolve_project_id(project_id: str | None) -> str:
|
||||
"""Resolve watsonx.ai project ID from parameter or environment.
|
||||
|
||||
Args:
|
||||
project_id: Explicitly provided project ID.
|
||||
|
||||
Returns:
|
||||
The resolved project ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If no project ID is found.
|
||||
"""
|
||||
resolved = project_id or os.getenv("WATSONX_PROJECT_ID")
|
||||
if not resolved:
|
||||
raise ValueError(
|
||||
"watsonx.ai project ID is required. "
|
||||
"Set the WATSONX_PROJECT_ID environment variable or pass "
|
||||
"project_id parameter."
|
||||
)
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _resolve_region(region: str | None) -> str:
|
||||
"""Resolve IBM Cloud region from parameter or environment.
|
||||
|
||||
Args:
|
||||
region: Explicitly provided region.
|
||||
|
||||
Returns:
|
||||
The resolved region string.
|
||||
"""
|
||||
resolved = region or os.getenv("WATSONX_REGION", _DEFAULT_REGION)
|
||||
if resolved not in _SUPPORTED_REGIONS:
|
||||
logger.warning(
|
||||
"Region '%s' is not in the known supported regions: %s. "
|
||||
"Proceeding anyway in case IBM has added new regions.",
|
||||
resolved,
|
||||
", ".join(sorted(_SUPPORTED_REGIONS)),
|
||||
)
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _resolve_base_url(base_url: str | None, region: str) -> str:
|
||||
"""Resolve the watsonx.ai base URL.
|
||||
|
||||
Priority:
|
||||
1. Explicit base_url parameter
|
||||
2. WATSONX_URL environment variable
|
||||
3. Constructed from region
|
||||
|
||||
Args:
|
||||
base_url: Explicitly provided base URL.
|
||||
region: IBM Cloud region for URL construction.
|
||||
|
||||
Returns:
|
||||
The resolved base URL.
|
||||
"""
|
||||
if base_url:
|
||||
return base_url.rstrip("/")
|
||||
|
||||
env_url = os.getenv("WATSONX_URL")
|
||||
if env_url:
|
||||
return env_url.rstrip("/")
|
||||
|
||||
return f"https://{region}.ml.cloud.ibm.com/ml/v1"
|
||||
|
||||
def _build_client(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
) -> OpenAI:
|
||||
"""Build the OpenAI client with watsonx-specific configuration.
|
||||
|
||||
Overrides the parent method to inject the project_id header
|
||||
and ensure the IAM token is current.
|
||||
|
||||
Args:
|
||||
api_key: Bearer token (from IAM exchange).
|
||||
base_url: watsonx.ai endpoint URL.
|
||||
default_headers: Additional headers.
|
||||
|
||||
Returns:
|
||||
Configured OpenAI client instance.
|
||||
"""
|
||||
# Refresh token if needed
|
||||
current_token = self._iam_manager.get_token()
|
||||
|
||||
# Merge watsonx-specific headers
|
||||
watsonx_headers = {
|
||||
"X-Watsonx-Project-Id": self._project_id,
|
||||
}
|
||||
if default_headers:
|
||||
watsonx_headers.update(default_headers)
|
||||
|
||||
return super()._build_client(
|
||||
api_key=current_token,
|
||||
base_url=base_url,
|
||||
default_headers=watsonx_headers,
|
||||
)
|
||||
|
||||
def _ensure_fresh_token(self) -> None:
|
||||
"""Refresh the IAM token on the client if needed.
|
||||
|
||||
Updates the client's API key (Bearer token) if the cached
|
||||
token has been refreshed.
|
||||
"""
|
||||
current_token = self._iam_manager.get_token()
|
||||
if hasattr(self, "client") and self.client is not None:
|
||||
self.client.api_key = current_token
|
||||
|
||||
def call(self, messages, tools=None, callbacks=None, available_functions=None,
|
||||
from_task=None, from_agent=None, response_model=None):
|
||||
"""Call the LLM, refreshing the IAM token if needed."""
|
||||
self._ensure_fresh_token()
|
||||
return super().call(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
async def acall(self, messages, tools=None, callbacks=None, available_functions=None,
|
||||
from_task=None, from_agent=None, response_model=None):
|
||||
"""Async call the LLM, refreshing the IAM token if needed."""
|
||||
self._ensure_fresh_token()
|
||||
return await super().acall(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get context window size for Granite models.
|
||||
|
||||
Returns:
|
||||
The context window size in tokens.
|
||||
"""
|
||||
model_lower = self.model.lower()
|
||||
|
||||
# Granite 4.x models have 128K context
|
||||
if "granite-4" in model_lower:
|
||||
return 131072
|
||||
|
||||
# Granite 3.x instruct models have 128K context
|
||||
if "granite-3" in model_lower and "instruct" in model_lower:
|
||||
return 131072
|
||||
|
||||
# Granite 3.x base models have 4K context
|
||||
if "granite-3" in model_lower:
|
||||
return 4096
|
||||
|
||||
# Granite code models
|
||||
if "granite" in model_lower and "code" in model_lower:
|
||||
return 8192
|
||||
|
||||
# Default for unknown models
|
||||
return 8192
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling / tool use.
|
||||
|
||||
Granite 3.x instruct and 4.x models support tool use.
|
||||
|
||||
Returns:
|
||||
True if the model supports function calling.
|
||||
"""
|
||||
model_lower = self.model.lower()
|
||||
|
||||
# Granite 4.x models support tool use
|
||||
if "granite-4" in model_lower:
|
||||
return True
|
||||
|
||||
# Granite 3.x instruct models support tool use
|
||||
if "granite-3" in model_lower and "instruct" in model_lower:
|
||||
return True
|
||||
|
||||
# Granite guardian models don't do tool use
|
||||
if "guardian" in model_lower:
|
||||
return False
|
||||
|
||||
# Default: assume no tool use for unknown models
|
||||
return False
|
||||
|
||||
def supports_multimodal(self) -> bool:
|
||||
"""Check if the model supports multimodal inputs.
|
||||
|
||||
Currently, Granite models are text-only.
|
||||
|
||||
Returns:
|
||||
False (Granite models are text-only).
|
||||
"""
|
||||
return False
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Serialize this LLM to a dict for reconstruction.
|
||||
|
||||
Returns:
|
||||
Configuration dict with watsonx-specific fields.
|
||||
"""
|
||||
config = super().to_config_dict()
|
||||
config["model"] = f"watsonx/{self.model}" if "/" not in self.model else f"watsonx/{self.model}"
|
||||
return config
|
||||
0
lib/crewai/tests/llms/providers/watsonx/__init__.py
Normal file
0
lib/crewai/tests/llms/providers/watsonx/__init__.py
Normal file
293
lib/crewai/tests/llms/providers/watsonx/test_completion.py
Normal file
293
lib/crewai/tests/llms/providers/watsonx/test_completion.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Tests for IBM watsonx.ai provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
class TestIAMTokenManager:
|
||||
"""Tests for the IAM token manager."""
|
||||
|
||||
def test_token_exchange_success(self):
|
||||
"""Test successful IAM token exchange."""
|
||||
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "test-bearer-token",
|
||||
"expiration": time.time() + 3600,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.post", return_value=mock_response) as mock_post:
|
||||
manager = _IAMTokenManager("test-api-key")
|
||||
token = manager.get_token()
|
||||
|
||||
assert token == "test-bearer-token"
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
assert call_kwargs[1]["data"]["apikey"] == "test-api-key"
|
||||
assert (
|
||||
call_kwargs[1]["data"]["grant_type"]
|
||||
== "urn:ibm:params:oauth:grant-type:apikey"
|
||||
)
|
||||
|
||||
def test_token_caching(self):
|
||||
"""Test that tokens are cached and not re-fetched."""
|
||||
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "cached-token",
|
||||
"expiration": time.time() + 3600,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.post", return_value=mock_response) as mock_post:
|
||||
manager = _IAMTokenManager("test-api-key")
|
||||
|
||||
# First call - should fetch
|
||||
token1 = manager.get_token()
|
||||
# Second call - should use cache
|
||||
token2 = manager.get_token()
|
||||
|
||||
assert token1 == token2 == "cached-token"
|
||||
assert mock_post.call_count == 1 # Only one HTTP call
|
||||
|
||||
def test_token_refresh_on_expiry(self):
|
||||
"""Test that expired tokens are refreshed."""
|
||||
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
|
||||
|
||||
call_count = 0
|
||||
|
||||
def mock_post(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"access_token": f"token-{call_count}",
|
||||
"expiration": time.time() + (0 if call_count == 1 else 3600),
|
||||
}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
return mock_resp
|
||||
|
||||
with patch("httpx.post", side_effect=mock_post):
|
||||
manager = _IAMTokenManager("test-api-key")
|
||||
|
||||
# First call - gets token-1 which is already expired
|
||||
token1 = manager.get_token()
|
||||
assert token1 == "token-1"
|
||||
|
||||
# Second call - token-1 is expired, should refresh to token-2
|
||||
token2 = manager.get_token()
|
||||
assert token2 == "token-2"
|
||||
assert call_count == 2
|
||||
|
||||
def test_token_exchange_http_error(self):
|
||||
"""Test that HTTP errors during token exchange raise RuntimeError."""
|
||||
from crewai.llms.providers.watsonx.completion import _IAMTokenManager
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"401", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch("httpx.post", return_value=mock_response):
|
||||
manager = _IAMTokenManager("bad-api-key")
|
||||
with pytest.raises(RuntimeError, match="IBM IAM token exchange failed"):
|
||||
manager.get_token()
|
||||
|
||||
|
||||
class TestWatsonxCompletionInit:
|
||||
"""Tests for WatsonxCompletion initialization."""
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_missing_api_key_raises(self):
|
||||
"""Test that missing API key raises ValueError."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
with pytest.raises(ValueError, match="IBM Cloud API key is required"):
|
||||
WatsonxCompletion(model="ibm/granite-4-h-small")
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"WATSONX_API_KEY": "test-key"},
|
||||
clear=True,
|
||||
)
|
||||
def test_missing_project_id_raises(self):
|
||||
"""Test that missing project ID raises ValueError."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
with pytest.raises(ValueError, match="project ID is required"):
|
||||
WatsonxCompletion(model="ibm/granite-4-h-small")
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WATSONX_API_KEY": "test-key",
|
||||
"WATSONX_PROJECT_ID": "test-project",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_default_region_url(self):
|
||||
"""Test that default region constructs correct URL."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "test-token",
|
||||
"expiration": time.time() + 3600,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.post", return_value=mock_response):
|
||||
with patch(
|
||||
"crewai.llms.providers.openai.completion.OpenAICompletion.__init__",
|
||||
return_value=None,
|
||||
) as mock_init:
|
||||
completion = WatsonxCompletion.__new__(WatsonxCompletion)
|
||||
# Manually set _iam_manager and _project_id since we skip __init__
|
||||
# Instead, test the static method directly
|
||||
url = WatsonxCompletion._resolve_base_url(None, "us-south")
|
||||
assert url == "https://us-south.ml.cloud.ibm.com/ml/v1"
|
||||
|
||||
def test_resolve_base_url_custom_region(self):
|
||||
"""Test URL construction with custom region."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
url = WatsonxCompletion._resolve_base_url(None, "eu-de")
|
||||
assert url == "https://eu-de.ml.cloud.ibm.com/ml/v1"
|
||||
|
||||
def test_resolve_base_url_explicit(self):
|
||||
"""Test that explicit base_url takes priority."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
url = WatsonxCompletion._resolve_base_url(
|
||||
"https://custom.example.com/v1", "us-south"
|
||||
)
|
||||
assert url == "https://custom.example.com/v1"
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"WATSONX_URL": "https://env-override.example.com/v1"},
|
||||
clear=True,
|
||||
)
|
||||
def test_resolve_base_url_env_override(self):
|
||||
"""Test that WATSONX_URL env var overrides region-based URL."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
url = WatsonxCompletion._resolve_base_url(None, "us-south")
|
||||
assert url == "https://env-override.example.com/v1"
|
||||
|
||||
def test_resolve_region_default(self):
|
||||
"""Test default region resolution."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
region = WatsonxCompletion._resolve_region(None)
|
||||
assert region == "us-south"
|
||||
|
||||
@patch.dict(os.environ, {"WATSONX_REGION": "eu-gb"}, clear=True)
|
||||
def test_resolve_region_from_env(self):
|
||||
"""Test region resolution from environment variable."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
region = WatsonxCompletion._resolve_region(None)
|
||||
assert region == "eu-gb"
|
||||
|
||||
def test_resolve_region_explicit(self):
|
||||
"""Test explicit region parameter takes priority."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
region = WatsonxCompletion._resolve_region("jp-tok")
|
||||
assert region == "jp-tok"
|
||||
|
||||
|
||||
class TestWatsonxModelCapabilities:
|
||||
"""Tests for model capability detection."""
|
||||
|
||||
def _make_completion(self, model: str) -> object:
|
||||
"""Create a minimal WatsonxCompletion-like object for testing."""
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
# Create a bare instance without calling __init__
|
||||
obj = object.__new__(WatsonxCompletion)
|
||||
obj.model = model
|
||||
return obj
|
||||
|
||||
def test_granite_4_context_window(self):
|
||||
"""Test Granite 4.x models report 128K context."""
|
||||
comp = self._make_completion("ibm/granite-4-h-small")
|
||||
assert comp.get_context_window_size() == 131072
|
||||
|
||||
def test_granite_3_instruct_context_window(self):
|
||||
"""Test Granite 3.x instruct models report 128K context."""
|
||||
comp = self._make_completion("ibm/granite-3-8b-instruct")
|
||||
assert comp.get_context_window_size() == 131072
|
||||
|
||||
def test_granite_code_context_window(self):
|
||||
"""Test Granite code models report 8K context."""
|
||||
comp = self._make_completion("ibm/granite-8b-code-instruct")
|
||||
assert comp.get_context_window_size() == 8192
|
||||
|
||||
def test_granite_4_supports_function_calling(self):
|
||||
"""Test Granite 4.x models support function calling."""
|
||||
comp = self._make_completion("ibm/granite-4-h-small")
|
||||
assert comp.supports_function_calling() is True
|
||||
|
||||
def test_granite_3_instruct_supports_function_calling(self):
|
||||
"""Test Granite 3.x instruct models support function calling."""
|
||||
comp = self._make_completion("ibm/granite-3-8b-instruct")
|
||||
assert comp.supports_function_calling() is True
|
||||
|
||||
def test_granite_guardian_no_function_calling(self):
|
||||
"""Test Granite Guardian models don't support function calling."""
|
||||
comp = self._make_completion("ibm/granite-guardian-3-8b")
|
||||
assert comp.supports_function_calling() is False
|
||||
|
||||
def test_granite_not_multimodal(self):
|
||||
"""Test Granite models are not multimodal."""
|
||||
comp = self._make_completion("ibm/granite-4-h-small")
|
||||
assert comp.supports_multimodal() is False
|
||||
|
||||
|
||||
class TestWatsonxModelRouting:
|
||||
"""Tests for model routing through the LLM factory."""
|
||||
|
||||
def test_watsonx_models_in_constants(self):
|
||||
"""Test that WATSONX_MODELS is properly defined."""
|
||||
from crewai.llms.constants import WATSONX_MODELS
|
||||
|
||||
assert "ibm/granite-4-h-small" in WATSONX_MODELS
|
||||
assert "ibm/granite-3-8b-instruct" in WATSONX_MODELS
|
||||
assert "ibm/granite-guardian-3-8b" in WATSONX_MODELS
|
||||
assert len(WATSONX_MODELS) >= 10
|
||||
|
||||
def test_watsonx_in_supported_providers(self):
|
||||
"""Test that watsonx is in the supported native providers list."""
|
||||
from crewai.llm import SUPPORTED_NATIVE_PROVIDERS
|
||||
|
||||
assert "watsonx" in SUPPORTED_NATIVE_PROVIDERS
|
||||
assert "ibm" in SUPPORTED_NATIVE_PROVIDERS
|
||||
|
||||
def test_get_native_provider_watsonx(self):
|
||||
"""Test that _get_native_provider returns WatsonxCompletion."""
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.providers.watsonx.completion import WatsonxCompletion
|
||||
|
||||
assert LLM._get_native_provider("watsonx") is WatsonxCompletion
|
||||
assert LLM._get_native_provider("ibm") is WatsonxCompletion
|
||||
|
||||
def test_infer_provider_from_watsonx_model(self):
|
||||
"""Test that Granite models are inferred as watsonx provider."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
assert LLM._infer_provider_from_model("ibm/granite-4-h-small") == "watsonx"
|
||||
Reference in New Issue
Block a user