Compare commits

..

13 Commits

Author SHA1 Message Date
Greyson LaLonde
27bd16fd0c fix: use A2UI_EXTENSION_URI for catalog label 2026-03-14 21:28:45 -04:00
Greyson LaLonde
6cc0abbdc6 fix: use typing_extensions.TypedDict for Pydantic compatibility on Python < 3.12 2026-03-14 19:19:21 -04:00
Greyson Lalonde
bc20b96538 refactor: add Field descriptions, explicit defaults, and ConfigDict to A2UI models 2026-03-14 17:25:58 -04:00
Greyson Lalonde
d2e74fc0be refactor: improve typing and use instance state in A2UI client extension 2026-03-14 16:38:05 -04:00
Greyson Lalonde
afb6cbbb6e refactor: add Field descriptions, explicit defaults, and ConfigDict to A2UI catalog models 2026-03-14 15:25:20 -04:00
Greyson Lalonde
e63c6310ed style: consolidate catalog imports in A2UI conformance tests 2026-03-14 13:24:07 -04:00
Greyson LaLonde
aa01d3130f Merge branch 'main' into gl/feat/a2ui-extension 2026-03-14 03:07:44 -04:00
Greyson Lalonde
e6fa652364 fix: replace getattr with direct metadata access on DataPart 2026-03-14 02:55:35 -04:00
Greyson Lalonde
d653a40394 docs: register A2UI guide in docs navigation 2026-03-14 02:48:09 -04:00
Greyson Lalonde
b8f47aaa6b test: add A2UI schema conformance tests 2026-03-14 02:47:09 -04:00
Greyson Lalonde
4ee32fc7d1 feat: add A2UI content type to A2A utils 2026-03-14 02:43:49 -04:00
Greyson Lalonde
c8d776f62a docs: add A2UI extension guide 2026-03-14 02:42:36 -04:00
Greyson Lalonde
b2a2902f00 feat: add A2UI v0.8 extension for declarative UI generation 2026-03-14 02:41:48 -04:00
144 changed files with 4099 additions and 5926 deletions

View File

@@ -55,7 +55,6 @@ jobs:
echo "${{ steps.changed-files.outputs.files }}" \
| tr ' ' '\n' \
| grep -v 'src/crewai/cli/templates/' \
| grep -v 'src/crewai_cli/templates/' \
| grep -v '/tests/' \
| xargs -I{} uv run ruff check "{}"

View File

@@ -19,7 +19,7 @@ repos:
language: system
pass_filenames: true
types: [python]
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/cli/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.9.3
hooks:

View File

@@ -226,7 +226,7 @@ def vcr_cassette_dir(request: Any) -> str:
for parent in test_file.parents:
if (
parent.name in ("crewai", "crewai-tools", "crewai-files", "cli")
parent.name in ("crewai", "crewai-tools", "crewai-files")
and parent.parent.name == "lib"
):
package_root = parent

View File

@@ -351,7 +351,9 @@
"en/learn/using-annotations",
"en/learn/execution-hooks",
"en/learn/llm-hooks",
"en/learn/tool-hooks"
"en/learn/tool-hooks",
"en/learn/a2a-agent-delegation",
"en/learn/a2ui"
]
},
{
@@ -810,7 +812,9 @@
"en/learn/using-annotations",
"en/learn/execution-hooks",
"en/learn/llm-hooks",
"en/learn/tool-hooks"
"en/learn/tool-hooks",
"en/learn/a2a-agent-delegation",
"en/learn/a2ui"
]
},
{

344
docs/en/learn/a2ui.mdx Normal file
View File

@@ -0,0 +1,344 @@
---
title: Agent-to-UI (A2UI) Protocol
description: Enable agents to generate declarative UI surfaces for rich client rendering via the A2UI extension.
icon: window-restore
mode: "wide"
---
## A2UI Overview
A2UI is a declarative UI protocol extension for [A2A](/en/learn/a2a-agent-delegation) that lets agents emit structured JSON messages describing interactive surfaces. Clients receive these messages and render them as rich UI components — forms, cards, lists, modals, and more — without the agent needing to know anything about the client's rendering stack.
A2UI is built on the A2A extension mechanism and identified by the URI `https://a2ui.org/a2a-extension/a2ui/v0.8`.
<Note>
A2UI requires the `a2a-sdk` package. Install with: `uv add 'crewai[a2a]'` or `pip install 'crewai[a2a]'`
</Note>
## How It Works
1. The **server extension** scans agent output for A2UI JSON objects
2. Valid messages are wrapped as `DataPart` entries with the `application/json+a2ui` MIME type
3. The **client extension** augments the agent's system prompt with A2UI instructions and the component catalog
4. The client tracks surface state (active surfaces and data models) across conversation turns
## Server Setup
Add `A2UIServerExtension` to your `A2AServerConfig` to enable A2UI output:
```python Code
from crewai import Agent
from crewai.a2a import A2AServerConfig
from crewai.a2a.extensions.a2ui import A2UIServerExtension
agent = Agent(
role="Dashboard Agent",
goal="Present data through interactive UI surfaces",
backstory="Expert at building clear, actionable dashboards",
llm="gpt-4o",
a2a=A2AServerConfig(
url="https://your-server.com",
extensions=[A2UIServerExtension()],
),
)
```
### Server Extension Options
<ParamField path="catalog_ids" type="list[str] | None" default="None">
Component catalog identifiers the server supports. When set, only these catalogs are advertised to clients.
</ParamField>
<ParamField path="accept_inline_catalogs" type="bool" default="False">
Whether to accept inline catalog definitions from clients in addition to named catalogs.
</ParamField>
## Client Setup
Add `A2UIClientExtension` to your `A2AClientConfig` to enable A2UI rendering:
```python Code
from crewai import Agent
from crewai.a2a import A2AClientConfig
from crewai.a2a.extensions.a2ui import A2UIClientExtension
agent = Agent(
role="UI Coordinator",
goal="Coordinate tasks and render agent responses as rich UI",
backstory="Expert at presenting agent output in interactive formats",
llm="gpt-4o",
a2a=A2AClientConfig(
endpoint="https://dashboard-agent.example.com/.well-known/agent-card.json",
extensions=[A2UIClientExtension()],
),
)
```
### Client Extension Options
<ParamField path="catalog_id" type="str | None" default="None">
Preferred component catalog identifier. Defaults to `"standard (v0.8)"` when not set.
</ParamField>
<ParamField path="allowed_components" type="list[str] | None" default="None">
Restrict which components the agent may use. When `None`, all 18 standard catalog components are available.
</ParamField>
## Message Types
A2UI defines four server-to-client message types. Each message targets a **surface** identified by `surfaceId`.
<Tabs>
<Tab title="beginRendering">
Initializes a new surface with a root component and optional styles.
```json
{
"beginRendering": {
"surfaceId": "dashboard-1",
"root": "main-column",
"catalogId": "standard (v0.8)",
"styles": {
"primaryColor": "#EB6658"
}
}
}
```
</Tab>
<Tab title="surfaceUpdate">
Sends or updates one or more components on an existing surface.
```json
{
"surfaceUpdate": {
"surfaceId": "dashboard-1",
"components": [
{
"id": "main-column",
"component": {
"Column": {
"children": { "explicitList": ["title", "content"] },
"alignment": "start"
}
}
},
{
"id": "title",
"component": {
"Text": {
"text": { "literalString": "Dashboard" },
"usageHint": "h1"
}
}
}
]
}
}
```
</Tab>
<Tab title="dataModelUpdate">
Updates the data model bound to a surface, enabling dynamic content.
```json
{
"dataModelUpdate": {
"surfaceId": "dashboard-1",
"path": "/data/model",
"contents": [
{
"key": "userName",
"valueString": "Alice"
},
{
"key": "score",
"valueNumber": 42
}
]
}
}
```
</Tab>
<Tab title="deleteSurface">
Removes a surface and all its components.
```json
{
"deleteSurface": {
"surfaceId": "dashboard-1"
}
}
```
</Tab>
</Tabs>
## Component Catalog
A2UI ships with 18 standard components organized into three categories:
### Content
| Component | Description | Required Fields |
|-----------|-------------|-----------------|
| **Text** | Renders text with optional heading/body hints | `text` (StringBinding) |
| **Image** | Displays an image with fit and size options | `url` (StringBinding) |
| **Icon** | Renders a named icon from a set of 47 icons | `name` (IconBinding) |
| **Video** | Embeds a video player | `url` (StringBinding) |
| **AudioPlayer** | Embeds an audio player with optional description | `url` (StringBinding) |
### Layout
| Component | Description | Required Fields |
|-----------|-------------|-----------------|
| **Row** | Horizontal flex container | `children` (ChildrenDef) |
| **Column** | Vertical flex container | `children` (ChildrenDef) |
| **List** | Scrollable list (vertical or horizontal) | `children` (ChildrenDef) |
| **Card** | Elevated container for a single child | `child` (str) |
| **Tabs** | Tabbed container | `tabItems` (list of TabItem) |
| **Divider** | Visual separator (horizontal or vertical) | — |
| **Modal** | Overlay triggered by an entry point | `entryPointChild`, `contentChild` (str) |
### Interactive
| Component | Description | Required Fields |
|-----------|-------------|-----------------|
| **Button** | Clickable button that triggers an action | `child` (str), `action` (Action) |
| **CheckBox** | Boolean toggle with a label | `label` (StringBinding), `value` (BooleanBinding) |
| **TextField** | Text input with type and validation options | `label` (StringBinding) |
| **DateTimeInput** | Date and/or time picker | `value` (StringBinding) |
| **MultipleChoice** | Selection from a list of options | `selections` (ArrayBinding), `options` (list) |
| **Slider** | Numeric range slider | `value` (NumberBinding) |
## Data Binding
Components reference values through **bindings** rather than raw literals. This allows surfaces to update dynamically when the data model changes.
There are two ways to bind a value:
- **Literal values** — hardcoded directly in the component definition
- **Path references** — point to a key in the surface's data model
```json
{
"surfaceUpdate": {
"surfaceId": "profile-1",
"components": [
{
"id": "greeting",
"component": {
"Text": {
"text": { "path": "/data/model/userName" },
"usageHint": "h2"
}
}
},
{
"id": "status",
"component": {
"Text": {
"text": { "literalString": "Online" },
"usageHint": "caption"
}
}
}
]
}
}
```
In this example, `greeting` reads the user's name from the data model (updated via `dataModelUpdate`), while `status` uses a hardcoded literal.
## Handling User Actions
Interactive components like `Button` trigger `userAction` events that flow back to the server. Each action includes a `name`, the originating `surfaceId` and `sourceComponentId`, and an optional `context` with key-value pairs.
```json
{
"userAction": {
"name": "submitForm",
"surfaceId": "form-1",
"sourceComponentId": "submit-btn",
"timestamp": "2026-03-12T10:00:00Z",
"context": {
"selectedOption": "optionA"
}
}
}
```
Action context values can also use path bindings to send current data model values back to the server:
```json
{
"Button": {
"child": "confirm-label",
"action": {
"name": "confirm",
"context": [
{
"key": "currentScore",
"value": { "path": "/data/model/score" }
}
]
}
}
}
```
## Validation
Use `validate_a2ui_message` to validate server-to-client messages and `validate_a2ui_event` for client-to-server events:
```python Code
from crewai.a2a.extensions.a2ui import validate_a2ui_message
from crewai.a2a.extensions.a2ui.validator import (
validate_a2ui_event,
A2UIValidationError,
)
# Validate a server message
try:
msg = validate_a2ui_message({"beginRendering": {"surfaceId": "s1", "root": "r1"}})
except A2UIValidationError as exc:
print(exc.errors)
# Validate a client event
try:
event = validate_a2ui_event({
"userAction": {
"name": "click",
"surfaceId": "s1",
"sourceComponentId": "btn-1",
"timestamp": "2026-03-12T10:00:00Z",
}
})
except A2UIValidationError as exc:
print(exc.errors)
```
## Best Practices
<CardGroup cols={2}>
<Card title="Start Simple" icon="play">
Begin with a `beginRendering` message and a single `surfaceUpdate`. Add data binding and interactivity once the basic flow works.
</Card>
<Card title="Use Data Binding for Dynamic Content" icon="arrows-rotate">
Prefer path bindings over literal values for content that changes. Use `dataModelUpdate` to push new values without resending the full component tree.
</Card>
<Card title="Filter Components" icon="filter">
Use the `allowed_components` option on `A2UIClientExtension` to restrict which components the agent may emit, reducing prompt size and keeping output predictable.
</Card>
<Card title="Validate Messages" icon="check">
Use `validate_a2ui_message` and `validate_a2ui_event` to catch malformed payloads early, especially when building custom integrations.
</Card>
</CardGroup>
## Learn More
- [A2A Agent Delegation](/en/learn/a2a-agent-delegation) — configure agents for remote delegation via the A2A protocol
- [A2A Protocol Documentation](https://a2a-protocol.org) — official protocol specification

View File

@@ -1,15 +0,0 @@
# crewai-cli
CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework.
## Installation
```bash
pip install crewai-cli
```
Or install alongside the full framework:
```bash
pip install crewai[cli]
```

View File

@@ -1,39 +0,0 @@
[project]
name = "crewai-cli"
version = "1.10.0"
description = "CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework."
readme = "README.md"
authors = [
{ name = "Joao Moura", email = "joao@crewai.com" }
]
requires-python = ">=3.10, <3.14"
dependencies = [
"click~=8.1.7",
"pydantic~=2.11.9",
"pydantic-settings~=2.10.1",
"appdirs~=1.4.4",
"httpx~=0.28.1",
"pyjwt>=2.9.0,<3",
"rich>=13.7.1",
"tomli~=2.0.2",
"tomli-w~=1.1.0",
"packaging>=23.0",
"python-dotenv~=1.1.1",
"uv~=0.9.13",
"portalocker~=2.7.0",
]
[project.urls]
Homepage = "https://crewai.com"
Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[project.scripts]
crewai = "crewai_cli.cli:crewai"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/crewai_cli"]

View File

@@ -1 +0,0 @@
__version__ = "1.10.0"

View File

@@ -1,4 +0,0 @@
from crewai_cli.authentication.main import AuthenticationCommand
__all__ = ["AuthenticationCommand"]

View File

@@ -1 +0,0 @@
ALGORITHMS = ["RS256"]

View File

@@ -1,215 +0,0 @@
import time
from typing import TYPE_CHECKING, Any, TypeVar, cast
import webbrowser
import httpx
from pydantic import BaseModel, Field
from rich.console import Console
from crewai_cli.authentication.utils import validate_jwt_token
from crewai_cli.config import Settings
from crewai_cli.shared.token_manager import TokenManager
console = Console()
TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
class Oauth2Settings(BaseModel):
provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
)
client_id: str = Field(
description="OAuth2 client ID issued by the provider, used during authentication requests."
)
domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
)
audience: str | None = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=None,
)
extra: dict[str, Any] = Field(
description="Extra configuration for the OAuth2 provider.",
default={},
)
@classmethod
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
"""Create an Oauth2Settings instance from the CLI settings."""
settings = Settings()
return cls(
provider=settings.oauth2_provider,
domain=settings.oauth2_domain,
client_id=settings.oauth2_client_id,
audience=settings.oauth2_audience,
extra=settings.oauth2_extra,
)
if TYPE_CHECKING:
from crewai_cli.authentication.providers.base_provider import BaseProvider
class ProviderFactory:
@classmethod
def from_settings(
cls: type["ProviderFactory"], # noqa: UP037
settings: Oauth2Settings | None = None,
) -> "BaseProvider": # noqa: UP037
settings = settings or Oauth2Settings.from_settings()
import importlib
module = importlib.import_module(
f"crewai_cli.authentication.providers.{settings.provider.lower()}"
)
# Converts from snake_case to CamelCase to obtain the provider class name.
provider = getattr(
module,
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
)
return cast("BaseProvider", provider(settings))
class AuthenticationCommand:
def __init__(self) -> None:
self.token_manager = TokenManager()
self.oauth2_provider = ProviderFactory.from_settings()
def login(self) -> None:
"""Sign up to CrewAI+"""
console.print("Signing in to CrewAI AMP...\n", style="bold blue")
device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data)
return self._poll_for_token(device_code_data)
def _get_device_code(self) -> dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
"client_id": self.oauth2_provider.get_client_id(),
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
"audience": self.oauth2_provider.get_audience(),
}
response = httpx.post(
url=self.oauth2_provider.get_authorize_url(),
data=device_code_payload,
timeout=20,
)
response.raise_for_status()
return cast(dict[str, Any], response.json())
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
"""Display the authentication instructions to the user."""
verification_uri = device_code_data.get(
"verification_uri_complete", device_code_data.get("verification_uri", "")
)
console.print("1. Navigate to: ", verification_uri)
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(verification_uri)
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
"""Polls the server for the token until it is received, or max attempts are reached."""
token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code_data["device_code"],
"client_id": self.oauth2_provider.get_client_id(),
}
console.print("\nWaiting for authentication... ", style="bold blue", end="")
attempts = 0
while True and attempts < 10:
response = httpx.post(
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
)
token_data = response.json()
if response.status_code == 200:
self._validate_and_save_token(token_data)
console.print(
"Success!",
style="bold green",
)
self._login_to_tool_repository()
console.print("\n[bold green]Welcome to CrewAI AMP![/bold green]\n")
return
if token_data["error"] not in ("authorization_pending", "slow_down"):
raise httpx.HTTPError(
token_data.get("error_description") or token_data.get("error")
)
time.sleep(device_code_data["interval"])
attempts += 1
console.print(
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
def _validate_and_save_token(self, token_data: dict[str, Any]) -> None:
"""Validates the JWT token and saves the token to the token manager."""
jwt_token = token_data["access_token"]
issuer = self.oauth2_provider.get_issuer()
jwt_token_data = {
"jwt_token": jwt_token,
"jwks_url": self.oauth2_provider.get_jwks_url(),
"issuer": issuer,
"audience": self.oauth2_provider.get_audience(),
}
decoded_token = validate_jwt_token(**jwt_token_data)
expires_at = decoded_token.get("exp", 0)
self.token_manager.save_tokens(jwt_token, expires_at)
def _login_to_tool_repository(self) -> None:
"""Login to the tool repository."""
from crewai_cli.tools.main import ToolCommand
try:
console.print(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
)
ToolCommand().login()
console.print(
"Success!\n",
style="bold green",
)
settings = Settings()
console.print(
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
style="green",
)
except Exception:
console.print(
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
style="yellow",
)
console.print(
"Other features will work normally, but you may experience limitations "
"with downloading and publishing tools."
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
)

View File

@@ -1,34 +0,0 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth/device/code"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/.well-known/jwks.json"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}/"
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def _get_domain(self) -> str:
if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain

View File

@@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from crewai_cli.authentication.main import Oauth2Settings
class BaseProvider(ABC):
def __init__(self, settings: Oauth2Settings):
self.settings = settings
@abstractmethod
def get_authorize_url(self) -> str: ...
@abstractmethod
def get_token_url(self) -> str: ...
@abstractmethod
def get_jwks_url(self) -> str: ...
@abstractmethod
def get_issuer(self) -> str: ...
@abstractmethod
def get_audience(self) -> str: ...
@abstractmethod
def get_client_id(self) -> str: ...
def get_required_fields(self) -> list[str]:
"""Returns which provider-specific fields inside the "extra" dict will be required"""
return []
def get_oauth_scopes(self) -> list[str]:
return ["openid", "profile", "email"]

View File

@@ -1,43 +0,0 @@
from typing import cast
from crewai_cli.authentication.providers.base_provider import BaseProvider
class EntraIdProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"{self._base_url()}/oauth2/v2.0/devicecode"
def get_token_url(self) -> str:
return f"{self._base_url()}/oauth2/v2.0/token"
def get_jwks_url(self) -> str:
return f"{self._base_url()}/discovery/v2.0/keys"
def get_issuer(self) -> str:
return f"{self._base_url()}/v2.0"
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_oauth_scopes(self) -> list[str]:
return [
*super().get_oauth_scopes(),
*cast(str, self.settings.extra.get("scope", "")).split(),
]
def get_required_fields(self) -> list[str]:
return ["scope"]
def _base_url(self) -> str:
return f"https://login.microsoftonline.com/{self.settings.domain}"

View File

@@ -1,32 +0,0 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
class KeycloakProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/auth/device"
def get_token_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/token"
def get_jwks_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/certs"
def get_issuer(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}"
def get_audience(self) -> str:
return self.settings.audience or "no-audience-provided"
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_required_fields(self) -> list[str]:
return ["realm"]
def _oauth2_base_url(self) -> str:
domain = self.settings.domain.removeprefix("https://").removeprefix("http://")
return f"https://{domain}"

View File

@@ -1,42 +0,0 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/device/authorize"
def get_token_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/token"
def get_jwks_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/keys"
def get_issuer(self) -> str:
return self._oauth2_base_url().removesuffix("/oauth2")
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_required_fields(self) -> list[str]:
return ["authorization_server_name", "using_org_auth_server"]
def _oauth2_base_url(self) -> str:
using_org_auth_server = self.settings.extra.get("using_org_auth_server", False)
if using_org_auth_server:
base_url = f"https://{self.settings.domain}/oauth2"
else:
base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}"
return f"{base_url}"

View File

@@ -1,30 +0,0 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/device_authorization"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/jwks"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}"
def get_audience(self) -> str:
return self.settings.audience or ""
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def _get_domain(self) -> str:
if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain

View File

@@ -1,13 +0,0 @@
from crewai_cli.shared.token_manager import TokenManager
class AuthError(Exception):
pass
def get_auth_token() -> str:
"""Get the authentication token."""
access_token = TokenManager().get_token()
if not access_token:
raise AuthError("No token found, make sure you are logged in")
return access_token

View File

@@ -1,63 +0,0 @@
from typing import Any
import jwt
from jwt import PyJWKClient
def validate_jwt_token(
jwt_token: str, jwks_url: str, issuer: str, audience: str
) -> Any:
"""
Verify the token's signature and claims using PyJWT.
:param jwt_token: The JWT (JWS) string to validate.
:param jwks_url: The URL of the JWKS endpoint.
:param issuer: The expected issuer of the token.
:param audience: The expected audience of the token.
:return: The decoded token.
:raises Exception: If the token is invalid for any reason (e.g., signature mismatch,
expired, incorrect issuer/audience, JWKS fetching error,
missing required claims).
"""
try:
jwk_client = PyJWKClient(jwks_url)
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
_unverified_decoded_token = jwt.decode(
jwt_token, options={"verify_signature": False}
)
return jwt.decode(
jwt_token,
signing_key.key,
algorithms=["RS256"],
audience=audience,
issuer=issuer,
leeway=10.0,
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
except jwt.ExpiredSignatureError as e:
raise Exception("Token has expired.") from e
except jwt.InvalidAudienceError as e:
actual_audience = _unverified_decoded_token.get("aud", "[no audience found]")
raise Exception(
f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'"
) from e
except jwt.InvalidIssuerError as e:
actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]")
raise Exception(
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
) from e
except jwt.MissingRequiredClaimError as e:
raise Exception(f"Token is missing required claims: {e!s}") from e
except jwt.exceptions.PyJWKClientError as e:
raise Exception(f"JWKS or key processing error: {e!s}") from e
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {e!s}") from e

View File

@@ -1,68 +0,0 @@
from __future__ import annotations
import json
import httpx
from rich.console import Console
from crewai_cli.authentication.token import get_auth_token
from crewai_cli.plus_api import PlusAPI
console = Console()
class BaseCommand:
def __init__(self) -> None:
pass
class PlusAPIMixin:
def __init__(self) -> None:
try:
self.plus_api_client = PlusAPI(api_key=get_auth_token())
except Exception:
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
)
console.print("Run 'crewai login' to sign up/login.", style="bold green")
raise SystemExit from None
def _validate_response(self, response: httpx.Response) -> None:
try:
json_response = response.json()
except (json.JSONDecodeError, ValueError):
console.print(
"Failed to parse response from Enterprise API failed. Details:",
style="bold red",
)
console.print(f"Status Code: {response.status_code}")
console.print(
f"Response:\n{response.content.decode('utf-8', errors='replace')}"
)
raise SystemExit from None
if response.status_code == 422:
console.print(
"Failed to complete operation. Please fix the following errors:",
style="bold red",
)
for field, messages in json_response.items():
for message in messages:
console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}"
)
raise SystemExit
if not response.is_success:
console.print(
"Request to Enterprise API failed. Details:", style="bold red"
)
details = (
json_response.get("error")
or json_response.get("message")
or response.content.decode("utf-8", errors="replace")
)
console.print(f"{details}")
raise SystemExit

View File

@@ -1,221 +0,0 @@
import json
from logging import getLogger
from pathlib import Path
import tempfile
from typing import Any
from pydantic import BaseModel, Field
from crewai_cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
DEFAULT_CREWAI_ENTERPRISE_URL,
)
from crewai_cli.shared.token_manager import TokenManager
logger = getLogger(__name__)
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
def get_writable_config_path() -> Path | None:
"""
Find a writable location for the config file with fallback options.
Tries in order:
1. Default: ~/.config/crewai/settings.json
2. Temp directory: /tmp/crewai_settings.json (or OS equivalent)
3. Current directory: ./crewai_settings.json
4. In-memory only (returns None)
Returns:
Path object for writable config location, or None if no writable location found
"""
fallback_paths = [
DEFAULT_CONFIG_PATH, # Default location
Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory
Path.cwd() / "crewai_settings.json", # Current working directory
]
for config_path in fallback_paths:
try:
config_path.parent.mkdir(parents=True, exist_ok=True)
test_file = config_path.parent / ".crewai_write_test"
try:
test_file.write_text("test")
test_file.unlink() # Clean up test file
logger.info(f"Using config path: {config_path}")
return config_path
except Exception: # noqa: S112
continue
except Exception: # noqa: S112
continue
return None
# Settings that are related to the user's account
USER_SETTINGS_KEYS = [
"tool_repository_username",
"tool_repository_password",
"org_name",
"org_uuid",
]
# Settings that are related to the CLI
CLI_SETTINGS_KEYS = [
"enterprise_base_url",
"oauth2_provider",
"oauth2_audience",
"oauth2_client_id",
"oauth2_domain",
"oauth2_extra",
]
# Default values for CLI settings
DEFAULT_CLI_SETTINGS = {
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
"oauth2_extra": {},
}
# Readonly settings - cannot be set by the user
READONLY_SETTINGS_KEYS = [
"org_name",
"org_uuid",
]
# Hidden settings - not displayed by the 'list' command and cannot be set by the user
HIDDEN_SETTINGS_KEYS = [
"config_path",
"tool_repository_username",
"tool_repository_password",
]
class Settings(BaseModel):
enterprise_base_url: str | None = Field(
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
description="Base URL of the CrewAI AMP instance",
)
tool_repository_username: str | None = Field(
None, description="Username for interacting with the Tool Repository"
)
tool_repository_password: str | None = Field(
None, description="Password for interacting with the Tool Repository"
)
org_name: str | None = Field(
None, description="Name of the currently active organization"
)
org_uuid: str | None = Field(
None, description="UUID of the currently active organization"
)
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
oauth2_provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
)
oauth2_audience: str | None = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
)
oauth2_client_id: str = Field(
default=DEFAULT_CLI_SETTINGS["oauth2_client_id"],
description="OAuth2 client ID issued by the provider, used during authentication requests.",
)
oauth2_domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
)
oauth2_extra: dict[str, Any] = Field(
description="Extra configuration for the OAuth2 provider.",
default={},
)
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
"""Load Settings from config path with fallback support"""
if config_path is None:
config_path = get_writable_config_path()
# If config_path is None, we're in memory-only mode
if config_path is None:
merged_data = {**data}
# Dummy path for memory-only mode
super().__init__(config_path=Path("/dev/null"), **merged_data)
return
try:
config_path.parent.mkdir(parents=True, exist_ok=True)
except Exception:
merged_data = {**data}
# Dummy path for memory-only mode
super().__init__(config_path=Path("/dev/null"), **merged_data)
return
file_data = {}
if config_path.is_file():
try:
with config_path.open("r") as f:
file_data = json.load(f)
except Exception:
file_data = {}
merged_data = {**file_data, **data}
super().__init__(config_path=config_path, **merged_data)
def clear_user_settings(self) -> None:
"""Clear all user settings"""
self._reset_user_settings()
self.dump()
def reset(self) -> None:
"""Reset all settings to default values"""
self._reset_user_settings()
self._reset_cli_settings()
self._clear_auth_tokens()
self.dump()
def dump(self) -> None:
"""Save current settings to settings.json"""
if str(self.config_path) == "/dev/null":
return
try:
if self.config_path.is_file():
with self.config_path.open("r") as f:
existing_data = json.load(f)
else:
existing_data = {}
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
with self.config_path.open("w") as f:
json.dump(updated_data, f, indent=4)
except Exception: # noqa: S110
pass
def _reset_user_settings(self) -> None:
"""Reset all user settings to default values"""
for key in USER_SETTINGS_KEYS:
setattr(self, key, None)
def _reset_cli_settings(self) -> None:
"""Reset all CLI settings to default values"""
for key in CLI_SETTINGS_KEYS:
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
def _clear_auth_tokens(self) -> None:
"""Clear all authentication tokens"""
TokenManager().clear_tokens()

View File

@@ -1,333 +0,0 @@
from typing import Any
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
ENV_VARS: dict[str, list[dict[str, Any]]] = {
"openai": [
{
"prompt": "Enter your OPENAI API key (press Enter to skip)",
"key_name": "OPENAI_API_KEY",
}
],
"anthropic": [
{
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
"key_name": "ANTHROPIC_API_KEY",
}
],
"gemini": [
{
"prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)",
"key_name": "GEMINI_API_KEY",
}
],
"nvidia_nim": [
{
"prompt": "Enter your NVIDIA API key (press Enter to skip)",
"key_name": "NVIDIA_NIM_API_KEY",
}
],
"groq": [
{
"prompt": "Enter your GROQ API key (press Enter to skip)",
"key_name": "GROQ_API_KEY",
}
],
"watson": [
{
"prompt": "Enter your WATSONX URL (press Enter to skip)",
"key_name": "WATSONX_URL",
},
{
"prompt": "Enter your WATSONX API Key (press Enter to skip)",
"key_name": "WATSONX_APIKEY",
},
{
"prompt": "Enter your WATSONX Project Id (press Enter to skip)",
"key_name": "WATSONX_PROJECT_ID",
},
],
"ollama": [
{
"default": True,
"API_BASE": "http://localhost:11434",
}
],
"bedrock": [
{
"prompt": "Enter your AWS Access Key ID (press Enter to skip)",
"key_name": "AWS_ACCESS_KEY_ID",
},
{
"prompt": "Enter your AWS Secret Access Key (press Enter to skip)",
"key_name": "AWS_SECRET_ACCESS_KEY",
},
{
"prompt": "Enter your AWS Region Name (press Enter to skip)",
"key_name": "AWS_DEFAULT_REGION",
},
],
"azure": [
{
"prompt": "Enter your Azure deployment name (must start with 'azure/')",
"key_name": "model",
},
{
"prompt": "Enter your AZURE API key (press Enter to skip)",
"key_name": "AZURE_API_KEY",
},
{
"prompt": "Enter your AZURE API base URL (press Enter to skip)",
"key_name": "AZURE_API_BASE",
},
{
"prompt": "Enter your AZURE API version (press Enter to skip)",
"key_name": "AZURE_API_VERSION",
},
],
"cerebras": [
{
"prompt": "Enter your Cerebras model name (must start with 'cerebras/')",
"key_name": "model",
},
{
"prompt": "Enter your Cerebras API version (press Enter to skip)",
"key_name": "CEREBRAS_API_KEY",
},
],
"huggingface": [
{
"prompt": "Enter your Huggingface API key (HF_TOKEN) (press Enter to skip)",
"key_name": "HF_TOKEN",
},
],
"sambanova": [
{
"prompt": "Enter your SambaNovaCloud API key (press Enter to skip)",
"key_name": "SAMBANOVA_API_KEY",
}
],
}
PROVIDERS: list[str] = [
"openai",
"anthropic",
"gemini",
"nvidia_nim",
"groq",
"huggingface",
"ollama",
"watson",
"bedrock",
"azure",
"cerebras",
"sambanova",
]
MODELS: dict[str, list[str]] = {
"openai": [
"gpt-4",
"gpt-4.1",
"gpt-4.1-mini-2025-04-14",
"gpt-4.1-nano-2025-04-14",
"gpt-4o",
"gpt-4o-mini",
"o1-mini",
"o1-preview",
],
"anthropic": [
"claude-3-5-sonnet-20240620",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
],
"gemini": [
"gemini/gemini-3-pro-preview",
"gemini/gemini-1.5-flash",
"gemini/gemini-1.5-pro",
"gemini/gemini-2.0-flash-lite-001",
"gemini/gemini-2.0-flash-001",
"gemini/gemini-2.0-flash-thinking-exp-01-21",
"gemini/gemini-2.5-flash-preview-04-17",
"gemini/gemini-2.5-pro-exp-03-25",
"gemini/gemini-gemma-2-9b-it",
"gemini/gemini-gemma-2-27b-it",
"gemini/gemma-3-1b-it",
"gemini/gemma-3-4b-it",
"gemini/gemma-3-12b-it",
"gemini/gemma-3-27b-it",
],
"nvidia_nim": [
"nvidia_nim/nvidia/mistral-nemo-minitron-8b-8k-instruct",
"nvidia_nim/nvidia/nemotron-4-mini-hindi-4b-instruct",
"nvidia_nim/nvidia/llama-3.1-nemotron-70b-instruct",
"nvidia_nim/nvidia/llama3-chatqa-1.5-8b",
"nvidia_nim/nvidia/llama3-chatqa-1.5-70b",
"nvidia_nim/nvidia/vila",
"nvidia_nim/nvidia/neva-22",
"nvidia_nim/nvidia/nemotron-mini-4b-instruct",
"nvidia_nim/nvidia/usdcode-llama3-70b-instruct",
"nvidia_nim/nvidia/nemotron-4-340b-instruct",
"nvidia_nim/meta/codellama-70b",
"nvidia_nim/meta/llama2-70b",
"nvidia_nim/meta/llama3-8b-instruct",
"nvidia_nim/meta/llama3-70b-instruct",
"nvidia_nim/meta/llama-3.1-8b-instruct",
"nvidia_nim/meta/llama-3.1-70b-instruct",
"nvidia_nim/meta/llama-3.1-405b-instruct",
"nvidia_nim/meta/llama-3.2-1b-instruct",
"nvidia_nim/meta/llama-3.2-3b-instruct",
"nvidia_nim/meta/llama-3.2-11b-vision-instruct",
"nvidia_nim/meta/llama-3.2-90b-vision-instruct",
"nvidia_nim/meta/llama-3.1-70b-instruct",
"nvidia_nim/google/gemma-7b",
"nvidia_nim/google/gemma-2b",
"nvidia_nim/google/codegemma-7b",
"nvidia_nim/google/codegemma-1.1-7b",
"nvidia_nim/google/recurrentgemma-2b",
"nvidia_nim/google/gemma-2-9b-it",
"nvidia_nim/google/gemma-2-27b-it",
"nvidia_nim/google/gemma-2-2b-it",
"nvidia_nim/google/deplot",
"nvidia_nim/google/paligemma",
"nvidia_nim/mistralai/mistral-7b-instruct-v0.2",
"nvidia_nim/mistralai/mixtral-8x7b-instruct-v0.1",
"nvidia_nim/mistralai/mistral-large",
"nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1",
"nvidia_nim/mistralai/mistral-7b-instruct-v0.3",
"nvidia_nim/nv-mistralai/mistral-nemo-12b-instruct",
"nvidia_nim/mistralai/mamba-codestral-7b-v0.1",
"nvidia_nim/microsoft/phi-3-mini-128k-instruct",
"nvidia_nim/microsoft/phi-3-mini-4k-instruct",
"nvidia_nim/microsoft/phi-3-small-8k-instruct",
"nvidia_nim/microsoft/phi-3-small-128k-instruct",
"nvidia_nim/microsoft/phi-3-medium-4k-instruct",
"nvidia_nim/microsoft/phi-3-medium-128k-instruct",
"nvidia_nim/microsoft/phi-3.5-mini-instruct",
"nvidia_nim/microsoft/phi-3.5-moe-instruct",
"nvidia_nim/microsoft/kosmos-2",
"nvidia_nim/microsoft/phi-3-vision-128k-instruct",
"nvidia_nim/microsoft/phi-3.5-vision-instruct",
"nvidia_nim/databricks/dbrx-instruct",
"nvidia_nim/snowflake/arctic",
"nvidia_nim/aisingapore/sea-lion-7b-instruct",
"nvidia_nim/ibm/granite-8b-code-instruct",
"nvidia_nim/ibm/granite-34b-code-instruct",
"nvidia_nim/ibm/granite-3.0-8b-instruct",
"nvidia_nim/ibm/granite-3.0-3b-a800m-instruct",
"nvidia_nim/mediatek/breeze-7b-instruct",
"nvidia_nim/upstage/solar-10.7b-instruct",
"nvidia_nim/writer/palmyra-med-70b-32k",
"nvidia_nim/writer/palmyra-med-70b",
"nvidia_nim/writer/palmyra-fin-70b-32k",
"nvidia_nim/01-ai/yi-large",
"nvidia_nim/deepseek-ai/deepseek-coder-6.7b-instruct",
"nvidia_nim/rakuten/rakutenai-7b-instruct",
"nvidia_nim/rakuten/rakutenai-7b-chat",
"nvidia_nim/baichuan-inc/baichuan2-13b-chat",
],
"groq": [
"groq/llama-3.1-8b-instant",
"groq/llama-3.1-70b-versatile",
"groq/llama-3.1-405b-reasoning",
"groq/gemma2-9b-it",
"groq/gemma-7b-it",
],
"ollama": ["ollama/llama3.1", "ollama/mixtral"],
"watson": [
"watsonx/meta-llama/llama-3-1-70b-instruct",
"watsonx/meta-llama/llama-3-1-8b-instruct",
"watsonx/meta-llama/llama-3-2-11b-vision-instruct",
"watsonx/meta-llama/llama-3-2-1b-instruct",
"watsonx/meta-llama/llama-3-2-90b-vision-instruct",
"watsonx/meta-llama/llama-3-405b-instruct",
"watsonx/mistral/mistral-large",
"watsonx/ibm/granite-3-8b-instruct",
],
"bedrock": [
"bedrock/us.amazon.nova-pro-v1:0",
"bedrock/us.amazon.nova-micro-v1:0",
"bedrock/us.amazon.nova-lite-v1:0",
"bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0",
"bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"bedrock/us.anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/us.anthropic.claude-3-opus-20240229-v1:0",
"bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/us.meta.llama3-2-11b-instruct-v1:0",
"bedrock/us.meta.llama3-2-3b-instruct-v1:0",
"bedrock/us.meta.llama3-2-90b-instruct-v1:0",
"bedrock/us.meta.llama3-2-1b-instruct-v1:0",
"bedrock/us.meta.llama3-1-8b-instruct-v1:0",
"bedrock/us.meta.llama3-1-70b-instruct-v1:0",
"bedrock/us.meta.llama3-3-70b-instruct-v1:0",
"bedrock/us.meta.llama3-1-405b-instruct-v1:0",
"bedrock/eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/eu.anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/eu.anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/eu.meta.llama3-2-3b-instruct-v1:0",
"bedrock/eu.meta.llama3-2-1b-instruct-v1:0",
"bedrock/apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/apac.anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/apac.anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/amazon.nova-pro-v1:0",
"bedrock/amazon.nova-micro-v1:0",
"bedrock/amazon.nova-lite-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0",
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/anthropic.claude-v2:1",
"bedrock/anthropic.claude-v2",
"bedrock/anthropic.claude-instant-v1",
"bedrock/meta.llama3-1-405b-instruct-v1:0",
"bedrock/meta.llama3-1-70b-instruct-v1:0",
"bedrock/meta.llama3-1-8b-instruct-v1:0",
"bedrock/meta.llama3-70b-instruct-v1:0",
"bedrock/meta.llama3-8b-instruct-v1:0",
"bedrock/amazon.titan-text-lite-v1",
"bedrock/amazon.titan-text-express-v1",
"bedrock/cohere.command-text-v14",
"bedrock/ai21.j2-mid-v1",
"bedrock/ai21.j2-ultra-v1",
"bedrock/ai21.jamba-instruct-v1:0",
"bedrock/mistral.mistral-7b-instruct-v0:2",
"bedrock/mistral.mixtral-8x7b-instruct-v0:1",
],
"huggingface": [
"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
"huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1",
"huggingface/tiiuae/falcon-180B-chat",
"huggingface/google/gemma-7b-it",
],
"sambanova": [
"sambanova/Meta-Llama-3.3-70B-Instruct",
"sambanova/QwQ-32B-Preview",
"sambanova/Qwen2.5-72B-Instruct",
"sambanova/Qwen2.5-Coder-32B-Instruct",
"sambanova/Meta-Llama-3.1-405B-Instruct",
"sambanova/Meta-Llama-3.1-70B-Instruct",
"sambanova/Meta-Llama-3.1-8B-Instruct",
"sambanova/Llama-3.2-90B-Vision-Instruct",
"sambanova/Llama-3.2-11B-Vision-Instruct",
"sambanova/Meta-Llama-3.2-3B-Instruct",
"sambanova/Meta-Llama-3.2-1B-Instruct",
],
}
DEFAULT_LLM_MODEL = "gpt-4.1-mini"
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]

View File

@@ -1,23 +0,0 @@
"""Wrapper for the crew chat command.
Delegates to ``crewai.cli.crew_chat.run_chat`` when the full crewai package is
installed, otherwise prints a helpful error message.
"""
from __future__ import annotations
import click
def run_chat() -> None:
try:
from crewai.cli.crew_chat import run_chat as _run_chat
except ImportError:
click.secho(
"The 'chat' command requires the full crewai package.\n"
"Install it with: pip install crewai",
fg="red",
)
raise SystemExit(1) from None
_run_chat()

View File

@@ -1,89 +0,0 @@
from functools import lru_cache
import subprocess
class Repository:
def __init__(self, path: str = ".") -> None:
self.path = path
if not self.is_git_installed():
raise ValueError("Git is not installed or not found in your PATH.")
if not self.is_git_repo():
raise ValueError(f"{self.path} is not a Git repository.")
self.fetch()
@staticmethod
def is_git_installed() -> bool:
"""Check if Git is installed and available in the system."""
try:
subprocess.run(
["git", "--version"], # noqa: S607
capture_output=True,
check=True,
text=True,
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def fetch(self) -> None:
"""Fetch latest updates from the remote."""
subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607
def status(self) -> str:
"""Get the git status in porcelain format."""
return subprocess.check_output(
["git", "status", "--branch", "--porcelain"], # noqa: S607
cwd=self.path,
encoding="utf-8",
).strip()
@lru_cache(maxsize=None) # noqa: B019
def is_git_repo(self) -> bool:
"""Check if the current directory is a git repository.
Notes:
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
"""
try:
subprocess.check_output(
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
cwd=self.path,
encoding="utf-8",
)
return True
except subprocess.CalledProcessError:
return False
def has_uncommitted_changes(self) -> bool:
"""Check if the repository has uncommitted changes."""
return len(self.status().splitlines()) > 1
def is_ahead_or_behind(self) -> bool:
"""Check if the repository is ahead or behind the remote."""
for line in self.status().splitlines():
if line.startswith("##") and ("ahead" in line or "behind" in line):
return True
return False
def is_synced(self) -> bool:
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
return False
return True
def origin_url(self) -> str | None:
"""Get the Git repository's remote URL."""
try:
result = subprocess.run(
["git", "remote", "get-url", "origin"], # noqa: S607
cwd=self.path,
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except subprocess.CalledProcessError:
return None

View File

@@ -1,210 +0,0 @@
import os
from typing import Any
from urllib.parse import urljoin
import httpx
from crewai_cli.config import Settings
from crewai_cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai_cli.version import get_crewai_version
class PlusAPI:
"""
This class exposes methods for working with the CrewAI+ API.
"""
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
ORGANIZATIONS_RESOURCE = "/crewai_plus/api/v1/me/organizations"
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
TRACING_RESOURCE = "/crewai_plus/api/v1/tracing"
EPHEMERAL_TRACING_RESOURCE = "/crewai_plus/api/v1/tracing/ephemeral"
INTEGRATIONS_RESOURCE = "/crewai_plus/api/v1/integrations"
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
"X-Crewai-Version": get_crewai_version(),
}
settings = Settings()
if settings.org_uuid:
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
self.base_url = (
os.getenv("CREWAI_PLUS_URL")
or str(settings.enterprise_base_url)
or DEFAULT_CREWAI_ENTERPRISE_URL
)
def _make_request(
self, method: str, endpoint: str, **kwargs: Any
) -> httpx.Response:
url = urljoin(self.base_url, endpoint)
verify = kwargs.pop("verify", True)
with httpx.Client(trust_env=False, verify=verify) as client:
return client.request(method, url, headers=self.headers, **kwargs)
def login_to_tool_repository(self) -> httpx.Response:
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
def get_tool(self, handle: str) -> httpx.Response:
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
async def get_agent(self, handle: str) -> httpx.Response:
url = urljoin(self.base_url, f"{self.AGENTS_RESOURCE}/{handle}")
async with httpx.AsyncClient() as client:
return await client.get(url, headers=self.headers)
def publish_tool(
self,
handle: str,
is_public: bool,
version: str,
description: str | None,
encoded_file: str,
available_exports: list[dict[str, Any]] | None = None,
) -> httpx.Response:
params = {
"handle": handle,
"public": is_public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": available_exports,
}
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
def deploy_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
def crew_status_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
)
def crew_status_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment"
) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def crew_by_uuid(self, uuid: str, log_type: str = "deployment") -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
)
def delete_crew_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
)
def delete_crew_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
def list_crews(self) -> httpx.Response:
return self._make_request("GET", self.CREWS_RESOURCE)
def create_crew(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
def get_organizations(self) -> httpx.Response:
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
def initialize_trace_batch(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches",
json=payload,
timeout=30,
)
def initialize_ephemeral_trace_batch(
self, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
json=payload,
)
def send_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def send_ephemeral_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def finalize_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
json=payload,
timeout=30,
)
def finalize_ephemeral_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
json=payload,
timeout=30,
)
def mark_trace_batch_as_failed(
self, trace_batch_id: str, error_message: str
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
json={"status": "failed", "failure_reason": error_message},
timeout=30,
)
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
"""Get MCP server configurations for the given slugs."""
return self._make_request(
"GET",
f"{self.INTEGRATIONS_RESOURCE}/mcp_configs",
params={"slugs": ",".join(slugs)},
timeout=30,
)
def get_triggers(self) -> httpx.Response:
"""Get all available triggers from integrations."""
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
def get_trigger_payload(self, app_slug: str, trigger_slug: str) -> httpx.Response:
"""Get sample payload for a specific trigger."""
return self._make_request(
"GET", f"{self.INTEGRATIONS_RESOURCE}/{app_slug}/{trigger_slug}/payload"
)

View File

@@ -1,231 +0,0 @@
from collections import defaultdict
from collections.abc import Sequence
import json
import os
from pathlib import Path
import time
from typing import Any
import certifi
import click
import httpx
from crewai_cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
"""Presents a list of choices to the user and prompts them to select one.
Args:
prompt_message: The message to display to the user before presenting the choices.
choices: A list of options to present to the user.
Returns:
The selected choice from the list, or None if the user chooses to quit.
"""
provider_models = get_provider_data()
if not provider_models:
return None
click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan")
click.secho("q. Quit", fg="cyan")
while True:
choice = click.prompt(
"Enter the number of your choice or 'q' to quit", type=str
)
if choice.lower() == "q":
return None
try:
selected_index = int(choice) - 1
if 0 <= selected_index < len(choices):
return choices[selected_index]
except ValueError:
pass
click.secho(
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
fg="red",
)
def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
"""Presents a list of providers to the user and prompts them to select one.
Args:
provider_models: A dictionary of provider models.
Returns:
The selected provider, None if user explicitly quits, or False if no selection.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice(
"Select a provider to set up:", [*predefined_providers, "other"]
)
if provider is None: # User typed 'q'
return None
if provider == "other":
provider = select_choice("Select a provider from the full list:", all_providers)
if provider is None: # User typed 'q'
return None
return provider.lower() if provider else False
def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
"""Presents a list of models for a given provider to the user and prompts them to select one.
Args:
provider: The provider for which to select a model.
provider_models: A dictionary of provider models.
Returns:
The selected model, or None if the operation is aborted or an invalid selection is made.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
if provider in predefined_providers:
available_models = MODELS.get(provider, [])
else:
available_models = provider_models.get(provider, [])
if not available_models:
click.secho(f"No models available for provider '{provider}'.", fg="red")
return None
return select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models
)
def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
"""Loads provider data from a cache file if it exists and is not expired.
If the cache is expired or corrupted, it fetches the data from the web.
Args:
cache_file: The path to the cache file.
cache_expiry: The cache expiry time in seconds.
Returns:
The loaded provider data or None if the operation fails.
"""
current_time = time.time()
if (
cache_file.exists()
and (current_time - cache_file.stat().st_mtime) < cache_expiry
):
data = read_cache_file(cache_file)
if data:
return data
click.secho(
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
)
else:
click.secho(
"Cache expired or not found. Fetching provider data from the web...",
fg="cyan",
)
return fetch_provider_data(cache_file)
def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
"""Reads and returns the JSON content from a cache file.
Args:
cache_file: The path to the cache file.
Returns:
The JSON content of the cache file or None if the JSON is invalid.
"""
try:
with open(cache_file, "r") as f:
data: dict[str, Any] = json.load(f)
return data
except json.JSONDecodeError:
return None
def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
"""Fetches provider data from a specified URL and caches it to a file.
Args:
cache_file: The path to the cache file.
Returns:
The fetched provider data or None if the operation fails.
"""
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
try:
with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
response.raise_for_status()
data = download_data(response)
with open(cache_file, "w") as f:
json.dump(data, f)
return data
except httpx.HTTPError as e:
click.secho(f"Error fetching provider data: {e}", fg="red")
except json.JSONDecodeError:
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None
def download_data(response: httpx.Response) -> dict[str, Any]:
"""Downloads data from a given HTTP response and returns the JSON content.
Args:
response: The HTTP response object.
Returns:
The JSON content of the response.
"""
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
data_chunks: list[bytes] = []
bar: Any
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as bar:
for chunk in response.iter_bytes(block_size):
if chunk:
data_chunks.append(chunk)
bar.update(len(chunk))
data_content = b"".join(data_chunks)
result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
return result
def get_provider_data() -> dict[str, list[str]] | None:
"""Retrieves provider data from a cache file.
Filters out models based on provider criteria, and returns a dictionary of providers
mapped to their models.
Returns:
A dictionary of providers mapped to their models or None if the operation fails.
"""
cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True)
cache_file = cache_dir / "provider_cache.json"
cache_expiry = 24 * 3600
data = load_provider_data(cache_file, cache_expiry)
if not data:
return None
provider_models = defaultdict(list)
for model_name, properties in data.items():
provider = properties.get("litellm_provider", "").strip().lower()
if "http" in provider or provider == "other":
continue
if provider:
provider_models[provider].append(model_name)
return provider_models

View File

@@ -1,31 +0,0 @@
"""Wrapper for the reset-memories command.
Delegates to ``crewai.cli.reset_memories_command`` when the full crewai
package is installed, otherwise prints a helpful error message.
"""
from __future__ import annotations
import click
def reset_memories_command(
memory: bool,
knowledge: bool,
agent_knowledge: bool,
kickoff_outputs: bool,
all: bool,
) -> None:
try:
from crewai.cli.reset_memories_command import (
reset_memories_command as _reset,
)
except ImportError:
click.secho(
"The 'reset-memories' command requires the full crewai package.\n"
"Install it with: pip install crewai",
fg="red",
)
raise SystemExit(1) from None
_reset(memory, knowledge, agent_knowledge, kickoff_outputs, all)

View File

@@ -1,186 +0,0 @@
from datetime import datetime
import json
import os
from pathlib import Path
import sys
import tempfile
from typing import Final, Literal, cast
from cryptography.fernet import Fernet
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44
class TokenManager:
"""Manages encrypted token storage."""
def __init__(self, file_path: str = "tokens.enc") -> None:
"""Initialize the TokenManager.
Args:
file_path: The file path to store encrypted tokens.
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""Get or create the encryption key.
Returns:
The encryption key as bytes.
"""
key_filename: str = "secret.key"
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
new_key = Fernet.generate_key()
if self._atomic_create_secure_file(key_filename, new_key):
return new_key
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
raise RuntimeError("Failed to create or read encryption key")
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""Save the access token and its expiration time.
Args:
access_token: The access token to save.
expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self._atomic_write_secure_file(self.file_path, encrypted_data)
def get_token(self) -> str | None:
"""Get the access token if it is valid and not expired.
Returns:
The access token if valid and not expired, otherwise None.
"""
encrypted_data = self._read_secure_file(self.file_path)
if encrypted_data is None:
return None
decrypted_data = self.fernet.decrypt(encrypted_data)
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return cast(str | None, data.get("access_token"))
def clear_tokens(self) -> None:
"""Clear the stored tokens."""
self._delete_secure_file(self.file_path)
@staticmethod
def _get_secure_storage_path() -> Path:
"""Get the secure storage path based on the operating system.
Returns:
The secure storage path.
"""
if sys.platform == "win32":
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
base_path = os.path.expanduser("~/Library/Application Support")
else:
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
"""Create a file only if it doesn't exist.
Args:
filename: The name of the file.
content: The content to write.
Returns:
True if file was created, False if it already exists.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
try:
os.write(fd, content)
finally:
os.close(fd)
return True
except FileExistsError:
return False
def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
"""Write content to a secure file.
Args:
filename: The name of the file.
content: The content to write.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
fd_closed = False
try:
os.write(fd, content)
os.close(fd)
fd_closed = True
os.chmod(temp_path, 0o600)
os.replace(temp_path, file_path)
except Exception:
if not fd_closed:
os.close(fd)
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
def _read_secure_file(self, filename: str) -> bytes | None:
"""Read the content of a secure file.
Args:
filename: The name of the file.
Returns:
The content of the file if it exists, otherwise None.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
with open(file_path, "rb") as f:
return f.read()
except FileNotFoundError:
return None
def _delete_secure_file(self, filename: str) -> None:
"""Delete a secure file.
Args:
filename: The name of the file.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
file_path.unlink()
except FileNotFoundError:
pass

View File

@@ -1,54 +0,0 @@
"""Lightweight SQLite reader for kickoff task outputs.
Only used by the ``crewai log-tasks-outputs`` CLI command. Depends solely on
the standard library + *appdirs* so crewai-cli can read stored outputs without
importing the full crewai framework.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
import sqlite3
from typing import Any
from crewai_cli.user_data import _db_storage_path
logger = logging.getLogger(__name__)
def load_task_outputs(db_path: str | None = None) -> list[dict[str, Any]]:
"""Return all rows from the kickoff task outputs database."""
if db_path is None:
db_path = str(Path(_db_storage_path()) / "latest_kickoff_task_outputs.db")
if not Path(db_path).exists():
return []
try:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT *
FROM latest_kickoff_task_outputs
ORDER BY task_index
""")
rows = cursor.fetchall()
results: list[dict[str, Any]] = [
{
"task_id": row[0],
"expected_output": row[1],
"output": json.loads(row[2]),
"task_index": row[3],
"inputs": json.loads(row[4]),
"was_replayed": row[5],
"timestamp": row[6],
}
for row in rows
]
return results
except sqlite3.Error as e:
logger.error("Failed to load task outputs: %s", e)
return []

View File

@@ -1,66 +0,0 @@
"""Standalone user-data helpers for the CLI package.
These mirror the functions in ``crewai.events.listeners.tracing.utils`` but
depend only on the standard library + *appdirs* so that crewai-cli can work
without importing the full crewai framework.
"""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import Any, cast
import appdirs
logger = logging.getLogger(__name__)
def _get_project_directory_name() -> str:
return os.environ.get("CREWAI_STORAGE_DIR", Path.cwd().name)
def _db_storage_path() -> str:
app_name = _get_project_directory_name()
app_author = "CrewAI"
data_dir = Path(appdirs.user_data_dir(app_name, app_author))
data_dir.mkdir(parents=True, exist_ok=True)
return str(data_dir)
def _user_data_file() -> Path:
base = Path(_db_storage_path())
base.mkdir(parents=True, exist_ok=True)
return base / ".crewai_user.json"
def _load_user_data() -> dict[str, Any]:
p = _user_data_file()
if p.exists():
try:
return cast(dict[str, Any], json.loads(p.read_text()))
except (json.JSONDecodeError, OSError, PermissionError) as e:
logger.warning("Failed to load user data: %s", e)
return {}
def _save_user_data(data: dict[str, Any]) -> None:
try:
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
except (OSError, PermissionError) as e:
logger.warning("Failed to save user data: %s", e)
def is_tracing_enabled() -> bool:
"""Check if tracing is enabled (mirrors crewai core logic)."""
data = _load_user_data()
if (
data.get("first_execution_done", False)
and data.get("trace_consent", False) is False
):
return False
return os.getenv("CREWAI_TRACING_ENABLED", "false").lower() == "true"

View File

@@ -1,369 +0,0 @@
from __future__ import annotations
from functools import reduce
from inspect import getmro, isclass
import os
from pathlib import Path
import shutil
import sys
from typing import Any, cast
import click
from rich.console import Console
import tomli
from crewai_cli.config import Settings
from crewai_cli.constants import ENV_VARS
if sys.version_info >= (3, 11):
import tomllib
console = Console()
def copy_template(
src: Path, dst: Path, name: str, class_name: str, folder_name: str
) -> None:
"""Copy a file from src to dst."""
with open(src, "r") as file:
content = file.read()
content = content.replace("{{name}}", name)
content = content.replace("{{crew_name}}", class_name)
content = content.replace("{{folder_name}}", folder_name)
with open(dst, "w") as file:
file.write(content)
click.secho(f" - Created {dst}", fg="green")
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
"""Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f:
return tomli.load(f)
def parse_toml(content: str) -> dict[str, Any]:
if sys.version_info >= (3, 11):
return tomllib.loads(content)
return tomli.loads(content)
def get_project_name(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project name from the pyproject.toml file."""
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
def get_project_version(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project version from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "version"], require=require
)
def get_project_description(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project description from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "description"], require=require
)
def _get_project_attribute(
pyproject_path: str, keys: list[str], require: bool
) -> Any | None:
"""Get an attribute from the pyproject.toml file."""
attribute = None
try:
with open(pyproject_path, "r") as f:
pyproject_content = parse_toml(f.read())
dependencies = (
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
)
if not any(True for dep in dependencies if "crewai" in dep):
raise Exception("crewai is not in the dependencies.")
attribute = _get_nested_value(pyproject_content, keys)
except FileNotFoundError:
console.print(f"Error: {pyproject_path} not found.", style="bold red")
except KeyError:
console.print(
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
style="bold red",
)
except Exception as e:
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
console.print(
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
)
else:
console.print(
f"Error reading the pyproject.toml file: {e}", style="bold red"
)
if require and not attribute:
console.print(
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
style="bold red",
)
raise SystemExit
return attribute
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
"""Fetch the environment variables from a .env file and return them as a dictionary."""
try:
with open(env_file_path, "r") as f:
env_content = f.read()
env_dict = {}
for line in env_content.splitlines():
if line.strip() and not line.strip().startswith("#"):
key, value = line.split("=", 1)
env_dict[key.strip()] = value.strip()
return env_dict
except FileNotFoundError:
console.print(f"Error: {env_file_path} not found.", style="bold red")
except Exception as e:
console.print(f"Error reading the .env file: {e}", style="bold red")
return {}
def tree_copy(source: Path, destination: Path) -> None:
"""Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source):
source_item = os.path.join(source, item)
destination_item = os.path.join(destination, item)
if os.path.isdir(source_item):
shutil.copytree(source_item, destination_item)
else:
shutil.copy2(source_item, destination_item)
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
"""Recursively searches through a directory, replacing a target string in
both file contents and filenames with a specified replacement string.
"""
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
for filename in files:
filepath = os.path.join(path, filename)
with open(filepath, "r", encoding="utf-8", errors="ignore") as file:
contents = file.read()
with open(filepath, "w") as file:
file.write(contents.replace(find, replace))
if find in filename:
new_filename = filename.replace(find, replace)
new_filepath = os.path.join(path, new_filename)
os.rename(filepath, new_filepath)
for dirname in dirs:
if find in dirname:
new_dirname = dirname.replace(find, replace)
new_dirpath = os.path.join(path, new_dirname)
old_dirpath = os.path.join(path, dirname)
os.rename(old_dirpath, new_dirpath)
def load_env_vars(folder_path: Path) -> dict[str, Any]:
"""Loads environment variables from a .env file in the specified folder path."""
env_file_path = folder_path / ".env"
env_vars = {}
if env_file_path.exists():
with open(env_file_path, "r") as file:
for line in file:
key, _, value = line.strip().partition("=")
if key and value:
env_vars[key] = value
return env_vars
def update_env_vars(
env_vars: dict[str, Any], provider: str, model: str
) -> dict[str, Any] | None:
"""Updates environment variables with the API key for the selected provider and model."""
provider_config = cast(
list[str],
ENV_VARS.get(
provider,
[
click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
)
],
),
)
api_key_var = provider_config[0]
if api_key_var not in env_vars:
try:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
)
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return None
else:
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
env_vars["MODEL"] = model
click.secho(f"Selected model: {model}", fg="green")
return env_vars
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
"""Writes environment variables to a .env file in the specified folder."""
env_file_path = folder_path / ".env"
with open(env_file_path, "w") as file:
for key, value in env_vars.items():
file.write(f"{key.upper()}={value}\n")
def is_valid_tool(obj: Any) -> bool:
"""Check if an object is a valid tool class.
Works without importing crewai by checking MRO class names.
Falls back to crewai's ``is_valid_tool`` when available.
"""
try:
from crewai.cli.utils import is_valid_tool as _core_is_valid_tool
return _core_is_valid_tool(obj)
except ImportError:
pass
if isclass(obj):
try:
return any(base.__name__ == "BaseTool" for base in getmro(obj))
except (TypeError, AttributeError):
return False
return False
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
"""Extract available tool classes from the project's __init__.py files."""
try:
init_files = Path(dir_path).glob("**/__init__.py")
available_exports: list[dict[str, Any]] = []
for init_file in init_files:
tools = _load_tools_from_init(init_file)
available_exports.extend(tools)
if not available_exports:
_print_no_tools_warning()
raise SystemExit(1)
return available_exports
except SystemExit:
raise
except Exception as e:
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
console.print(
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
)
raise SystemExit(1) from e
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
"""Load and validate tools from a given __init__.py file."""
import importlib.util as _importlib_util
spec = _importlib_util.spec_from_file_location("temp_module", init_file)
if not spec or not spec.loader:
return []
module = _importlib_util.module_from_spec(spec)
sys.modules["temp_module"] = module
try:
spec.loader.exec_module(module)
if not hasattr(module, "__all__"):
console.print(
f"Warning: No __all__ defined in {init_file}",
style="bold yellow",
)
raise SystemExit(1)
return [
{"name": name}
for name in module.__all__
if hasattr(module, name) and is_valid_tool(getattr(module, name))
]
except SystemExit:
raise
except Exception as e:
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
raise SystemExit(1) from e
finally:
sys.modules.pop("temp_module", None)
def _print_no_tools_warning() -> None:
"""Display warning and usage instructions if no tools were found."""
console.print(
"\n[bold yellow]Warning: No valid tools were exposed in your __init__.py file![/bold yellow]"
)
console.print(
"Your __init__.py file must contain all classes that inherit from [bold]BaseTool[/bold] "
"or functions decorated with [bold]@tool[/bold]."
)
console.print(
"\nExample:\n[dim]# In your __init__.py file[/dim]\n"
"[green]__all__ = ['YourTool', 'your_tool_function'][/green]\n\n"
"[dim]# In your tool.py file[/dim]\n"
"[green]from crewai.tools import BaseTool, tool\n\n"
"# Tool class example\n"
"class YourTool(BaseTool):\n"
' name = "your_tool"\n'
' description = "Your tool description"\n'
" # ... rest of implementation\n\n"
"# Decorated function example\n"
"@tool\n"
"def your_tool_function(text: str) -> str:\n"
' """Your tool description"""\n'
" # ... implementation\n"
" return result\n"
)
def build_env_with_tool_repository_credentials(
repository_handle: str,
) -> dict[str, Any]:
repository_handle = repository_handle.upper().replace("-", "_")
settings = Settings()
env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or ""
)
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or ""
)
return env

View File

@@ -1,215 +0,0 @@
"""Version utilities for CrewAI CLI."""
from collections.abc import Mapping
from datetime import datetime, timedelta
from functools import lru_cache
import importlib.metadata
import json
from pathlib import Path
from typing import Any
from urllib import request
from urllib.error import URLError
import appdirs
from packaging.version import InvalidVersion, Version, parse
@lru_cache(maxsize=1)
def _get_cache_file() -> Path:
"""Get the path to the version cache file.
Cached to avoid repeated filesystem operations.
"""
cache_dir = Path(appdirs.user_cache_dir("crewai"))
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / "version_cache.json"
def get_crewai_version() -> str:
"""Get the version number of CrewAI running the CLI."""
return importlib.metadata.version("crewai")
def _is_cache_valid(cache_data: Mapping[str, Any]) -> bool:
"""Check if the cache is still valid, less than 24 hours old."""
if "timestamp" not in cache_data:
return False
try:
cache_time = datetime.fromisoformat(str(cache_data["timestamp"]))
return datetime.now() - cache_time < timedelta(hours=24)
except (ValueError, TypeError):
return False
def _find_latest_non_yanked_version(
releases: Mapping[str, list[dict[str, Any]]],
) -> str | None:
"""Find the latest non-yanked version from PyPI releases data.
Args:
releases: PyPI releases dict mapping version strings to file info lists.
Returns:
The latest non-yanked version string, or None if all versions are yanked.
"""
best_version: Version | None = None
best_version_str: str | None = None
for version_str, files in releases.items():
try:
v = parse(version_str)
except InvalidVersion:
continue
if v.is_prerelease or v.is_devrelease:
continue
if not files:
continue
all_yanked = all(f.get("yanked", False) for f in files)
if all_yanked:
continue
if best_version is None or v > best_version:
best_version = v
best_version_str = version_str
return best_version_str
def _is_version_yanked(
version_str: str,
releases: Mapping[str, list[dict[str, Any]]],
) -> tuple[bool, str]:
"""Check if a specific version is yanked.
Args:
version_str: The version string to check.
releases: PyPI releases dict mapping version strings to file info lists.
Returns:
Tuple of (is_yanked, yanked_reason).
"""
files = releases.get(version_str, [])
if not files:
return False, ""
all_yanked = all(f.get("yanked", False) for f in files)
if not all_yanked:
return False, ""
for f in files:
reason = f.get("yanked_reason", "")
if reason:
return True, str(reason)
return True, ""
def get_latest_version_from_pypi(timeout: int = 2) -> str | None:
"""Get the latest non-yanked version of CrewAI from PyPI.
Args:
timeout: Request timeout in seconds.
Returns:
Latest non-yanked version string or None if unable to fetch.
"""
cache_file = _get_cache_file()
if cache_file.exists():
try:
cache_data = json.loads(cache_file.read_text())
if _is_cache_valid(cache_data) and "current_version" in cache_data:
version: str | None = cache_data.get("version")
return version
except (json.JSONDecodeError, OSError):
pass
try:
with request.urlopen(
"https://pypi.org/pypi/crewai/json", timeout=timeout
) as response:
data = json.loads(response.read())
releases: dict[str, list[dict[str, Any]]] = data["releases"]
latest_version = _find_latest_non_yanked_version(releases)
current_version = get_crewai_version()
is_yanked, yanked_reason = _is_version_yanked(current_version, releases)
cache_data = {
"version": latest_version,
"timestamp": datetime.now().isoformat(),
"current_version": current_version,
"current_version_yanked": is_yanked,
"current_version_yanked_reason": yanked_reason,
}
cache_file.write_text(json.dumps(cache_data))
return latest_version
except (URLError, json.JSONDecodeError, KeyError, OSError):
return None
def is_current_version_yanked() -> tuple[bool, str]:
"""Check if the currently installed version has been yanked on PyPI.
Reads from cache if available, otherwise triggers a fetch.
Returns:
Tuple of (is_yanked, yanked_reason).
"""
cache_file = _get_cache_file()
if cache_file.exists():
try:
cache_data = json.loads(cache_file.read_text())
if _is_cache_valid(cache_data) and "current_version" in cache_data:
current = get_crewai_version()
if cache_data.get("current_version") == current:
return (
bool(cache_data.get("current_version_yanked", False)),
str(cache_data.get("current_version_yanked_reason", "")),
)
except (json.JSONDecodeError, OSError):
pass
get_latest_version_from_pypi()
try:
cache_data = json.loads(cache_file.read_text())
return (
bool(cache_data.get("current_version_yanked", False)),
str(cache_data.get("current_version_yanked_reason", "")),
)
except (json.JSONDecodeError, OSError):
return False, ""
def check_version() -> tuple[str, str | None]:
"""Check current and latest versions.
Returns:
Tuple of (current_version, latest_version).
latest_version is None if unable to fetch from PyPI.
"""
current = get_crewai_version()
latest = get_latest_version_from_pypi()
return current, latest
def is_newer_version_available() -> tuple[bool, str, str | None]:
"""Check if a newer version is available.
Returns:
Tuple of (is_newer, current_version, latest_version).
"""
current, latest = check_version()
if latest is None:
return False, current, None
try:
return parse(latest) > parse(current), current, latest
except (InvalidVersion, TypeError):
return False, current, latest

View File

@@ -1,91 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.auth0 import Auth0Provider
class TestAuth0Provider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="auth0",
domain="test-domain.auth0.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = Auth0Provider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = Auth0Provider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "auth0"
assert provider.settings.domain == "test-domain.auth0.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.auth0.com/oauth/device/code"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="my-company.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://my-company.auth0.com/oauth/device/code"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.auth0.com/oauth/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="another-domain.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://another-domain.auth0.com/oauth/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.auth0.com/.well-known/jwks.json"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="dev.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://dev.auth0.com/.well-known/jwks.json"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.auth0.com/"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="prod.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_issuer = "https://prod.auth0.com/"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -1,141 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.entra_id import EntraIdProvider
class TestEntraIdProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "openid profile email api://crewai-cli-dev/read"
}
)
self.provider = EntraIdProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = EntraIdProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "entra_id"
assert provider.settings.domain == "tenant-id-abcdef123456"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="my-company.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="another-domain.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="dev.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="other-tenant-id-xpto",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="entra_id",
domain="test-tenant-id",
client_id="test-client-id",
audience=None,
)
provider = EntraIdProvider(settings)
with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["scope"])
def test_get_oauth_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
def test_base_url(self):
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"

View File

@@ -1,138 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.keycloak import KeycloakProvider
class TestKeycloakProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="keycloak",
domain="keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
self.provider = KeycloakProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = KeycloakProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "keycloak"
assert provider.settings.domain == "keycloak.example.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
assert provider.settings.extra.get("realm") == "test-realm"
def test_get_authorize_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/auth/device"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="auth.company.com",
client_id="test-client",
audience="test-audience",
extra={
"realm": "my-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://auth.company.com/realms/my-realm/protocol/openid-connect/auth/device"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="sso.enterprise.com",
client_id="test-client",
audience="test-audience",
extra={
"realm": "enterprise-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://sso.enterprise.com/realms/enterprise-realm/protocol/openid-connect/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/certs"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="identity.org",
client_id="test-client",
audience="test-audience",
extra={
"realm": "org-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://identity.org/realms/org-realm/protocol/openid-connect/certs"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://keycloak.example.com/realms/test-realm"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="login.myapp.io",
client_id="test-client",
audience="test-audience",
extra={
"realm": "app-realm"
}
)
provider = KeycloakProvider(settings)
expected_issuer = "https://login.myapp.io/realms/app-realm"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert self.provider.get_required_fields() == ["realm"]
def test_oauth2_base_url(self):
assert self.provider._oauth2_base_url() == "https://keycloak.example.com"
def test_oauth2_base_url_strips_https_prefix(self):
settings = Oauth2Settings(
provider="keycloak",
domain="https://keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
provider = KeycloakProvider(settings)
assert provider._oauth2_base_url() == "https://keycloak.example.com"
def test_oauth2_base_url_strips_http_prefix(self):
settings = Oauth2Settings(
provider="keycloak",
domain="http://keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
provider = KeycloakProvider(settings)
assert provider._oauth2_base_url() == "https://keycloak.example.com"

View File

@@ -1,257 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.okta import OktaProvider
class TestOktaProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience="test-audience",
)
self.provider = OktaProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = OktaProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "okta"
assert provider.settings.domain == "test-domain.okta.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="my-company.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="another-domain.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="dev.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.okta.com/oauth2/default"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="prod.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_issuer = "https://prod.okta.com/oauth2/default"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
)
provider = OktaProvider(settings)
with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
def test_oauth2_base_url(self):
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
def test_oauth2_base_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
def test_oauth2_base_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"

View File

@@ -1,100 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.workos import WorkosProvider
class TestWorkosProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = WorkosProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = WorkosProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "workos"
assert provider.settings.domain == "login.company.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.company.com/oauth2/device_authorization"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="login.example.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://login.example.com/oauth2/device_authorization"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.company.com/oauth2/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="api.workos.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://api.workos.com/oauth2/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.company.com/oauth2/jwks"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="auth.enterprise.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://auth.enterprise.com/oauth2/jwks"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.company.com"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="sso.company.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_issuer = "https://sso.company.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_fallback_to_default(self):
settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience=None
)
provider = WorkosProvider(settings)
assert provider.get_audience() == ""
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -1,348 +0,0 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, call, patch
import pytest
import httpx
from crewai_cli.authentication.main import AuthenticationCommand
from crewai_cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
)
class TestAuthenticationCommand:
def setup_method(self):
# Mock Settings so we always use default constants regardless of local config.
with patch("crewai_cli.authentication.main.Settings") as mock_settings:
instance = mock_settings.return_value
instance.oauth2_provider = "workos"
instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN
instance.oauth2_client_id = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID
instance.oauth2_audience = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE
instance.oauth2_extra = {}
self.auth_command = AuthenticationCommand()
@pytest.mark.parametrize(
"user_provider,expected_urls",
[
(
"workos",
{
"device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization",
"token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token",
"client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
},
),
],
)
@patch("crewai_cli.authentication.main.AuthenticationCommand._get_device_code")
@patch(
"crewai_cli.authentication.main.AuthenticationCommand._display_auth_instructions"
)
@patch("crewai_cli.authentication.main.AuthenticationCommand._poll_for_token")
@patch("crewai_cli.authentication.main.console.print")
def test_login(
self,
mock_console_print,
mock_poll,
mock_display,
mock_get_device,
user_provider,
expected_urls,
):
mock_get_device.return_value = {
"device_code": "test_code",
"user_code": "123456",
}
self.auth_command.login()
mock_console_print.assert_called_once_with(
"Signing in to CrewAI AMP...\n", style="bold blue"
)
mock_get_device.assert_called_once()
mock_display.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"}
)
mock_poll.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"},
)
assert (
self.auth_command.oauth2_provider.get_client_id()
== expected_urls["client_id"]
)
assert (
self.auth_command.oauth2_provider.get_audience()
== expected_urls["audience"]
)
assert (
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
)
@patch("crewai_cli.authentication.main.webbrowser")
@patch("crewai_cli.authentication.main.console.print")
def test_display_auth_instructions(self, mock_console_print, mock_webbrowser):
device_code_data = {
"verification_uri_complete": "https://example.com/auth",
"user_code": "123456",
}
self.auth_command._display_auth_instructions(device_code_data)
expected_calls = [
call("1. Navigate to: ", "https://example.com/auth"),
call("2. Enter the following code: ", "123456"),
]
mock_console_print.assert_has_calls(expected_calls)
mock_webbrowser.open.assert_called_once_with("https://example.com/auth")
@pytest.mark.parametrize(
"user_provider,jwt_config",
[
(
"workos",
{
"jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks",
"issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}",
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
},
),
],
)
@pytest.mark.parametrize("has_expiration", [True, False])
@patch("crewai_cli.authentication.main.validate_jwt_token")
@patch("crewai_cli.authentication.main.TokenManager.save_tokens")
def test_validate_and_save_token(
self,
mock_save_tokens,
mock_validate_jwt,
user_provider,
jwt_config,
has_expiration,
):
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.workos import WorkosProvider
if user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(
settings=Oauth2Settings(
provider=user_provider,
client_id="test-client-id",
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
audience=jwt_config["audience"],
)
)
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
if has_expiration:
future_timestamp = int((datetime.now() + timedelta(days=100)).timestamp())
decoded_token = {"exp": future_timestamp}
else:
decoded_token = {}
mock_validate_jwt.return_value = decoded_token
self.auth_command._validate_and_save_token(token_data)
mock_validate_jwt.assert_called_once_with(
jwt_token="test_access_token",
jwks_url=jwt_config["jwks_url"],
issuer=jwt_config["issuer"],
audience=jwt_config["audience"],
)
if has_expiration:
mock_save_tokens.assert_called_once_with(
"test_access_token", future_timestamp
)
else:
mock_save_tokens.assert_called_once_with("test_access_token", 0)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai_cli.authentication.main.Settings")
@patch("crewai_cli.authentication.main.console.print")
def test_login_to_tool_repository_success(
self, mock_console_print, mock_settings, mock_tool_command
):
mock_tool_instance = MagicMock()
mock_tool_command.return_value = mock_tool_instance
mock_settings_instance = MagicMock()
mock_settings_instance.org_name = "Test Org"
mock_settings_instance.org_uuid = "test-uuid-123"
mock_settings.return_value = mock_settings_instance
self.auth_command._login_to_tool_repository()
mock_tool_command.assert_called_once()
mock_tool_instance.login.assert_called_once()
expected_calls = [
call(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
),
call("Success!\n", style="bold green"),
call(
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
style="green",
),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai_cli.authentication.main.console.print")
def test_login_to_tool_repository_error(
self, mock_console_print, mock_tool_command
):
mock_tool_instance = MagicMock()
mock_tool_instance.login.side_effect = Exception("Tool repository error")
mock_tool_command.return_value = mock_tool_instance
self.auth_command._login_to_tool_repository()
mock_tool_command.assert_called_once()
mock_tool_instance.login.assert_called_once()
expected_calls = [
call(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
),
call(
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
style="yellow",
),
call(
"Other features will work normally, but you may experience limitations with downloading and publishing tools.\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.authentication.main.httpx.post")
def test_get_device_code(self, mock_post):
mock_response = MagicMock()
mock_response.json.return_value = {
"device_code": "test_device_code",
"user_code": "123456",
"verification_uri_complete": "https://example.com/auth",
}
mock_post.return_value = mock_response
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
"https://example.com/device"
)
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
result = self.auth_command._get_device_code()
mock_post.assert_called_once_with(
url="https://example.com/device",
data={
"client_id": "test_client",
"scope": "openid profile email",
"audience": "test_audience",
},
timeout=20,
)
assert result == {
"device_code": "test_device_code",
"user_code": "123456",
"verification_uri_complete": "https://example.com/auth",
}
@patch("crewai_cli.authentication.main.httpx.post")
@patch("crewai_cli.authentication.main.console.print")
def test_poll_for_token_success(self, mock_console_print, mock_post):
mock_response_success = MagicMock()
mock_response_success.status_code = 200
mock_response_success.json.return_value = {
"access_token": "test_access_token",
"id_token": "test_id_token",
}
mock_post.return_value = mock_response_success
device_code_data = {"device_code": "test_device_code", "interval": 1}
with (
patch.object(
self.auth_command, "_validate_and_save_token"
) as mock_validate,
patch.object(
self.auth_command, "_login_to_tool_repository"
) as mock_tool_login,
):
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_token_url.return_value = (
"https://example.com/token"
)
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command._poll_for_token(device_code_data)
mock_post.assert_called_once_with(
"https://example.com/token",
data={
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": "test_device_code",
"client_id": "test_client",
},
timeout=30,
)
mock_validate.assert_called_once()
mock_tool_login.assert_called_once()
expected_calls = [
call("\nWaiting for authentication... ", style="bold blue", end=""),
call("Success!", style="bold green"),
call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.authentication.main.httpx.post")
@patch("crewai_cli.authentication.main.console.print")
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
mock_response_pending = MagicMock()
mock_response_pending.status_code = 400
mock_response_pending.json.return_value = {"error": "authorization_pending"}
mock_post.return_value = mock_response_pending
device_code_data = {
"device_code": "test_device_code",
"interval": 0.1, # Short interval for testing
}
self.auth_command._poll_for_token(device_code_data)
mock_console_print.assert_any_call(
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
@patch("crewai_cli.authentication.main.httpx.post")
def test_poll_for_token_error(self, mock_post):
"""Test the method to poll for token (error path)."""
# Setup mock to return error
mock_response_error = MagicMock()
mock_response_error.status_code = 400
mock_response_error.json.return_value = {
"error": "access_denied",
"error_description": "User denied access",
}
mock_post.return_value = mock_response_error
device_code_data = {"device_code": "test_device_code", "interval": 1}
with pytest.raises(httpx.HTTPError):
self.auth_command._poll_for_token(device_code_data)

View File

@@ -1,107 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import jwt
from crewai_cli.authentication.utils import validate_jwt_token
@patch("crewai_cli.authentication.utils.PyJWKClient", return_value=MagicMock())
@patch("crewai_cli.authentication.utils.jwt")
class TestUtils(unittest.TestCase):
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.return_value = {"exp": 1719859200}
# Create signing key object mock with a .key attribute
mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock(
key="mock_signing_key"
)
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
decoded_token = validate_jwt_token(
jwt_token=jwt_token,
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
mock_jwt.decode.assert_called_with(
jwt_token,
"mock_signing_key",
algorithms=["RS256"],
audience="app_id_xxxx",
issuer="https://mock_issuer",
leeway=10.0,
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
mock_pyjwkclient.assert_called_once_with("https://mock_jwks_url")
self.assertEqual(decoded_token, {"exp": 1719859200})
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_missing_required_claims(
self, mock_jwt, mock_pyjwkclient
):
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidTokenError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)

View File

@@ -1,255 +0,0 @@
from pathlib import Path
from unittest import mock
import pytest
from click.testing import CliRunner
from crewai_cli.cli import (
deploy_create,
deploy_list,
deploy_logs,
deploy_push,
deploy_remove,
deply_status,
flow_add_crew,
login,
reset_memories,
test,
train,
version,
)
@pytest.fixture
def runner():
return CliRunner()
@mock.patch("crewai_cli.cli.train_crew")
def test_train_default_iterations(train_crew, runner):
result = runner.invoke(train)
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the Crew for 5 iterations" in result.output
@mock.patch("crewai_cli.cli.train_crew")
def test_train_custom_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "10"])
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the Crew for 10 iterations" in result.output
@mock.patch("crewai_cli.cli.train_crew")
def test_train_invalid_string_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "invalid"])
train_crew.assert_not_called()
assert result.exit_code == 2
assert (
"Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
in result.output
)
def test_reset_no_memory_flags(runner):
result = runner.invoke(
reset_memories,
)
assert (
result.output
== "Please specify at least one memory type to reset using the appropriate flags.\n"
)
def test_version_flag(runner):
result = runner.invoke(version)
assert result.exit_code == 0
assert "crewai version:" in result.output
def test_version_command(runner):
result = runner.invoke(version)
assert result.exit_code == 0
assert "crewai version:" in result.output
def test_version_command_with_tools(runner):
result = runner.invoke(version, ["--tools"])
assert result.exit_code == 0
assert "crewai version:" in result.output
assert (
"crewai tools version:" in result.output
or "crewai tools not installed" in result.output
)
@mock.patch("crewai_cli.cli.evaluate_crew")
def test_test_default_iterations(evaluate_crew, runner):
result = runner.invoke(test)
evaluate_crew.assert_called_once_with(3, "gpt-4o-mini")
assert result.exit_code == 0
assert "Testing the crew for 3 iterations with model gpt-4o-mini" in result.output
@mock.patch("crewai_cli.cli.evaluate_crew")
def test_test_custom_iterations(evaluate_crew, runner):
result = runner.invoke(test, ["--n_iterations", "5", "--model", "gpt-4o"])
evaluate_crew.assert_called_once_with(5, "gpt-4o")
assert result.exit_code == 0
assert "Testing the crew for 5 iterations with model gpt-4o" in result.output
@mock.patch("crewai_cli.cli.evaluate_crew")
def test_test_invalid_string_iterations(evaluate_crew, runner):
result = runner.invoke(test, ["--n_iterations", "invalid"])
evaluate_crew.assert_not_called()
assert result.exit_code == 2
assert (
"Usage: test [OPTIONS]\nTry 'test --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
in result.output
)
@mock.patch("crewai_cli.cli.AuthenticationCommand")
def test_login(command, runner):
mock_auth = command.return_value
result = runner.invoke(login)
assert result.exit_code == 0
mock_auth.login.assert_called_once()
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_create(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_create)
assert result.exit_code == 0
mock_deploy.create_crew.assert_called_once()
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_list(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_list)
assert result.exit_code == 0
mock_deploy.list_crews.assert_called_once()
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_push(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_push, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.deploy.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_push_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_push)
assert result.exit_code == 0
mock_deploy.deploy.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_status(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deply_status, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.get_crew_status.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_status_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deply_status)
assert result.exit_code == 0
mock_deploy.get_crew_status.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_logs(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_logs, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.get_crew_logs.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_logs_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_logs)
assert result.exit_code == 0
mock_deploy.get_crew_logs.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_remove(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_remove, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.remove_crew.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_remove_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_remove)
assert result.exit_code == 0
mock_deploy.remove_crew.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.add_crew_to_flow.create_embedded_crew")
@mock.patch("pathlib.Path.exists", return_value=True)
def test_flow_add_crew(mock_path_exists, mock_create_embedded_crew, runner):
crew_name = "new_crew"
result = runner.invoke(flow_add_crew, [crew_name])
assert result.exit_code == 0, f"Command failed with output: {result.output}"
assert f"Adding crew {crew_name} to the flow" in result.output
mock_create_embedded_crew.assert_called_once()
call_args, call_kwargs = mock_create_embedded_crew.call_args
assert call_args[0] == crew_name
assert "parent_folder" in call_kwargs
assert isinstance(call_kwargs["parent_folder"], Path)
def test_add_crew_to_flow_not_in_root(runner):
with mock.patch("pathlib.Path.exists", autospec=True) as mock_exists:
def exists_side_effect(self):
if self.name == "pyproject.toml":
return False
return True
mock_exists.side_effect = exists_side_effect
result = runner.invoke(flow_add_crew, ["new_crew"])
assert result.exit_code != 0
assert "This command must be run from the root of a flow project." in str(
result.output
)

View File

@@ -1,148 +0,0 @@
import json
import shutil
import tempfile
import unittest
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
from crewai_cli.config import (
CLI_SETTINGS_KEYS,
DEFAULT_CLI_SETTINGS,
USER_SETTINGS_KEYS,
Settings,
)
from crewai_cli.shared.token_manager import TokenManager
class TestSettings(unittest.TestCase):
def setUp(self):
self.test_dir = Path(tempfile.mkdtemp())
self.config_path = self.test_dir / "settings.json"
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_empty_initialization(self):
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)
self.assertIsNone(settings.tool_repository_password)
def test_initialization_with_data(self):
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
)
self.assertEqual(settings.tool_repository_username, "user1")
self.assertIsNone(settings.tool_repository_password)
def test_initialization_with_existing_file(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump({"tool_repository_username": "file_user"}, f)
settings = Settings(config_path=self.config_path)
self.assertEqual(settings.tool_repository_username, "file_user")
def test_merge_file_and_input_data(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump(
{
"tool_repository_username": "file_user",
"tool_repository_password": "file_pass",
},
f,
)
settings = Settings(
config_path=self.config_path, tool_repository_username="new_user"
)
self.assertEqual(settings.tool_repository_username, "new_user")
self.assertEqual(settings.tool_repository_password, "file_pass")
def test_clear_user_settings(self):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
settings = Settings(config_path=self.config_path, **user_settings)
settings.clear_user_settings()
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
@patch("crewai_cli.config.TokenManager")
def test_reset_settings(self, mock_token_manager):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
settings = Settings(
config_path=self.config_path, **user_settings, **cli_settings
)
mock_token_manager.return_value = MagicMock()
TokenManager().save_tokens(
"aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp()
)
settings.reset()
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
for key in cli_settings.keys():
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
mock_token_manager.return_value.clear_tokens.assert_called_once()
def test_dump_new_settings(self):
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertEqual(saved_data["tool_repository_username"], "user1")
def test_update_existing_settings(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump({"existing_setting": "value"}, f)
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertEqual(saved_data["existing_setting"], "value")
self.assertEqual(saved_data["tool_repository_username"], "user1")
def test_none_values(self):
settings = Settings(config_path=self.config_path, tool_repository_username=None)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertIsNone(saved_data.get("tool_repository_username"))
def test_invalid_json_in_config(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
f.write("invalid json")
try:
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)
except json.JSONDecodeError:
self.fail("Settings initialization should handle invalid JSON")
def test_empty_config_file(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
self.config_path.touch()
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)

View File

@@ -1,20 +0,0 @@
from crewai_cli.constants import ENV_VARS, MODELS, PROVIDERS
def test_huggingface_in_providers():
"""Test that Huggingface is in the PROVIDERS list."""
assert "huggingface" in PROVIDERS
def test_huggingface_env_vars():
"""Test that Huggingface environment variables are properly configured."""
assert "huggingface" in ENV_VARS
assert any(
detail.get("key_name") == "HF_TOKEN" for detail in ENV_VARS["huggingface"]
)
def test_huggingface_models():
"""Test that Huggingface models are properly configured."""
assert "huggingface" in MODELS
assert len(MODELS["huggingface"]) > 0

View File

@@ -1,101 +0,0 @@
import pytest
from crewai_cli.git import Repository
@pytest.fixture()
def repository(fp):
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
fp.register(["git", "fetch"], stdout="")
return Repository(path=".")
def test_init_with_invalid_git_repo(fp):
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(
["git", "rev-parse", "--is-inside-work-tree"],
returncode=1,
stderr="fatal: not a git repository\n",
)
with pytest.raises(ValueError):
Repository(path="invalid/path")
def test_is_git_not_installed(fp):
fp.register(["git", "--version"], returncode=1)
with pytest.raises(
ValueError, match="Git is not installed or not found in your PATH."
):
Repository(path=".")
def test_status(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
assert repository.status() == "## main...origin/main [ahead 1]"
def test_has_uncommitted_changes(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main\n M somefile.txt\n",
)
assert repository.has_uncommitted_changes() is True
def test_is_ahead_or_behind(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
assert repository.is_ahead_or_behind() is True
def test_is_synced_when_synced(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
)
fp.register(
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
)
assert repository.is_synced() is True
def test_is_synced_with_uncommitted_changes(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main\n M somefile.txt\n",
)
assert repository.is_synced() is False
def test_is_synced_when_ahead_or_behind(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
assert repository.is_synced() is False
def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n M somefile.txt\n",
)
assert repository.is_synced() is False
def test_origin_url(fp, repository):
fp.register(
["git", "remote", "get-url", "origin"],
stdout="https://github.com/user/repo.git\n",
)
assert repository.origin_url() == "https://github.com/user/repo.git"

View File

@@ -1,356 +0,0 @@
import os
import unittest
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from crewai_cli.plus_api import PlusAPI
class TestPlusAPI(unittest.TestCase):
def setUp(self):
self.api_key = "test_api_key"
self.api = PlusAPI(self.api_key)
self.org_uuid = "test-org-uuid"
def test_init(self):
self.assertEqual(self.api.api_key, self.api_key)
self.assertEqual(self.api.headers["Authorization"], f"Bearer {self.api_key}")
self.assertEqual(self.api.headers["Content-Type"], "application/json")
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
self.assertTrue(self.api.headers["X-Crewai-Version"])
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_login_to_tool_repository(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.login_to_tool_repository()
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools/login"
)
self.assertEqual(response, mock_response)
def assert_request_with_org_id(
self, mock_client_instance, method: str, endpoint: str, **kwargs
):
mock_client_instance.request.assert_called_once_with(
method,
f"{os.getenv('CREWAI_PLUS_URL')}{endpoint}",
headers={
"Authorization": ANY,
"Content-Type": ANY,
"User-Agent": ANY,
"X-Crewai-Version": ANY,
"X-Crewai-Organization-Id": self.org_uuid,
},
**kwargs,
)
@patch("crewai_cli.plus_api.Settings")
@patch("crewai_cli.plus_api.httpx.Client")
def test_login_to_tool_repository_with_org_uuid(
self, mock_client_class, mock_settings_class
):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api.login_to_tool_repository()
self.assert_request_with_org_id(
mock_client_instance, "POST", "/crewai_plus/api/v1/tools/login"
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get_tool("test_tool_handle")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.Settings")
@patch("crewai_cli.plus_api.httpx.Client")
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api.get_tool("test_tool_handle")
self.assert_request_with_org_id(
mock_client_instance, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_publish_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": None,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.Settings")
@patch("crewai_cli.plus_api.httpx.Client")
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
expected_params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": None,
}
self.assert_request_with_org_id(
mock_client_instance, "POST", "/crewai_plus/api/v1/tools", json=expected_params
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_publish_tool_without_description(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = False
version = "2.0.0"
description = None
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": None,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.httpx.Client")
def test_make_request(self, mock_client_class):
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api._make_request("GET", "test_endpoint")
mock_client_class.assert_called_once_with(trust_env=False, verify=True)
mock_client_instance.request.assert_called_once_with(
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_deploy_by_name(self, mock_make_request):
self.api.deploy_by_name("test_project")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_deploy_by_uuid(self, mock_make_request):
self.api.deploy_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_status_by_name(self, mock_make_request):
self.api.crew_status_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_status_by_uuid(self, mock_make_request):
self.api.crew_status_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_by_name(self, mock_make_request):
self.api.crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
)
self.api.crew_by_name("test_project", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_by_uuid(self, mock_make_request):
self.api.crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
)
self.api.crew_by_uuid("test_uuid", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_delete_crew_by_name(self, mock_make_request):
self.api.delete_crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_delete_crew_by_uuid(self, mock_make_request):
self.api.delete_crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_list_crews(self, mock_make_request):
self.api.list_crews()
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_create_crew(self, mock_make_request):
payload = {"name": "test_crew"}
self.api.create_crew(payload)
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews", json=payload
)
@patch("crewai_cli.plus_api.Settings")
@patch.dict(os.environ, {"CREWAI_PLUS_URL": ""})
def test_custom_base_url(self, mock_settings_class):
mock_settings = MagicMock()
mock_settings.enterprise_base_url = "https://custom-url.com/api"
mock_settings_class.return_value = mock_settings
custom_api = PlusAPI("test_key")
self.assertEqual(
custom_api.base_url,
"https://custom-url.com/api",
)
@patch.dict(os.environ, {"CREWAI_PLUS_URL": "https://custom-url-from-env.com"})
def test_custom_base_url_from_env(self):
custom_api = PlusAPI("test_key")
self.assertEqual(
custom_api.base_url,
"https://custom-url-from-env.com",
)
@pytest.mark.asyncio
@patch("httpx.AsyncClient")
async def test_get_agent(mock_async_client_class):
api = PlusAPI("test_api_key")
mock_response = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
response = await api.get_agent("test_agent_handle")
mock_client_instance.get.assert_called_once_with(
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
headers=api.headers,
)
assert response == mock_response
@pytest.mark.asyncio
@patch("httpx.AsyncClient")
@patch("crewai_cli.plus_api.Settings")
async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class):
org_uuid = "test-org-uuid"
mock_settings = MagicMock()
mock_settings.org_uuid = org_uuid
mock_settings.enterprise_base_url = os.getenv("CREWAI_PLUS_URL")
mock_settings_class.return_value = mock_settings
api = PlusAPI("test_api_key")
mock_response = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
response = await api.get_agent("test_agent_handle")
mock_client_instance.get.assert_called_once_with(
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
headers=api.headers,
)
assert "X-Crewai-Organization-Id" in api.headers
assert api.headers["X-Crewai-Organization-Id"] == org_uuid
assert response == mock_response

View File

@@ -1,294 +0,0 @@
"""Tests for TokenManager with atomic file operations."""
import json
import os
import tempfile
import unittest
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import patch
from cryptography.fernet import Fernet
from crewai_cli.shared.token_manager import TokenManager
class TestTokenManager(unittest.TestCase):
"""Test cases for TokenManager."""
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None:
"""Set up test fixtures."""
mock_get_key.return_value = Fernet.generate_key()
self.token_manager = TokenManager()
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_get_or_create_key_existing(
self,
mock_get_or_create: unittest.mock.MagicMock,
mock_read: unittest.mock.MagicMock,
) -> None:
"""Test that existing key is returned when present."""
mock_key = Fernet.generate_key()
mock_get_or_create.return_value = mock_key
token_manager = TokenManager()
result = token_manager.key
self.assertEqual(result, mock_key)
def test_get_or_create_key_new(self) -> None:
"""Test that new key is created when none exists."""
mock_key = Fernet.generate_key()
with (
patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create,
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
):
result = self.token_manager._get_or_create_key()
self.assertEqual(result, mock_key)
mock_read.assert_called_with("secret.key")
mock_generate.assert_called_once()
mock_atomic_create.assert_called_once_with("secret.key", mock_key)
def test_get_or_create_key_race_condition(self) -> None:
"""Test that another process's key is used when atomic create fails."""
our_key = Fernet.generate_key()
their_key = Fernet.generate_key()
with (
patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create,
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=our_key),
):
result = self.token_manager._get_or_create_key()
self.assertEqual(result, their_key)
self.assertEqual(mock_read.call_count, 2)
@patch("crewai_cli.shared.token_manager.TokenManager._atomic_write_secure_file")
def test_save_tokens(
self, mock_write: unittest.mock.MagicMock
) -> None:
"""Test saving tokens encrypts and writes atomically."""
access_token = "test_token"
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
self.token_manager.save_tokens(access_token, expires_at)
mock_write.assert_called_once()
args = mock_write.call_args[0]
self.assertEqual(args[0], "tokens.enc")
decrypted_data = self.token_manager.fernet.decrypt(args[1])
data = json.loads(decrypted_data)
self.assertEqual(data["access_token"], access_token)
expiration = datetime.fromisoformat(data["expiration"])
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_valid(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test getting a valid non-expired token."""
access_token = "test_token"
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertEqual(result, access_token)
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_expired(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test that expired token returns None."""
access_token = "test_token"
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertIsNone(result)
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_not_found(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test that missing token file returns None."""
mock_read.return_value = None
result = self.token_manager.get_token()
self.assertIsNone(result)
@patch("crewai_cli.shared.token_manager.TokenManager._delete_secure_file")
def test_clear_tokens(
self, mock_delete: unittest.mock.MagicMock
) -> None:
"""Test clearing tokens deletes the token file."""
self.token_manager.clear_tokens()
mock_delete.assert_called_once_with("tokens.enc")
class TestAtomicFileOperations(unittest.TestCase):
"""Test atomic file operations directly."""
def setUp(self) -> None:
"""Set up test fixtures with temp directory."""
self.temp_dir = tempfile.mkdtemp()
self.original_get_path = TokenManager._get_secure_storage_path
# Patch to use temp directory
def mock_get_path() -> Path:
return Path(self.temp_dir)
TokenManager._get_secure_storage_path = staticmethod(mock_get_path)
def tearDown(self) -> None:
"""Clean up temp directory."""
TokenManager._get_secure_storage_path = staticmethod(self.original_get_path)
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic create succeeds for new file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
result = tm._atomic_create_secure_file("test.txt", b"content")
self.assertTrue(result)
file_path = Path(self.temp_dir) / "test.txt"
self.assertTrue(file_path.exists())
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_existing_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic create fails for existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
# Create file first
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"original")
result = tm._atomic_create_secure_file("test.txt", b"new content")
self.assertFalse(result)
self.assertEqual(file_path.read_bytes(), b"original")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic write creates new file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
tm._atomic_write_secure_file("test.txt", b"content")
file_path = Path(self.temp_dir) / "test.txt"
self.assertTrue(file_path.exists())
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_overwrites(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic write overwrites existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"original")
tm._atomic_write_secure_file("test.txt", b"new content")
self.assertEqual(file_path.read_bytes(), b"new content")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_no_temp_file_on_success(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test that temp file is cleaned up after successful write."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
tm._atomic_write_secure_file("test.txt", b"content")
# Check no temp files remain
temp_files = list(Path(self.temp_dir).glob(".test.txt.*"))
self.assertEqual(len(temp_files), 0)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test reading existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"content")
result = tm._read_secure_file("test.txt")
self.assertEqual(result, b"content")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test reading non-existent file returns None."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
result = tm._read_secure_file("nonexistent.txt")
self.assertIsNone(result)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test deleting existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"content")
tm._delete_secure_file("test.txt")
self.assertFalse(file_path.exists())
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test deleting non-existent file doesn't raise."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
# Should not raise
tm._delete_secure_file("nonexistent.txt")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,146 +0,0 @@
import os
import shutil
import tempfile
from pathlib import Path
import pytest
from crewai_cli import utils
@pytest.fixture
def temp_tree():
root_dir = tempfile.mkdtemp()
create_file(os.path.join(root_dir, "file1.txt"), "Hello, world!")
create_file(os.path.join(root_dir, "file2.txt"), "Another file")
os.mkdir(os.path.join(root_dir, "empty_dir"))
nested_dir = os.path.join(root_dir, "nested_dir")
os.mkdir(nested_dir)
create_file(os.path.join(nested_dir, "nested_file.txt"), "Nested content")
yield root_dir
shutil.rmtree(root_dir)
def create_file(path, content):
with open(path, "w") as f:
f.write(content)
def test_tree_find_and_replace_file_content(temp_tree):
utils.tree_find_and_replace(temp_tree, "world", "universe")
with open(os.path.join(temp_tree, "file1.txt"), "r") as f:
assert f.read() == "Hello, universe!"
def test_tree_find_and_replace_file_name(temp_tree):
old_path = os.path.join(temp_tree, "file2.txt")
new_path = os.path.join(temp_tree, "file2_renamed.txt")
os.rename(old_path, new_path)
utils.tree_find_and_replace(temp_tree, "renamed", "modified")
assert os.path.exists(os.path.join(temp_tree, "file2_modified.txt"))
assert not os.path.exists(new_path)
def test_tree_find_and_replace_directory_name(temp_tree):
utils.tree_find_and_replace(temp_tree, "empty", "renamed")
assert os.path.exists(os.path.join(temp_tree, "renamed_dir"))
assert not os.path.exists(os.path.join(temp_tree, "empty_dir"))
def test_tree_find_and_replace_nested_content(temp_tree):
utils.tree_find_and_replace(temp_tree, "Nested", "Updated")
with open(os.path.join(temp_tree, "nested_dir", "nested_file.txt"), "r") as f:
assert f.read() == "Updated content"
def test_tree_find_and_replace_no_matches(temp_tree):
utils.tree_find_and_replace(temp_tree, "nonexistent", "replacement")
assert set(os.listdir(temp_tree)) == {
"file1.txt",
"file2.txt",
"empty_dir",
"nested_dir",
}
def test_tree_copy_full_structure(temp_tree):
dest_dir = tempfile.mkdtemp()
try:
utils.tree_copy(temp_tree, dest_dir)
assert set(os.listdir(dest_dir)) == set(os.listdir(temp_tree))
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
assert os.path.isfile(os.path.join(dest_dir, "file2.txt"))
assert os.path.isdir(os.path.join(dest_dir, "empty_dir"))
assert os.path.isdir(os.path.join(dest_dir, "nested_dir"))
assert os.path.isfile(os.path.join(dest_dir, "nested_dir", "nested_file.txt"))
finally:
shutil.rmtree(dest_dir)
def test_tree_copy_preserve_content(temp_tree):
dest_dir = tempfile.mkdtemp()
try:
utils.tree_copy(temp_tree, dest_dir)
with open(os.path.join(dest_dir, "file1.txt"), "r") as f:
assert f.read() == "Hello, world!"
with open(os.path.join(dest_dir, "nested_dir", "nested_file.txt"), "r") as f:
assert f.read() == "Nested content"
finally:
shutil.rmtree(dest_dir)
def test_tree_copy_to_existing_directory(temp_tree):
dest_dir = tempfile.mkdtemp()
try:
create_file(os.path.join(dest_dir, "existing_file.txt"), "I was here first")
utils.tree_copy(temp_tree, dest_dir)
assert os.path.isfile(os.path.join(dest_dir, "existing_file.txt"))
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
finally:
shutil.rmtree(dest_dir)
@pytest.fixture
def temp_project_dir():
"""Create a temporary directory for testing tool extraction."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
def create_init_file(directory, content):
return create_file(directory / "__init__.py", content)
def test_extract_available_exports_empty_project(temp_project_dir, capsys):
with pytest.raises(SystemExit):
utils.extract_available_exports(dir_path=temp_project_dir)
captured = capsys.readouterr()
assert "No valid tools were exposed in your __init__.py file" in captured.out
def test_extract_available_exports_no_init_file(temp_project_dir, capsys):
(temp_project_dir / "some_file.py").write_text("print('hello')")
with pytest.raises(SystemExit):
utils.extract_available_exports(dir_path=temp_project_dir)
captured = capsys.readouterr()
assert "No valid tools were exposed in your __init__.py file" in captured.out
def test_extract_available_exports_empty_init_file(temp_project_dir, capsys):
create_init_file(temp_project_dir, "")
with pytest.raises(SystemExit):
utils.extract_available_exports(dir_path=temp_project_dir)
captured = capsys.readouterr()
assert "Warning: No __all__ defined in" in captured.out
# Tests for extract_available_exports with crewai.tools (BaseTool, @tool)
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.
# Tests for get_crews, get_flows, fetch_crews, is_valid_tool
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.

View File

@@ -1,372 +0,0 @@
"""Test for version management."""
import json
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
from crewai_cli.version import get_crewai_version as _get_ver
from crewai_cli.version import (
_find_latest_non_yanked_version,
_get_cache_file,
_is_cache_valid,
_is_version_yanked,
get_crewai_version,
get_latest_version_from_pypi,
is_current_version_yanked,
is_newer_version_available,
)
def test_dynamic_versioning_consistency() -> None:
"""Test that dynamic versioning provides consistent version across all access methods."""
cli_version = get_crewai_version()
package_version = _get_ver()
assert cli_version == package_version
assert package_version is not None
assert len(package_version.strip()) > 0
class TestVersionChecking:
"""Test version checking utilities."""
def test_get_crewai_version(self) -> None:
"""Test getting current crewai version."""
version = get_crewai_version()
assert isinstance(version, str)
assert len(version) > 0
def test_get_cache_file(self) -> None:
"""Test cache file path generation."""
cache_file = _get_cache_file()
assert isinstance(cache_file, Path)
assert cache_file.name == "version_cache.json"
def test_is_cache_valid_with_fresh_cache(self) -> None:
"""Test cache validation with fresh cache."""
cache_data = {"timestamp": datetime.now().isoformat(), "version": "1.0.0"}
assert _is_cache_valid(cache_data) is True
def test_is_cache_valid_with_stale_cache(self) -> None:
"""Test cache validation with stale cache."""
old_time = datetime.now() - timedelta(hours=25)
cache_data = {"timestamp": old_time.isoformat(), "version": "1.0.0"}
assert _is_cache_valid(cache_data) is False
def test_is_cache_valid_with_missing_timestamp(self) -> None:
"""Test cache validation with missing timestamp."""
cache_data = {"version": "1.0.0"}
assert _is_cache_valid(cache_data) is False
@patch("crewai_cli.version.Path.exists")
@patch("crewai_cli.version.request.urlopen")
def test_get_latest_version_from_pypi_success(
self, mock_urlopen: MagicMock, mock_exists: MagicMock
) -> None:
"""Test successful PyPI version fetch uses releases data."""
mock_exists.return_value = False
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0": [{"yanked": False}],
"2.1.0": [{"yanked": True, "yanked_reason": "bad release"}],
}
mock_response = MagicMock()
mock_response.read.return_value = json.dumps(
{"info": {"version": "2.1.0"}, "releases": releases}
).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response
version = get_latest_version_from_pypi()
assert version == "2.0.0"
@patch("crewai_cli.version.Path.exists")
@patch("crewai_cli.version.request.urlopen")
def test_get_latest_version_from_pypi_failure(
self, mock_urlopen: MagicMock, mock_exists: MagicMock
) -> None:
"""Test PyPI version fetch failure."""
from urllib.error import URLError
mock_exists.return_value = False
mock_urlopen.side_effect = URLError("Network error")
version = get_latest_version_from_pypi()
assert version is None
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version.get_latest_version_from_pypi")
def test_is_newer_version_available_true(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
"""Test when newer version is available."""
mock_current.return_value = "1.0.0"
mock_latest.return_value = "2.0.0"
is_newer, current, latest = is_newer_version_available()
assert is_newer is True
assert current == "1.0.0"
assert latest == "2.0.0"
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version.get_latest_version_from_pypi")
def test_is_newer_version_available_false(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
"""Test when no newer version is available."""
mock_current.return_value = "2.0.0"
mock_latest.return_value = "2.0.0"
is_newer, current, latest = is_newer_version_available()
assert is_newer is False
assert current == "2.0.0"
assert latest == "2.0.0"
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version.get_latest_version_from_pypi")
def test_is_newer_version_available_with_none_latest(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
"""Test when PyPI fetch fails."""
mock_current.return_value = "1.0.0"
mock_latest.return_value = None
is_newer, current, latest = is_newer_version_available()
assert is_newer is False
assert current == "1.0.0"
assert latest is None
class TestFindLatestNonYankedVersion:
"""Test _find_latest_non_yanked_version helper."""
def test_skips_yanked_versions(self) -> None:
"""Test that yanked versions are skipped."""
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0": [{"yanked": True}],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_returns_highest_non_yanked(self) -> None:
"""Test that the highest non-yanked version is returned."""
releases = {
"1.0.0": [{"yanked": False}],
"1.5.0": [{"yanked": False}],
"2.0.0": [{"yanked": True}],
}
assert _find_latest_non_yanked_version(releases) == "1.5.0"
def test_returns_none_when_all_yanked(self) -> None:
"""Test that None is returned when all versions are yanked."""
releases = {
"1.0.0": [{"yanked": True}],
"2.0.0": [{"yanked": True}],
}
assert _find_latest_non_yanked_version(releases) is None
def test_skips_prerelease_versions(self) -> None:
"""Test that pre-release versions are skipped."""
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0a1": [{"yanked": False}],
"2.0.0rc1": [{"yanked": False}],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_skips_versions_with_empty_files(self) -> None:
"""Test that versions with no files are skipped."""
releases: dict[str, list[dict[str, bool]]] = {
"1.0.0": [{"yanked": False}],
"2.0.0": [],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_handles_invalid_version_strings(self) -> None:
"""Test that invalid version strings are skipped."""
releases = {
"1.0.0": [{"yanked": False}],
"not-a-version": [{"yanked": False}],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_partially_yanked_files_not_considered_yanked(self) -> None:
"""Test that a version with some non-yanked files is not yanked."""
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0": [{"yanked": True}, {"yanked": False}],
}
assert _find_latest_non_yanked_version(releases) == "2.0.0"
class TestIsVersionYanked:
"""Test _is_version_yanked helper."""
def test_non_yanked_version(self) -> None:
"""Test a non-yanked version returns False."""
releases = {"1.0.0": [{"yanked": False}]}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is False
assert reason == ""
def test_yanked_version_with_reason(self) -> None:
"""Test a yanked version returns True with reason."""
releases = {
"1.0.0": [{"yanked": True, "yanked_reason": "critical bug"}],
}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is True
assert reason == "critical bug"
def test_yanked_version_without_reason(self) -> None:
"""Test a yanked version returns True with empty reason."""
releases = {"1.0.0": [{"yanked": True}]}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is True
assert reason == ""
def test_unknown_version(self) -> None:
"""Test an unknown version returns False."""
releases = {"1.0.0": [{"yanked": False}]}
is_yanked, reason = _is_version_yanked("9.9.9", releases)
assert is_yanked is False
assert reason == ""
def test_partially_yanked_files(self) -> None:
"""Test a version with mixed yanked/non-yanked files is not yanked."""
releases = {
"1.0.0": [{"yanked": True}, {"yanked": False}],
}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is False
assert reason == ""
def test_multiple_yanked_files_picks_first_reason(self) -> None:
"""Test that the first available reason is returned."""
releases = {
"1.0.0": [
{"yanked": True, "yanked_reason": ""},
{"yanked": True, "yanked_reason": "second reason"},
],
}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is True
assert reason == "second reason"
class TestIsCurrentVersionYanked:
"""Test is_current_version_yanked public function."""
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_reads_from_valid_cache(
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
) -> None:
"""Test reading yanked status from a valid cache."""
mock_version.return_value = "1.0.0"
cache_file = tmp_path / "version_cache.json"
cache_data = {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"current_version": "1.0.0",
"current_version_yanked": True,
"current_version_yanked_reason": "bad release",
}
cache_file.write_text(json.dumps(cache_data))
mock_cache_file.return_value = cache_file
is_yanked, reason = is_current_version_yanked()
assert is_yanked is True
assert reason == "bad release"
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_not_yanked_from_cache(
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
) -> None:
"""Test non-yanked status from a valid cache."""
mock_version.return_value = "2.0.0"
cache_file = tmp_path / "version_cache.json"
cache_data = {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"current_version": "2.0.0",
"current_version_yanked": False,
"current_version_yanked_reason": "",
}
cache_file.write_text(json.dumps(cache_data))
mock_cache_file.return_value = cache_file
is_yanked, reason = is_current_version_yanked()
assert is_yanked is False
assert reason == ""
@patch("crewai_cli.version.get_latest_version_from_pypi")
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_triggers_fetch_on_stale_cache(
self,
mock_cache_file: MagicMock,
mock_version: MagicMock,
mock_fetch: MagicMock,
tmp_path: Path,
) -> None:
"""Test that a stale cache triggers a re-fetch."""
mock_version.return_value = "1.0.0"
cache_file = tmp_path / "version_cache.json"
old_time = datetime.now() - timedelta(hours=25)
cache_data = {
"version": "2.0.0",
"timestamp": old_time.isoformat(),
"current_version": "1.0.0",
"current_version_yanked": True,
"current_version_yanked_reason": "old reason",
}
cache_file.write_text(json.dumps(cache_data))
mock_cache_file.return_value = cache_file
fresh_cache = {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"current_version": "1.0.0",
"current_version_yanked": False,
"current_version_yanked_reason": "",
}
def write_fresh_cache() -> str:
cache_file.write_text(json.dumps(fresh_cache))
return "2.0.0"
mock_fetch.side_effect = lambda: write_fresh_cache()
is_yanked, reason = is_current_version_yanked()
assert is_yanked is False
mock_fetch.assert_called_once()
@patch("crewai_cli.version.get_latest_version_from_pypi")
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_returns_false_on_fetch_failure(
self,
mock_cache_file: MagicMock,
mock_version: MagicMock,
mock_fetch: MagicMock,
tmp_path: Path,
) -> None:
"""Test that fetch failure returns not yanked."""
mock_version.return_value = "1.0.0"
cache_file = tmp_path / "version_cache.json"
mock_cache_file.return_value = cache_file
mock_fetch.return_value = None
is_yanked, reason = is_current_version_yanked()
assert is_yanked is False
assert reason == ""
# TestConsoleFormatterVersionCheck tests remain in lib/crewai/tests/cli/test_version.py
# as they depend on crewai.events.utils.console_formatter (core package).

View File

@@ -105,15 +105,10 @@ a2a = [
file-processing = [
"crewai-files",
]
cli = [
"crewai-cli",
]
# CLI entry point has moved to the crewai-cli package.
# Install it via: pip install crewai[cli]
# [project.scripts]
# crewai = "crewai.cli.cli:crewai"
[project.scripts]
crewai = "crewai.cli.cli:crewai"
# PyTorch index configuration, since torch 2.5.0 is not compatible with python 3.13

View File

@@ -0,0 +1,68 @@
"""A2UI (Agent to UI) declarative UI protocol support for CrewAI."""
from crewai.a2a.extensions.a2ui.catalog import (
AudioPlayer,
Button,
Card,
CheckBox,
Column,
DateTimeInput,
Divider,
Icon,
Image,
List,
Modal,
MultipleChoice,
Row,
Slider,
Tabs,
Text,
TextField,
Video,
)
from crewai.a2a.extensions.a2ui.client_extension import A2UIClientExtension
from crewai.a2a.extensions.a2ui.models import (
A2UIEvent,
A2UIMessage,
A2UIResponse,
BeginRendering,
DataModelUpdate,
DeleteSurface,
SurfaceUpdate,
UserAction,
)
from crewai.a2a.extensions.a2ui.server_extension import A2UIServerExtension
from crewai.a2a.extensions.a2ui.validator import validate_a2ui_message
__all__ = [
"A2UIClientExtension",
"A2UIEvent",
"A2UIMessage",
"A2UIResponse",
"A2UIServerExtension",
"AudioPlayer",
"BeginRendering",
"Button",
"Card",
"CheckBox",
"Column",
"DataModelUpdate",
"DateTimeInput",
"DeleteSurface",
"Divider",
"Icon",
"Image",
"List",
"Modal",
"MultipleChoice",
"Row",
"Slider",
"SurfaceUpdate",
"Tabs",
"Text",
"TextField",
"UserAction",
"Video",
"validate_a2ui_message",
]

View File

@@ -0,0 +1,467 @@
"""Typed helpers for A2UI standard catalog components.
These models provide optional type safety for standard catalog components.
Agents can also use raw dicts validated against the JSON schema.
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
class StringBinding(BaseModel):
"""A string value: literal or data-model path."""
literal_string: str | None = Field(
default=None, alias="literalString", description="Literal string value."
)
path: str | None = Field(default=None, description="Data-model path reference.")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class NumberBinding(BaseModel):
"""A numeric value: literal or data-model path."""
literal_number: float | None = Field(
default=None, alias="literalNumber", description="Literal numeric value."
)
path: str | None = Field(default=None, description="Data-model path reference.")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class BooleanBinding(BaseModel):
"""A boolean value: literal or data-model path."""
literal_boolean: bool | None = Field(
default=None, alias="literalBoolean", description="Literal boolean value."
)
path: str | None = Field(default=None, description="Data-model path reference.")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class ArrayBinding(BaseModel):
"""An array value: literal or data-model path."""
literal_array: list[str] | None = Field(
default=None, alias="literalArray", description="Literal array of strings."
)
path: str | None = Field(default=None, description="Data-model path reference.")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class ChildrenDef(BaseModel):
"""Children definition for layout components."""
explicit_list: list[str] | None = Field(
default=None,
alias="explicitList",
description="Explicit list of child component IDs.",
)
template: ChildTemplate | None = Field(
default=None, description="Template for generating dynamic children."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class ChildTemplate(BaseModel):
"""Template for generating dynamic children from a data model list."""
component_id: str = Field(
alias="componentId", description="ID of the component to repeat."
)
data_binding: str = Field(
alias="dataBinding", description="Data-model path to bind the template to."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class ActionContextEntry(BaseModel):
"""A key-value pair in an action context payload."""
key: str = Field(description="Context entry key.")
value: ActionBoundValue = Field(description="Context entry value.")
model_config = ConfigDict(extra="forbid")
class ActionBoundValue(BaseModel):
"""A value in an action context: literal or data-model path."""
path: str | None = Field(default=None, description="Data-model path reference.")
literal_string: str | None = Field(
default=None, alias="literalString", description="Literal string value."
)
literal_number: float | None = Field(
default=None, alias="literalNumber", description="Literal numeric value."
)
literal_boolean: bool | None = Field(
default=None, alias="literalBoolean", description="Literal boolean value."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class Action(BaseModel):
"""Client-side action dispatched by interactive components."""
name: str = Field(description="Action name dispatched on interaction.")
context: list[ActionContextEntry] | None = Field(
default=None, description="Key-value pairs sent with the action."
)
model_config = ConfigDict(extra="forbid")
class TabItem(BaseModel):
"""A single tab definition."""
title: StringBinding = Field(description="Tab title text.")
child: str = Field(description="Component ID rendered as the tab content.")
model_config = ConfigDict(extra="forbid")
class MultipleChoiceOption(BaseModel):
"""A single option in a MultipleChoice component."""
label: StringBinding = Field(description="Display label for the option.")
value: str = Field(description="Value submitted when the option is selected.")
model_config = ConfigDict(extra="forbid")
class Text(BaseModel):
"""Displays text content."""
text: StringBinding = Field(description="Text content to display.")
usage_hint: Literal["h1", "h2", "h3", "h4", "h5", "caption", "body"] | None = Field(
default=None, alias="usageHint", description="Semantic hint for text styling."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class Image(BaseModel):
"""Displays an image."""
url: StringBinding = Field(description="Image source URL.")
fit: Literal["contain", "cover", "fill", "none", "scale-down"] | None = Field(
default=None, description="Object-fit behavior for the image."
)
usage_hint: (
Literal[
"icon", "avatar", "smallFeature", "mediumFeature", "largeFeature", "header"
]
| None
) = Field(
default=None, alias="usageHint", description="Semantic hint for image sizing."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
IconName = Literal[
"accountCircle",
"add",
"arrowBack",
"arrowForward",
"attachFile",
"calendarToday",
"call",
"camera",
"check",
"close",
"delete",
"download",
"edit",
"event",
"error",
"favorite",
"favoriteOff",
"folder",
"help",
"home",
"info",
"locationOn",
"lock",
"lockOpen",
"mail",
"menu",
"moreVert",
"moreHoriz",
"notificationsOff",
"notifications",
"payment",
"person",
"phone",
"photo",
"print",
"refresh",
"search",
"send",
"settings",
"share",
"shoppingCart",
"star",
"starHalf",
"starOff",
"upload",
"visibility",
"visibilityOff",
"warning",
]
class IconBinding(BaseModel):
"""Icon name: literal enum or data-model path."""
literal_string: IconName | None = Field(
default=None, alias="literalString", description="Literal icon name."
)
path: str | None = Field(default=None, description="Data-model path reference.")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class Icon(BaseModel):
"""Displays a named icon."""
name: IconBinding = Field(description="Icon name binding.")
model_config = ConfigDict(extra="forbid")
class Video(BaseModel):
"""Displays a video player."""
url: StringBinding = Field(description="Video source URL.")
model_config = ConfigDict(extra="forbid")
class AudioPlayer(BaseModel):
"""Displays an audio player."""
url: StringBinding = Field(description="Audio source URL.")
description: StringBinding | None = Field(
default=None, description="Accessible description of the audio content."
)
model_config = ConfigDict(extra="forbid")
class Row(BaseModel):
"""Horizontal layout container."""
children: ChildrenDef = Field(description="Child components in this row.")
distribution: (
Literal["center", "end", "spaceAround", "spaceBetween", "spaceEvenly", "start"]
| None
) = Field(
default=None, description="How children are distributed along the main axis."
)
alignment: Literal["start", "center", "end", "stretch"] | None = Field(
default=None, description="How children are aligned on the cross axis."
)
model_config = ConfigDict(extra="forbid")
class Column(BaseModel):
"""Vertical layout container."""
children: ChildrenDef = Field(description="Child components in this column.")
distribution: (
Literal["start", "center", "end", "spaceBetween", "spaceAround", "spaceEvenly"]
| None
) = Field(
default=None, description="How children are distributed along the main axis."
)
alignment: Literal["center", "end", "start", "stretch"] | None = Field(
default=None, description="How children are aligned on the cross axis."
)
model_config = ConfigDict(extra="forbid")
class List(BaseModel):
"""Scrollable list container."""
children: ChildrenDef = Field(description="Child components in this list.")
direction: Literal["vertical", "horizontal"] | None = Field(
default=None, description="Scroll direction of the list."
)
alignment: Literal["start", "center", "end", "stretch"] | None = Field(
default=None, description="How children are aligned on the cross axis."
)
model_config = ConfigDict(extra="forbid")
class Card(BaseModel):
"""Card container wrapping a single child."""
child: str = Field(description="Component ID of the card content.")
model_config = ConfigDict(extra="forbid")
class Tabs(BaseModel):
"""Tabbed navigation container."""
tab_items: list[TabItem] = Field(
alias="tabItems", description="List of tab definitions."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class Divider(BaseModel):
"""A visual divider line."""
axis: Literal["horizontal", "vertical"] | None = Field(
default=None, description="Orientation of the divider."
)
model_config = ConfigDict(extra="forbid")
class Modal(BaseModel):
"""A modal dialog with an entry point trigger and content."""
entry_point_child: str = Field(
alias="entryPointChild", description="Component ID that triggers the modal."
)
content_child: str = Field(
alias="contentChild", description="Component ID rendered inside the modal."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class Button(BaseModel):
"""An interactive button with an action."""
child: str = Field(description="Component ID of the button label.")
primary: bool | None = Field(
default=None, description="Whether the button uses primary styling."
)
action: Action = Field(description="Action dispatched when the button is clicked.")
model_config = ConfigDict(extra="forbid")
class CheckBox(BaseModel):
"""A checkbox input."""
label: StringBinding = Field(description="Label text for the checkbox.")
value: BooleanBinding = Field(
description="Boolean value binding for the checkbox state."
)
model_config = ConfigDict(extra="forbid")
class TextField(BaseModel):
"""A text input field."""
label: StringBinding = Field(description="Label text for the input.")
text: StringBinding | None = Field(
default=None, description="Current text value binding."
)
text_field_type: (
Literal["date", "longText", "number", "shortText", "obscured"] | None
) = Field(default=None, alias="textFieldType", description="Input type variant.")
validation_regexp: str | None = Field(
default=None,
alias="validationRegexp",
description="Regex pattern for client-side validation.",
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class DateTimeInput(BaseModel):
"""A date and/or time picker."""
value: StringBinding = Field(description="ISO date/time string value binding.")
enable_date: bool | None = Field(
default=None,
alias="enableDate",
description="Whether the date picker is enabled.",
)
enable_time: bool | None = Field(
default=None,
alias="enableTime",
description="Whether the time picker is enabled.",
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class MultipleChoice(BaseModel):
"""A multiple-choice selection component."""
selections: ArrayBinding = Field(description="Array binding for selected values.")
options: list[MultipleChoiceOption] = Field(description="Available choices.")
max_allowed_selections: int | None = Field(
default=None,
alias="maxAllowedSelections",
description="Maximum number of selections allowed.",
)
variant: Literal["checkbox", "chips"] | None = Field(
default=None, description="Visual variant for the selection UI."
)
filterable: bool | None = Field(
default=None, description="Whether options can be filtered by typing."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class Slider(BaseModel):
"""A numeric slider input."""
value: NumberBinding = Field(
description="Numeric value binding for the slider position."
)
min_value: float | None = Field(
default=None, alias="minValue", description="Minimum slider value."
)
max_value: float | None = Field(
default=None, alias="maxValue", description="Maximum slider value."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
STANDARD_CATALOG_COMPONENTS: frozenset[str] = frozenset(
{
"Text",
"Image",
"Icon",
"Video",
"AudioPlayer",
"Row",
"Column",
"List",
"Card",
"Tabs",
"Divider",
"Modal",
"Button",
"CheckBox",
"TextField",
"DateTimeInput",
"MultipleChoice",
"Slider",
}
)

View File

@@ -0,0 +1,285 @@
"""A2UI client extension for the A2A protocol."""
from __future__ import annotations
from collections.abc import Sequence
import logging
from typing import TYPE_CHECKING, Any, cast
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import TypedDict
from crewai.a2a.extensions.a2ui.models import extract_a2ui_json_objects
from crewai.a2a.extensions.a2ui.prompt import build_a2ui_system_prompt
from crewai.a2a.extensions.a2ui.server_extension import A2UI_MIME_TYPE
from crewai.a2a.extensions.a2ui.validator import (
A2UIValidationError,
validate_a2ui_message,
)
if TYPE_CHECKING:
from a2a.types import Message
from crewai.agent.core import Agent
logger = logging.getLogger(__name__)
class StylesDict(TypedDict, total=False):
"""Serialized surface styling."""
font: str
primaryColor: str
class ComponentEntryDict(TypedDict, total=False):
"""Serialized component entry in a surface update."""
id: str
weight: float
component: dict[str, Any]
class BeginRenderingDict(TypedDict, total=False):
"""Serialized beginRendering payload."""
surfaceId: str
root: str
catalogId: str
styles: StylesDict
class SurfaceUpdateDict(TypedDict, total=False):
"""Serialized surfaceUpdate payload."""
surfaceId: str
components: list[ComponentEntryDict]
class DataEntryDict(TypedDict, total=False):
"""Serialized data model entry."""
key: str
valueString: str
valueNumber: float
valueBoolean: bool
valueMap: list[DataEntryDict]
class DataModelUpdateDict(TypedDict, total=False):
"""Serialized dataModelUpdate payload."""
surfaceId: str
path: str
contents: list[DataEntryDict]
class DeleteSurfaceDict(TypedDict):
"""Serialized deleteSurface payload."""
surfaceId: str
class A2UIMessageDict(TypedDict, total=False):
"""Serialized A2UI server-to-client message with exactly one key set."""
beginRendering: BeginRenderingDict
surfaceUpdate: SurfaceUpdateDict
dataModelUpdate: DataModelUpdateDict
deleteSurface: DeleteSurfaceDict
@dataclass
class A2UIConversationState:
"""Tracks active A2UI surfaces and data models across a conversation."""
active_surfaces: dict[str, dict[str, Any]] = Field(default_factory=dict)
data_models: dict[str, list[dict[str, Any]]] = Field(default_factory=dict)
last_a2ui_messages: list[A2UIMessageDict] = Field(default_factory=list)
def is_ready(self) -> bool:
"""Return True when at least one surface is active."""
return bool(self.active_surfaces)
class A2UIClientExtension:
"""A2A client extension that adds A2UI support to agents.
Implements the ``A2AExtension`` protocol to inject A2UI prompt
instructions, track UI state across conversations, and validate
A2UI messages in responses.
Example::
A2AClientConfig(
endpoint="...",
extensions=["https://a2ui.org/a2a-extension/a2ui/v0.8"],
client_extensions=[A2UIClientExtension()],
)
"""
def __init__(
self,
catalog_id: str | None = None,
allowed_components: list[str] | None = None,
) -> None:
"""Initialize the A2UI client extension.
Args:
catalog_id: Catalog identifier to use for prompt generation.
allowed_components: Subset of component names to expose to the agent.
"""
self._catalog_id = catalog_id
self._allowed_components = allowed_components
def inject_tools(self, agent: Agent) -> None:
"""No-op — A2UI uses prompt augmentation rather than tool injection."""
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> A2UIConversationState | None:
"""Scan conversation history for A2UI DataParts and track surface state.
When ``catalog_id`` is set, only surfaces matching that catalog are tracked.
"""
state = A2UIConversationState()
for message in conversation_history:
for part in message.parts:
root = part.root
if root.kind != "data":
continue
metadata = root.metadata or {}
mime_type = metadata.get("mimeType", "")
if mime_type != A2UI_MIME_TYPE:
continue
data = root.data
if not isinstance(data, dict):
continue
surface_id = _get_surface_id(data)
if not surface_id:
continue
if self._catalog_id and "beginRendering" in data:
catalog_id = data["beginRendering"].get("catalogId")
if catalog_id and catalog_id != self._catalog_id:
continue
if "deleteSurface" in data:
state.active_surfaces.pop(surface_id, None)
state.data_models.pop(surface_id, None)
elif "beginRendering" in data:
state.active_surfaces[surface_id] = data["beginRendering"]
elif "surfaceUpdate" in data:
state.active_surfaces[surface_id] = data["surfaceUpdate"]
elif "dataModelUpdate" in data:
contents = data["dataModelUpdate"].get("contents", [])
state.data_models.setdefault(surface_id, []).extend(contents)
if not state.active_surfaces and not state.data_models:
return None
return state
def augment_prompt(
self,
base_prompt: str,
_conversation_state: A2UIConversationState | None,
) -> str:
"""Append A2UI system prompt instructions to the base prompt."""
a2ui_prompt = build_a2ui_system_prompt(
catalog_id=self._catalog_id,
allowed_components=self._allowed_components,
)
return f"{base_prompt}\n\n{a2ui_prompt}"
def process_response(
self,
agent_response: Any,
conversation_state: A2UIConversationState | None,
) -> Any:
"""Extract and validate A2UI JSON from agent output.
When ``allowed_components`` is set, components not in the allowlist are
logged and stripped from surface updates. Stores extracted A2UI messages
on the conversation state and returns the original response unchanged.
"""
text = (
agent_response if isinstance(agent_response, str) else str(agent_response)
)
a2ui_messages = _extract_and_validate(text)
if self._allowed_components:
allowed = set(self._allowed_components)
a2ui_messages = [_filter_components(msg, allowed) for msg in a2ui_messages]
if a2ui_messages and conversation_state is not None:
conversation_state.last_a2ui_messages = a2ui_messages
return agent_response
def _get_surface_id(data: dict[str, Any]) -> str | None:
"""Extract surfaceId from any A2UI message type."""
for key in ("beginRendering", "surfaceUpdate", "dataModelUpdate", "deleteSurface"):
inner = data.get(key)
if isinstance(inner, dict):
sid = inner.get("surfaceId")
if isinstance(sid, str):
return sid
return None
def _filter_components(msg: A2UIMessageDict, allowed: set[str]) -> A2UIMessageDict:
"""Strip components whose type is not in *allowed* from a surfaceUpdate."""
surface_update = msg.get("surfaceUpdate")
if not isinstance(surface_update, dict):
return msg
components = surface_update.get("components")
if not isinstance(components, list):
return msg
filtered = []
for entry in components:
component = entry.get("component", {})
component_types = set(component.keys())
disallowed = component_types - allowed
if disallowed:
logger.debug(
"Stripping disallowed component type(s) %s from surface update",
disallowed,
)
continue
filtered.append(entry)
if len(filtered) == len(components):
return msg
return {**msg, "surfaceUpdate": {**surface_update, "components": filtered}}
def _extract_and_validate(text: str) -> list[A2UIMessageDict]:
"""Extract A2UI JSON objects from text and validate them."""
return [
dumped
for candidate in extract_a2ui_json_objects(text)
if (dumped := _try_validate(candidate)) is not None
]
def _try_validate(candidate: dict[str, Any]) -> A2UIMessageDict | None:
"""Validate a single A2UI candidate, returning None on failure."""
try:
msg = validate_a2ui_message(candidate)
except A2UIValidationError:
logger.debug(
"Skipping invalid A2UI candidate in agent output",
exc_info=True,
)
return None
return cast(A2UIMessageDict, msg.model_dump(by_alias=True, exclude_none=True))

View File

@@ -0,0 +1,281 @@
"""Pydantic models for A2UI server-to-client messages and client-to-server events."""
from __future__ import annotations
import json
import re
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, model_validator
class BoundValue(BaseModel):
"""A value that can be a literal or a data-model path reference."""
literal_string: str | None = Field(
default=None, alias="literalString", description="Literal string value."
)
literal_number: float | None = Field(
default=None, alias="literalNumber", description="Literal numeric value."
)
literal_boolean: bool | None = Field(
default=None, alias="literalBoolean", description="Literal boolean value."
)
literal_array: list[str] | None = Field(
default=None, alias="literalArray", description="Literal array of strings."
)
path: str | None = Field(default=None, description="Data-model path reference.")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class MapEntry(BaseModel):
"""A single entry in a valueMap adjacency list, supporting recursive nesting."""
key: str = Field(description="Entry key.")
value_string: str | None = Field(
default=None, alias="valueString", description="String value."
)
value_number: float | None = Field(
default=None, alias="valueNumber", description="Numeric value."
)
value_boolean: bool | None = Field(
default=None, alias="valueBoolean", description="Boolean value."
)
value_map: list[MapEntry] | None = Field(
default=None, alias="valueMap", description="Nested map entries."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class DataEntry(BaseModel):
"""A data model entry with a key and exactly one typed value."""
key: str = Field(description="Entry key.")
value_string: str | None = Field(
default=None, alias="valueString", description="String value."
)
value_number: float | None = Field(
default=None, alias="valueNumber", description="Numeric value."
)
value_boolean: bool | None = Field(
default=None, alias="valueBoolean", description="Boolean value."
)
value_map: list[MapEntry] | None = Field(
default=None, alias="valueMap", description="Nested map entries."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
_HEX_COLOR_PATTERN: re.Pattern[str] = re.compile(r"^#[0-9a-fA-F]{6}$")
class Styles(BaseModel):
"""Surface styling information."""
font: str | None = Field(default=None, description="Font family name.")
primary_color: str | None = Field(
default=None,
alias="primaryColor",
pattern=_HEX_COLOR_PATTERN.pattern,
description="Primary color as a hex string.",
)
model_config = ConfigDict(populate_by_name=True, extra="allow")
class ComponentEntry(BaseModel):
"""A single component in a UI widget tree.
The ``component`` dict must contain exactly one key — the component type
name (e.g. ``"Text"``, ``"Button"``) — whose value holds the component
properties. Component internals are left as ``dict[str, Any]`` because
they are catalog-dependent; use the typed helpers in ``catalog.py`` for
the standard catalog.
"""
id: str = Field(description="Unique component identifier.")
weight: float | None = Field(
default=None, description="Flex weight for layout distribution."
)
component: dict[str, Any] = Field(
description="Component type name mapped to its properties."
)
model_config = ConfigDict(extra="forbid")
class BeginRendering(BaseModel):
"""Signals the client to begin rendering a surface."""
surface_id: str = Field(alias="surfaceId", description="Unique surface identifier.")
root: str = Field(description="Component ID of the root element.")
catalog_id: str | None = Field(
default=None,
alias="catalogId",
description="Catalog identifier for the surface.",
)
styles: Styles | None = Field(
default=None, description="Surface styling overrides."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class SurfaceUpdate(BaseModel):
"""Updates a surface with a new set of components."""
surface_id: str = Field(alias="surfaceId", description="Target surface identifier.")
components: list[ComponentEntry] = Field(
min_length=1, description="Components to render on the surface."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class DataModelUpdate(BaseModel):
"""Updates the data model for a surface."""
surface_id: str = Field(alias="surfaceId", description="Target surface identifier.")
path: str | None = Field(
default=None, description="Data-model path prefix for the update."
)
contents: list[DataEntry] = Field(
description="Data entries to merge into the model."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class DeleteSurface(BaseModel):
"""Signals the client to delete a surface."""
surface_id: str = Field(
alias="surfaceId", description="Surface identifier to delete."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
class A2UIMessage(BaseModel):
"""Union wrapper for the four server-to-client A2UI message types.
Exactly one of the fields must be set.
"""
begin_rendering: BeginRendering | None = Field(
default=None,
alias="beginRendering",
description="Begin rendering a new surface.",
)
surface_update: SurfaceUpdate | None = Field(
default=None,
alias="surfaceUpdate",
description="Update components on a surface.",
)
data_model_update: DataModelUpdate | None = Field(
default=None,
alias="dataModelUpdate",
description="Update the surface data model.",
)
delete_surface: DeleteSurface | None = Field(
default=None, alias="deleteSurface", description="Delete an existing surface."
)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
@model_validator(mode="after")
def _check_exactly_one(self) -> A2UIMessage:
"""Enforce the spec's exactly-one-of constraint."""
fields = [
self.begin_rendering,
self.surface_update,
self.data_model_update,
self.delete_surface,
]
count = sum(f is not None for f in fields)
if count != 1:
raise ValueError(f"Exactly one A2UI message type must be set, got {count}")
return self
class UserAction(BaseModel):
"""Reports a user-initiated action from a component."""
name: str = Field(description="Action name.")
surface_id: str = Field(alias="surfaceId", description="Source surface identifier.")
source_component_id: str = Field(
alias="sourceComponentId", description="Component that triggered the action."
)
timestamp: str = Field(description="ISO 8601 timestamp of the action.")
context: dict[str, Any] = Field(description="Action context payload.")
model_config = ConfigDict(populate_by_name=True)
class ClientError(BaseModel):
"""Reports a client-side error."""
model_config = ConfigDict(extra="allow")
class A2UIEvent(BaseModel):
"""Union wrapper for client-to-server events."""
user_action: UserAction | None = Field(
default=None, alias="userAction", description="User-initiated action event."
)
error: ClientError | None = Field(
default=None, description="Client-side error report."
)
model_config = ConfigDict(populate_by_name=True)
@model_validator(mode="after")
def _check_exactly_one(self) -> A2UIEvent:
"""Enforce the spec's exactly-one-of constraint."""
fields = [self.user_action, self.error]
count = sum(f is not None for f in fields)
if count != 1:
raise ValueError(f"Exactly one A2UI event type must be set, got {count}")
return self
class A2UIResponse(BaseModel):
"""Typed wrapper for responses containing A2UI messages."""
text: str = Field(description="Raw text content of the response.")
a2ui_parts: list[dict[str, Any]] = Field(
default_factory=list, description="A2UI DataParts extracted from the response."
)
a2ui_messages: list[dict[str, Any]] = Field(
default_factory=list, description="Validated A2UI message dicts."
)
_A2UI_KEYS = {"beginRendering", "surfaceUpdate", "dataModelUpdate", "deleteSurface"}
def extract_a2ui_json_objects(text: str) -> list[dict[str, Any]]:
"""Extract JSON objects containing A2UI keys from text.
Uses ``json.JSONDecoder.raw_decode`` for robust parsing that correctly
handles braces inside string literals.
"""
decoder = json.JSONDecoder()
results: list[dict[str, Any]] = []
idx = 0
while idx < len(text):
idx = text.find("{", idx)
if idx == -1:
break
try:
obj, end_idx = decoder.raw_decode(text, idx)
if isinstance(obj, dict) and _A2UI_KEYS & obj.keys():
results.append(obj)
idx = end_idx
except json.JSONDecodeError:
idx += 1
return results

View File

@@ -0,0 +1,76 @@
"""System prompt generation for A2UI-capable agents."""
from __future__ import annotations
import json
from crewai.a2a.extensions.a2ui.catalog import STANDARD_CATALOG_COMPONENTS
from crewai.a2a.extensions.a2ui.schema import load_schema
from crewai.a2a.extensions.a2ui.server_extension import A2UI_EXTENSION_URI
def build_a2ui_system_prompt(
catalog_id: str | None = None,
allowed_components: list[str] | None = None,
) -> str:
"""Build a system prompt fragment instructing the LLM to produce A2UI output.
The prompt describes the A2UI message format, available components, and
data binding rules. It includes the resolved schema so the LLM can
generate structured output.
Args:
catalog_id: Catalog identifier to reference. Defaults to the
standard catalog version derived from ``A2UI_EXTENSION_URI``.
allowed_components: Subset of component names to expose. When
``None``, all standard catalog components are available.
Returns:
A system prompt string to append to the agent's instructions.
"""
components = sorted(
allowed_components
if allowed_components is not None
else STANDARD_CATALOG_COMPONENTS
)
catalog_label = catalog_id or f"standard ({A2UI_EXTENSION_URI.rsplit('/', 1)[-1]})"
resolved_schema = load_schema("server_to_client_with_standard_catalog")
schema_json = json.dumps(resolved_schema, indent=2)
return f"""\
<A2UI_INSTRUCTIONS>
You can generate rich, declarative UI by emitting A2UI JSON messages.
CATALOG: {catalog_label}
AVAILABLE COMPONENTS: {", ".join(components)}
MESSAGE TYPES (emit exactly ONE per message):
- beginRendering: Initialize a new surface with a root component and optional styles.
- surfaceUpdate: Send/update components for a surface. Each component has a unique id \
and a "component" wrapper containing exactly one component-type key.
- dataModelUpdate: Update the data model for a surface. Data entries have a key and \
one typed value (valueString, valueNumber, valueBoolean, valueMap).
- deleteSurface: Remove a surface.
DATA BINDING:
- Use {{"literalString": "..."}} for inline string values.
- Use {{"literalNumber": ...}} for inline numeric values.
- Use {{"literalBoolean": ...}} for inline boolean values.
- Use {{"literalArray": ["...", "..."]}} for inline array values.
- Use {{"path": "/data/model/path"}} to bind to data model values.
ACTIONS:
- Interactive components (Button, etc.) have an "action" with a "name" and optional \
"context" array of key/value pairs.
- Values in action context can use data binding (path or literal).
OUTPUT FORMAT:
Emit each A2UI message as a valid JSON object. When generating UI, produce a \
beginRendering message first, then surfaceUpdate messages with components, and \
optionally dataModelUpdate messages to populate data-bound values.
SCHEMA:
{schema_json}
</A2UI_INSTRUCTIONS>"""

View File

@@ -0,0 +1,48 @@
"""Schema loading utilities for vendored A2UI JSON schemas."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
_SCHEMA_DIR = Path(__file__).parent / "v0_8"
_SCHEMA_CACHE: dict[str, dict[str, Any]] = {}
SCHEMA_NAMES: frozenset[str] = frozenset(
{
"server_to_client",
"client_to_server",
"standard_catalog_definition",
"server_to_client_with_standard_catalog",
}
)
def load_schema(name: str) -> dict[str, Any]:
"""Load a vendored A2UI JSON schema by name.
Args:
name: Schema name without extension (e.g. ``"server_to_client"``).
Returns:
Parsed JSON schema dict.
Raises:
ValueError: If the schema name is not recognized.
FileNotFoundError: If the schema file is missing from the package.
"""
if name not in SCHEMA_NAMES:
raise ValueError(f"Unknown schema {name!r}. Available: {sorted(SCHEMA_NAMES)}")
if name in _SCHEMA_CACHE:
return _SCHEMA_CACHE[name]
path = _SCHEMA_DIR / f"{name}.json"
with path.open() as f:
schema: dict[str, Any] = json.load(f)
_SCHEMA_CACHE[name] = schema
return schema

View File

@@ -0,0 +1,53 @@
{
"title": "A2UI (Agent to UI) Client-to-Server Event Schema",
"description": "Describes a JSON payload for a client-to-server event message.",
"type": "object",
"minProperties": 1,
"maxProperties": 1,
"properties": {
"userAction": {
"type": "object",
"description": "Reports a user-initiated action from a component.",
"properties": {
"name": {
"type": "string",
"description": "The name of the action, taken from the component's action.name property."
},
"surfaceId": {
"type": "string",
"description": "The id of the surface where the event originated."
},
"sourceComponentId": {
"type": "string",
"description": "The id of the component that triggered the event."
},
"timestamp": {
"type": "string",
"format": "date-time",
"description": "An ISO 8601 timestamp of when the event occurred."
},
"context": {
"type": "object",
"description": "A JSON object containing the key-value pairs from the component's action.context, after resolving all data bindings.",
"additionalProperties": true
}
},
"required": [
"name",
"surfaceId",
"sourceComponentId",
"timestamp",
"context"
]
},
"error": {
"type": "object",
"description": "Reports a client-side error. The content is flexible.",
"additionalProperties": true
}
},
"oneOf": [
{ "required": ["userAction"] },
{ "required": ["error"] }
]
}

View File

@@ -0,0 +1,148 @@
{
"title": "A2UI Message Schema",
"description": "Describes a JSON payload for an A2UI (Agent to UI) message, which is used to dynamically construct and update user interfaces. A message MUST contain exactly ONE of the action properties: 'beginRendering', 'surfaceUpdate', 'dataModelUpdate', or 'deleteSurface'.",
"type": "object",
"additionalProperties": false,
"properties": {
"beginRendering": {
"type": "object",
"description": "Signals the client to begin rendering a surface with a root component and specific styles.",
"additionalProperties": false,
"properties": {
"surfaceId": {
"type": "string",
"description": "The unique identifier for the UI surface to be rendered."
},
"catalogId": {
"type": "string",
"description": "The identifier of the component catalog to use for this surface. If omitted, the client MUST default to the standard catalog for this A2UI version (https://a2ui.org/specification/v0_8/standard_catalog_definition.json)."
},
"root": {
"type": "string",
"description": "The ID of the root component to render."
},
"styles": {
"type": "object",
"description": "Styling information for the UI.",
"additionalProperties": true
}
},
"required": ["root", "surfaceId"]
},
"surfaceUpdate": {
"type": "object",
"description": "Updates a surface with a new set of components.",
"additionalProperties": false,
"properties": {
"surfaceId": {
"type": "string",
"description": "The unique identifier for the UI surface to be updated. If you are adding a new surface this *must* be a new, unique identified that has never been used for any existing surfaces shown."
},
"components": {
"type": "array",
"description": "A list containing all UI components for the surface.",
"minItems": 1,
"items": {
"type": "object",
"description": "Represents a *single* component in a UI widget tree. This component could be one of many supported types.",
"additionalProperties": false,
"properties": {
"id": {
"type": "string",
"description": "The unique identifier for this component."
},
"weight": {
"type": "number",
"description": "The relative weight of this component within a Row or Column. This corresponds to the CSS 'flex-grow' property. Note: this may ONLY be set when the component is a direct descendant of a Row or Column."
},
"component": {
"type": "object",
"description": "A wrapper object that MUST contain exactly one key, which is the name of the component type. The value is an object containing the properties for that specific component.",
"additionalProperties": true
}
},
"required": ["id", "component"]
}
}
},
"required": ["surfaceId", "components"]
},
"dataModelUpdate": {
"type": "object",
"description": "Updates the data model for a surface.",
"additionalProperties": false,
"properties": {
"surfaceId": {
"type": "string",
"description": "The unique identifier for the UI surface this data model update applies to."
},
"path": {
"type": "string",
"description": "An optional path to a location within the data model (e.g., '/user/name'). If omitted, or set to '/', the entire data model will be replaced."
},
"contents": {
"type": "array",
"description": "An array of data entries. Each entry must contain a 'key' and exactly one corresponding typed 'value*' property.",
"items": {
"type": "object",
"description": "A single data entry. Exactly one 'value*' property should be provided alongside the key.",
"additionalProperties": false,
"properties": {
"key": {
"type": "string",
"description": "The key for this data entry."
},
"valueString": {
"type": "string"
},
"valueNumber": {
"type": "number"
},
"valueBoolean": {
"type": "boolean"
},
"valueMap": {
"description": "Represents a map as an adjacency list.",
"type": "array",
"items": {
"type": "object",
"description": "One entry in the map. Exactly one 'value*' property should be provided alongside the key.",
"additionalProperties": false,
"properties": {
"key": {
"type": "string"
},
"valueString": {
"type": "string"
},
"valueNumber": {
"type": "number"
},
"valueBoolean": {
"type": "boolean"
}
},
"required": ["key"]
}
}
},
"required": ["key"]
}
}
},
"required": ["contents", "surfaceId"]
},
"deleteSurface": {
"type": "object",
"description": "Signals the client to delete the surface identified by 'surfaceId'.",
"additionalProperties": false,
"properties": {
"surfaceId": {
"type": "string",
"description": "The unique identifier for the UI surface to be deleted."
}
},
"required": ["surfaceId"]
}
}
}

View File

@@ -0,0 +1,823 @@
{
"title": "A2UI Message Schema",
"description": "Describes a JSON payload for an A2UI (Agent to UI) message, which is used to dynamically construct and update user interfaces. A message MUST contain exactly ONE of the action properties: 'beginRendering', 'surfaceUpdate', 'dataModelUpdate', or 'deleteSurface'.",
"type": "object",
"additionalProperties": false,
"properties": {
"beginRendering": {
"type": "object",
"description": "Signals the client to begin rendering a surface with a root component and specific styles.",
"additionalProperties": false,
"properties": {
"surfaceId": {
"type": "string",
"description": "The unique identifier for the UI surface to be rendered."
},
"root": {
"type": "string",
"description": "The ID of the root component to render."
},
"styles": {
"type": "object",
"description": "Styling information for the UI.",
"additionalProperties": false,
"properties": {
"font": {
"type": "string",
"description": "The primary font for the UI."
},
"primaryColor": {
"type": "string",
"description": "The primary UI color as a hexadecimal code (e.g., '#00BFFF').",
"pattern": "^#[0-9a-fA-F]{6}$"
}
}
}
},
"required": ["root", "surfaceId"]
},
"surfaceUpdate": {
"type": "object",
"description": "Updates a surface with a new set of components.",
"additionalProperties": false,
"properties": {
"surfaceId": {
"type": "string",
"description": "The unique identifier for the UI surface to be updated. If you are adding a new surface this *must* be a new, unique identified that has never been used for any existing surfaces shown."
},
"components": {
"type": "array",
"description": "A list containing all UI components for the surface.",
"minItems": 1,
"items": {
"type": "object",
"description": "Represents a *single* component in a UI widget tree. This component could be one of many supported types.",
"additionalProperties": false,
"properties": {
"id": {
"type": "string",
"description": "The unique identifier for this component."
},
"weight": {
"type": "number",
"description": "The relative weight of this component within a Row or Column. This corresponds to the CSS 'flex-grow' property. Note: this may ONLY be set when the component is a direct descendant of a Row or Column."
},
"component": {
"type": "object",
"description": "A wrapper object that MUST contain exactly one key, which is the name of the component type (e.g., 'Heading'). The value is an object containing the properties for that specific component.",
"additionalProperties": false,
"properties": {
"Text": {
"type": "object",
"additionalProperties": false,
"properties": {
"text": {
"type": "object",
"description": "The text content to display. This can be a literal string or a reference to a value in the data model ('path', e.g., '/doc/title'). While simple Markdown formatting is supported (i.e. without HTML, images, or links), utilizing dedicated UI components is generally preferred for a richer and more structured presentation.",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
},
"usageHint": {
"type": "string",
"description": "A hint for the base text style. One of:\n- `h1`: Largest heading.\n- `h2`: Second largest heading.\n- `h3`: Third largest heading.\n- `h4`: Fourth largest heading.\n- `h5`: Fifth largest heading.\n- `caption`: Small text for captions.\n- `body`: Standard body text.",
"enum": [
"h1",
"h2",
"h3",
"h4",
"h5",
"caption",
"body"
]
}
},
"required": ["text"]
},
"Image": {
"type": "object",
"additionalProperties": false,
"properties": {
"url": {
"type": "object",
"description": "The URL of the image to display. This can be a literal string ('literal') or a reference to a value in the data model ('path', e.g. '/thumbnail/url').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
},
"fit": {
"type": "string",
"description": "Specifies how the image should be resized to fit its container. This corresponds to the CSS 'object-fit' property.",
"enum": [
"contain",
"cover",
"fill",
"none",
"scale-down"
]
},
"usageHint": {
"type": "string",
"description": "A hint for the image size and style. One of:\n- `icon`: Small square icon.\n- `avatar`: Circular avatar image.\n- `smallFeature`: Small feature image.\n- `mediumFeature`: Medium feature image.\n- `largeFeature`: Large feature image.\n- `header`: Full-width, full bleed, header image.",
"enum": [
"icon",
"avatar",
"smallFeature",
"mediumFeature",
"largeFeature",
"header"
]
}
},
"required": ["url"]
},
"Icon": {
"type": "object",
"additionalProperties": false,
"properties": {
"name": {
"type": "object",
"description": "The name of the icon to display. This can be a literal string or a reference to a value in the data model ('path', e.g. '/form/submit').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string",
"enum": [
"accountCircle",
"add",
"arrowBack",
"arrowForward",
"attachFile",
"calendarToday",
"call",
"camera",
"check",
"close",
"delete",
"download",
"edit",
"event",
"error",
"favorite",
"favoriteOff",
"folder",
"help",
"home",
"info",
"locationOn",
"lock",
"lockOpen",
"mail",
"menu",
"moreVert",
"moreHoriz",
"notificationsOff",
"notifications",
"payment",
"person",
"phone",
"photo",
"print",
"refresh",
"search",
"send",
"settings",
"share",
"shoppingCart",
"star",
"starHalf",
"starOff",
"upload",
"visibility",
"visibilityOff",
"warning"
]
},
"path": {
"type": "string"
}
}
}
},
"required": ["name"]
},
"Video": {
"type": "object",
"additionalProperties": false,
"properties": {
"url": {
"type": "object",
"description": "The URL of the video to display. This can be a literal string or a reference to a value in the data model ('path', e.g. '/video/url').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
}
},
"required": ["url"]
},
"AudioPlayer": {
"type": "object",
"additionalProperties": false,
"properties": {
"url": {
"type": "object",
"description": "The URL of the audio to be played. This can be a literal string ('literal') or a reference to a value in the data model ('path', e.g. '/song/url').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
},
"description": {
"type": "object",
"description": "A description of the audio, such as a title or summary. This can be a literal string or a reference to a value in the data model ('path', e.g. '/song/title').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
}
},
"required": ["url"]
},
"Row": {
"type": "object",
"additionalProperties": false,
"properties": {
"children": {
"type": "object",
"description": "Defines the children. Use 'explicitList' for a fixed set of children, or 'template' to generate children from a data list.",
"additionalProperties": false,
"properties": {
"explicitList": {
"type": "array",
"items": {
"type": "string"
}
},
"template": {
"type": "object",
"description": "A template for generating a dynamic list of children from a data model list. `componentId` is the component to use as a template, and `dataBinding` is the path to the map of components in the data model. Values in the map will define the list of children.",
"additionalProperties": false,
"properties": {
"componentId": {
"type": "string"
},
"dataBinding": {
"type": "string"
}
},
"required": ["componentId", "dataBinding"]
}
}
},
"distribution": {
"type": "string",
"description": "Defines the arrangement of children along the main axis (horizontally). This corresponds to the CSS 'justify-content' property.",
"enum": [
"center",
"end",
"spaceAround",
"spaceBetween",
"spaceEvenly",
"start"
]
},
"alignment": {
"type": "string",
"description": "Defines the alignment of children along the cross axis (vertically). This corresponds to the CSS 'align-items' property.",
"enum": ["start", "center", "end", "stretch"]
}
},
"required": ["children"]
},
"Column": {
"type": "object",
"additionalProperties": false,
"properties": {
"children": {
"type": "object",
"description": "Defines the children. Use 'explicitList' for a fixed set of children, or 'template' to generate children from a data list.",
"additionalProperties": false,
"properties": {
"explicitList": {
"type": "array",
"items": {
"type": "string"
}
},
"template": {
"type": "object",
"description": "A template for generating a dynamic list of children from a data model list. `componentId` is the component to use as a template, and `dataBinding` is the path to the map of components in the data model. Values in the map will define the list of children.",
"additionalProperties": false,
"properties": {
"componentId": {
"type": "string"
},
"dataBinding": {
"type": "string"
}
},
"required": ["componentId", "dataBinding"]
}
}
},
"distribution": {
"type": "string",
"description": "Defines the arrangement of children along the main axis (vertically). This corresponds to the CSS 'justify-content' property.",
"enum": [
"start",
"center",
"end",
"spaceBetween",
"spaceAround",
"spaceEvenly"
]
},
"alignment": {
"type": "string",
"description": "Defines the alignment of children along the cross axis (horizontally). This corresponds to the CSS 'align-items' property.",
"enum": ["center", "end", "start", "stretch"]
}
},
"required": ["children"]
},
"List": {
"type": "object",
"additionalProperties": false,
"properties": {
"children": {
"type": "object",
"description": "Defines the children. Use 'explicitList' for a fixed set of children, or 'template' to generate children from a data list.",
"additionalProperties": false,
"properties": {
"explicitList": {
"type": "array",
"items": {
"type": "string"
}
},
"template": {
"type": "object",
"description": "A template for generating a dynamic list of children from a data model list. `componentId` is the component to use as a template, and `dataBinding` is the path to the map of components in the data model. Values in the map will define the list of children.",
"additionalProperties": false,
"properties": {
"componentId": {
"type": "string"
},
"dataBinding": {
"type": "string"
}
},
"required": ["componentId", "dataBinding"]
}
}
},
"direction": {
"type": "string",
"description": "The direction in which the list items are laid out.",
"enum": ["vertical", "horizontal"]
},
"alignment": {
"type": "string",
"description": "Defines the alignment of children along the cross axis.",
"enum": ["start", "center", "end", "stretch"]
}
},
"required": ["children"]
},
"Card": {
"type": "object",
"additionalProperties": false,
"properties": {
"child": {
"type": "string",
"description": "The ID of the component to be rendered inside the card."
}
},
"required": ["child"]
},
"Tabs": {
"type": "object",
"additionalProperties": false,
"properties": {
"tabItems": {
"type": "array",
"description": "An array of objects, where each object defines a tab with a title and a child component.",
"items": {
"type": "object",
"additionalProperties": false,
"properties": {
"title": {
"type": "object",
"description": "The tab title. Defines the value as either a literal value or a path to data model value (e.g. '/options/title').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
},
"child": {
"type": "string"
}
},
"required": ["title", "child"]
}
}
},
"required": ["tabItems"]
},
"Divider": {
"type": "object",
"additionalProperties": false,
"properties": {
"axis": {
"type": "string",
"description": "The orientation of the divider.",
"enum": ["horizontal", "vertical"]
}
}
},
"Modal": {
"type": "object",
"additionalProperties": false,
"properties": {
"entryPointChild": {
"type": "string",
"description": "The ID of the component that opens the modal when interacted with (e.g., a button)."
},
"contentChild": {
"type": "string",
"description": "The ID of the component to be displayed inside the modal."
}
},
"required": ["entryPointChild", "contentChild"]
},
"Button": {
"type": "object",
"additionalProperties": false,
"properties": {
"child": {
"type": "string",
"description": "The ID of the component to display in the button, typically a Text component."
},
"primary": {
"type": "boolean",
"description": "Indicates if this button should be styled as the primary action."
},
"action": {
"type": "object",
"description": "The client-side action to be dispatched when the button is clicked. It includes the action's name and an optional context payload.",
"additionalProperties": false,
"properties": {
"name": {
"type": "string"
},
"context": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": false,
"properties": {
"key": {
"type": "string"
},
"value": {
"type": "object",
"description": "Defines the value to be included in the context as either a literal value or a path to a data model value (e.g. '/user/name').",
"additionalProperties": false,
"properties": {
"path": {
"type": "string"
},
"literalString": {
"type": "string"
},
"literalNumber": {
"type": "number"
},
"literalBoolean": {
"type": "boolean"
}
}
}
},
"required": ["key", "value"]
}
}
},
"required": ["name"]
}
},
"required": ["child", "action"]
},
"CheckBox": {
"type": "object",
"additionalProperties": false,
"properties": {
"label": {
"type": "object",
"description": "The text to display next to the checkbox. Defines the value as either a literal value or a path to data model ('path', e.g. '/option/label').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
},
"value": {
"type": "object",
"description": "The current state of the checkbox (true for checked, false for unchecked). This can be a literal boolean ('literalBoolean') or a reference to a value in the data model ('path', e.g. '/filter/open').",
"additionalProperties": false,
"properties": {
"literalBoolean": {
"type": "boolean"
},
"path": {
"type": "string"
}
}
}
},
"required": ["label", "value"]
},
"TextField": {
"type": "object",
"additionalProperties": false,
"properties": {
"label": {
"type": "object",
"description": "The text label for the input field. This can be a literal string or a reference to a value in the data model ('path, e.g. '/user/name').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
},
"text": {
"type": "object",
"description": "The value of the text field. This can be a literal string or a reference to a value in the data model ('path', e.g. '/user/name').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
},
"textFieldType": {
"type": "string",
"description": "The type of input field to display.",
"enum": [
"date",
"longText",
"number",
"shortText",
"obscured"
]
},
"validationRegexp": {
"type": "string",
"description": "A regular expression used for client-side validation of the input."
}
},
"required": ["label"]
},
"DateTimeInput": {
"type": "object",
"additionalProperties": false,
"properties": {
"value": {
"type": "object",
"description": "The selected date and/or time value in ISO 8601 format. This can be a literal string ('literalString') or a reference to a value in the data model ('path', e.g. '/user/dob').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
},
"enableDate": {
"type": "boolean",
"description": "If true, allows the user to select a date."
},
"enableTime": {
"type": "boolean",
"description": "If true, allows the user to select a time."
}
},
"required": ["value"]
},
"MultipleChoice": {
"type": "object",
"additionalProperties": false,
"properties": {
"selections": {
"type": "object",
"description": "The currently selected values for the component. This can be a literal array of strings or a path to an array in the data model('path', e.g. '/hotel/options').",
"additionalProperties": false,
"properties": {
"literalArray": {
"type": "array",
"items": {
"type": "string"
}
},
"path": {
"type": "string"
}
}
},
"options": {
"type": "array",
"description": "An array of available options for the user to choose from.",
"items": {
"type": "object",
"additionalProperties": false,
"properties": {
"label": {
"type": "object",
"description": "The text to display for this option. This can be a literal string or a reference to a value in the data model (e.g. '/option/label').",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string"
},
"path": {
"type": "string"
}
}
},
"value": {
"type": "string",
"description": "The value to be associated with this option when selected."
}
},
"required": ["label", "value"]
}
},
"maxAllowedSelections": {
"type": "integer",
"description": "The maximum number of options that the user is allowed to select."
}
},
"required": ["selections", "options"]
},
"Slider": {
"type": "object",
"additionalProperties": false,
"properties": {
"value": {
"type": "object",
"description": "The current value of the slider. This can be a literal number ('literalNumber') or a reference to a value in the data model ('path', e.g. '/restaurant/cost').",
"additionalProperties": false,
"properties": {
"literalNumber": {
"type": "number"
},
"path": {
"type": "string"
}
}
},
"minValue": {
"type": "number",
"description": "The minimum value of the slider."
},
"maxValue": {
"type": "number",
"description": "The maximum value of the slider."
}
},
"required": ["value"]
}
}
}
},
"required": ["id", "component"]
}
}
},
"required": ["surfaceId", "components"]
},
"dataModelUpdate": {
"type": "object",
"description": "Updates the data model for a surface.",
"additionalProperties": false,
"properties": {
"surfaceId": {
"type": "string",
"description": "The unique identifier for the UI surface this data model update applies to."
},
"path": {
"type": "string",
"description": "An optional path to a location within the data model (e.g., '/user/name'). If omitted, or set to '/', the entire data model will be replaced."
},
"contents": {
"type": "array",
"description": "An array of data entries. Each entry must contain a 'key' and exactly one corresponding typed 'value*' property.",
"items": {
"type": "object",
"description": "A single data entry. Exactly one 'value*' property should be provided alongside the key.",
"additionalProperties": false,
"properties": {
"key": {
"type": "string",
"description": "The key for this data entry."
},
"valueString": {
"type": "string"
},
"valueNumber": {
"type": "number"
},
"valueBoolean": {
"type": "boolean"
},
"valueMap": {
"description": "Represents a map as an adjacency list.",
"type": "array",
"items": {
"type": "object",
"description": "One entry in the map. Exactly one 'value*' property should be provided alongside the key.",
"additionalProperties": false,
"properties": {
"key": {
"type": "string"
},
"valueString": {
"type": "string"
},
"valueNumber": {
"type": "number"
},
"valueBoolean": {
"type": "boolean"
}
},
"required": ["key"]
}
}
},
"required": ["key"]
}
}
},
"required": ["contents", "surfaceId"]
},
"deleteSurface": {
"type": "object",
"description": "Signals the client to delete the surface identified by 'surfaceId'.",
"additionalProperties": false,
"properties": {
"surfaceId": {
"type": "string",
"description": "The unique identifier for the UI surface to be deleted."
}
},
"required": ["surfaceId"]
}
}
}

View File

@@ -0,0 +1,459 @@
{
"components": {
"Text": {
"type": "object",
"additionalProperties": false,
"properties": {
"text": {
"type": "object",
"description": "The text content to display. This can be a literal string or a reference to a value in the data model ('path', e.g., '/doc/title'). While simple Markdown formatting is supported (i.e. without HTML, images, or links), utilizing dedicated UI components is generally preferred for a richer and more structured presentation.",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
},
"usageHint": {
"type": "string",
"description": "A hint for the base text style.",
"enum": ["h1", "h2", "h3", "h4", "h5", "caption", "body"]
}
},
"required": ["text"]
},
"Image": {
"type": "object",
"additionalProperties": false,
"properties": {
"url": {
"type": "object",
"description": "The URL of the image to display.",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
},
"fit": {
"type": "string",
"description": "Specifies how the image should be resized to fit its container.",
"enum": ["contain", "cover", "fill", "none", "scale-down"]
},
"usageHint": {
"type": "string",
"description": "A hint for the image size and style.",
"enum": ["icon", "avatar", "smallFeature", "mediumFeature", "largeFeature", "header"]
}
},
"required": ["url"]
},
"Icon": {
"type": "object",
"additionalProperties": false,
"properties": {
"name": {
"type": "object",
"description": "The name of the icon to display.",
"additionalProperties": false,
"properties": {
"literalString": {
"type": "string",
"enum": [
"accountCircle", "add", "arrowBack", "arrowForward", "attachFile",
"calendarToday", "call", "camera", "check", "close", "delete",
"download", "edit", "event", "error", "favorite", "favoriteOff",
"folder", "help", "home", "info", "locationOn", "lock", "lockOpen",
"mail", "menu", "moreVert", "moreHoriz", "notificationsOff",
"notifications", "payment", "person", "phone", "photo", "print",
"refresh", "search", "send", "settings", "share", "shoppingCart",
"star", "starHalf", "starOff", "upload", "visibility",
"visibilityOff", "warning"
]
},
"path": { "type": "string" }
}
}
},
"required": ["name"]
},
"Video": {
"type": "object",
"additionalProperties": false,
"properties": {
"url": {
"type": "object",
"description": "The URL of the video to display.",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
}
},
"required": ["url"]
},
"AudioPlayer": {
"type": "object",
"additionalProperties": false,
"properties": {
"url": {
"type": "object",
"description": "The URL of the audio to be played.",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
},
"description": {
"type": "object",
"description": "A description of the audio, such as a title or summary.",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
}
},
"required": ["url"]
},
"Row": {
"type": "object",
"additionalProperties": false,
"properties": {
"children": {
"type": "object",
"description": "Defines the children. Use 'explicitList' for a fixed set of children, or 'template' to generate children from a data list.",
"additionalProperties": false,
"properties": {
"explicitList": { "type": "array", "items": { "type": "string" } },
"template": {
"type": "object",
"additionalProperties": false,
"properties": {
"componentId": { "type": "string" },
"dataBinding": { "type": "string" }
},
"required": ["componentId", "dataBinding"]
}
}
},
"distribution": {
"type": "string",
"enum": ["center", "end", "spaceAround", "spaceBetween", "spaceEvenly", "start"]
},
"alignment": {
"type": "string",
"enum": ["start", "center", "end", "stretch"]
}
},
"required": ["children"]
},
"Column": {
"type": "object",
"additionalProperties": false,
"properties": {
"children": {
"type": "object",
"description": "Defines the children. Use 'explicitList' for a fixed set of children, or 'template' to generate children from a data list.",
"additionalProperties": false,
"properties": {
"explicitList": { "type": "array", "items": { "type": "string" } },
"template": {
"type": "object",
"additionalProperties": false,
"properties": {
"componentId": { "type": "string" },
"dataBinding": { "type": "string" }
},
"required": ["componentId", "dataBinding"]
}
}
},
"distribution": {
"type": "string",
"enum": ["start", "center", "end", "spaceBetween", "spaceAround", "spaceEvenly"]
},
"alignment": {
"type": "string",
"enum": ["center", "end", "start", "stretch"]
}
},
"required": ["children"]
},
"List": {
"type": "object",
"additionalProperties": false,
"properties": {
"children": {
"type": "object",
"description": "Defines the children. Use 'explicitList' for a fixed set of children, or 'template' to generate children from a data list.",
"additionalProperties": false,
"properties": {
"explicitList": { "type": "array", "items": { "type": "string" } },
"template": {
"type": "object",
"additionalProperties": false,
"properties": {
"componentId": { "type": "string" },
"dataBinding": { "type": "string" }
},
"required": ["componentId", "dataBinding"]
}
}
},
"direction": {
"type": "string",
"enum": ["vertical", "horizontal"]
},
"alignment": {
"type": "string",
"enum": ["start", "center", "end", "stretch"]
}
},
"required": ["children"]
},
"Card": {
"type": "object",
"additionalProperties": false,
"properties": {
"child": {
"type": "string",
"description": "The ID of the component to be rendered inside the card."
}
},
"required": ["child"]
},
"Tabs": {
"type": "object",
"additionalProperties": false,
"properties": {
"tabItems": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": false,
"properties": {
"title": {
"type": "object",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
},
"child": { "type": "string" }
},
"required": ["title", "child"]
}
}
},
"required": ["tabItems"]
},
"Divider": {
"type": "object",
"additionalProperties": false,
"properties": {
"axis": {
"type": "string",
"enum": ["horizontal", "vertical"]
}
}
},
"Modal": {
"type": "object",
"additionalProperties": false,
"properties": {
"entryPointChild": {
"type": "string",
"description": "The ID of the component that opens the modal when interacted with."
},
"contentChild": {
"type": "string",
"description": "The ID of the component to be displayed inside the modal."
}
},
"required": ["entryPointChild", "contentChild"]
},
"Button": {
"type": "object",
"additionalProperties": false,
"properties": {
"child": {
"type": "string",
"description": "The ID of the component to display in the button."
},
"primary": {
"type": "boolean",
"description": "Indicates if this button should be styled as the primary action."
},
"action": {
"type": "object",
"additionalProperties": false,
"properties": {
"name": { "type": "string" },
"context": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": false,
"properties": {
"key": { "type": "string" },
"value": {
"type": "object",
"additionalProperties": false,
"properties": {
"path": { "type": "string" },
"literalString": { "type": "string" },
"literalNumber": { "type": "number" },
"literalBoolean": { "type": "boolean" }
}
}
},
"required": ["key", "value"]
}
}
},
"required": ["name"]
}
},
"required": ["child", "action"]
},
"CheckBox": {
"type": "object",
"additionalProperties": false,
"properties": {
"label": {
"type": "object",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
},
"value": {
"type": "object",
"additionalProperties": false,
"properties": {
"literalBoolean": { "type": "boolean" },
"path": { "type": "string" }
}
}
},
"required": ["label", "value"]
},
"TextField": {
"type": "object",
"additionalProperties": false,
"properties": {
"label": {
"type": "object",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
},
"text": {
"type": "object",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
},
"textFieldType": {
"type": "string",
"enum": ["date", "longText", "number", "shortText", "obscured"]
},
"validationRegexp": { "type": "string" }
},
"required": ["label"]
},
"DateTimeInput": {
"type": "object",
"additionalProperties": false,
"properties": {
"value": {
"type": "object",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
},
"enableDate": { "type": "boolean" },
"enableTime": { "type": "boolean" }
},
"required": ["value"]
},
"MultipleChoice": {
"type": "object",
"additionalProperties": false,
"properties": {
"selections": {
"type": "object",
"additionalProperties": false,
"properties": {
"literalArray": { "type": "array", "items": { "type": "string" } },
"path": { "type": "string" }
}
},
"options": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": false,
"properties": {
"label": {
"type": "object",
"additionalProperties": false,
"properties": {
"literalString": { "type": "string" },
"path": { "type": "string" }
}
},
"value": { "type": "string" }
},
"required": ["label", "value"]
}
},
"maxAllowedSelections": { "type": "integer" },
"variant": {
"type": "string",
"enum": ["checkbox", "chips"]
},
"filterable": { "type": "boolean" }
},
"required": ["selections", "options"]
},
"Slider": {
"type": "object",
"additionalProperties": false,
"properties": {
"value": {
"type": "object",
"additionalProperties": false,
"properties": {
"literalNumber": { "type": "number" },
"path": { "type": "string" }
}
},
"minValue": { "type": "number" },
"maxValue": { "type": "number" }
},
"required": ["value"]
}
},
"styles": {
"font": {
"type": "string",
"description": "The primary font for the UI."
},
"primaryColor": {
"type": "string",
"description": "The primary UI color as a hexadecimal code (e.g., '#00BFFF').",
"pattern": "^#[0-9a-fA-F]{6}$"
}
}
}

View File

@@ -0,0 +1,125 @@
"""A2UI server extension for the A2A protocol."""
from __future__ import annotations
import logging
from typing import Any
from crewai.a2a.extensions.a2ui.models import A2UIResponse, extract_a2ui_json_objects
from crewai.a2a.extensions.a2ui.validator import (
A2UIValidationError,
validate_a2ui_message,
)
from crewai.a2a.extensions.server import ExtensionContext, ServerExtension
logger = logging.getLogger(__name__)
A2UI_MIME_TYPE = "application/json+a2ui"
A2UI_EXTENSION_URI = "https://a2ui.org/a2a-extension/a2ui/v0.8"
class A2UIServerExtension(ServerExtension):
"""A2A server extension that enables A2UI declarative UI generation.
When activated by a client that declares A2UI v0.8 support,
this extension:
* Negotiates catalog preferences during ``on_request``.
* Wraps A2UI messages in the agent response as A2A DataParts with
``application/json+a2ui`` MIME type during ``on_response``.
Example::
A2AServerConfig(
server_extensions=[A2UIServerExtension()],
default_output_modes=["text/plain", "application/json+a2ui"],
)
"""
uri: str = A2UI_EXTENSION_URI
required: bool = False
description: str = "A2UI declarative UI generation"
def __init__(
self,
catalog_ids: list[str] | None = None,
accept_inline_catalogs: bool = False,
) -> None:
"""Initialize the A2UI server extension.
Args:
catalog_ids: Catalog identifiers this server supports.
accept_inline_catalogs: Whether inline catalog definitions are accepted.
"""
self._catalog_ids = catalog_ids or []
self._accept_inline_catalogs = accept_inline_catalogs
@property
def params(self) -> dict[str, Any]:
"""Extension parameters advertised in the AgentCard."""
result: dict[str, Any] = {}
if self._catalog_ids:
result["supportedCatalogIds"] = self._catalog_ids
result["acceptsInlineCatalogs"] = self._accept_inline_catalogs
return result
async def on_request(self, context: ExtensionContext) -> None:
"""Extract A2UI catalog preferences from the client request.
Stores the negotiated catalog in ``context.state`` under
``"a2ui_catalog_id"`` for downstream use.
"""
if not self.is_active(context):
return
catalog_id = context.get_extension_metadata(self.uri, "catalogId")
if isinstance(catalog_id, str):
context.state["a2ui_catalog_id"] = catalog_id
elif self._catalog_ids:
context.state["a2ui_catalog_id"] = self._catalog_ids[0]
context.state["a2ui_active"] = True
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Wrap A2UI messages in the result as A2A DataParts.
Scans the result for A2UI JSON payloads and converts them into
DataParts with ``application/json+a2ui`` MIME type and A2UI metadata.
"""
if not context.state.get("a2ui_active"):
return result
if not isinstance(result, str):
return result
a2ui_messages = extract_a2ui_json_objects(result)
if not a2ui_messages:
return result
data_parts = [
part
for part in (_build_data_part(msg_data) for msg_data in a2ui_messages)
if part is not None
]
if not data_parts:
return result
return A2UIResponse(text=result, a2ui_parts=data_parts)
def _build_data_part(msg_data: dict[str, Any]) -> dict[str, Any] | None:
"""Validate a single A2UI message and wrap it as a DataPart dict."""
try:
validated = validate_a2ui_message(msg_data)
except A2UIValidationError:
logger.warning("Skipping invalid A2UI message in response", exc_info=True)
return None
return {
"kind": "data",
"data": validated.model_dump(by_alias=True, exclude_none=True),
"metadata": {
"mimeType": A2UI_MIME_TYPE,
},
}

View File

@@ -0,0 +1,59 @@
"""Validate A2UI message dicts via Pydantic models."""
from __future__ import annotations
from typing import Any
from pydantic import ValidationError
from crewai.a2a.extensions.a2ui.models import A2UIEvent, A2UIMessage
class A2UIValidationError(Exception):
"""Raised when an A2UI message fails validation."""
def __init__(self, message: str, errors: list[Any] | None = None) -> None:
super().__init__(message)
self.errors = errors or []
def validate_a2ui_message(data: dict[str, Any]) -> A2UIMessage:
"""Parse and validate an A2UI server-to-client message.
Args:
data: Raw message dict (JSON-decoded).
Returns:
Validated ``A2UIMessage`` instance.
Raises:
A2UIValidationError: If the data does not conform to the A2UI schema.
"""
try:
return A2UIMessage.model_validate(data)
except ValidationError as exc:
raise A2UIValidationError(
f"Invalid A2UI message: {exc.error_count()} validation error(s)",
errors=exc.errors(),
) from exc
def validate_a2ui_event(data: dict[str, Any]) -> A2UIEvent:
"""Parse and validate an A2UI client-to-server event.
Args:
data: Raw event dict (JSON-decoded).
Returns:
Validated ``A2UIEvent`` instance.
Raises:
A2UIValidationError: If the data does not conform to the A2UI event schema.
"""
try:
return A2UIEvent.model_validate(data)
except ValidationError as exc:
raise A2UIValidationError(
f"Invalid A2UI event: {exc.error_count()} validation error(s)",
errors=exc.errors(),
) from exc

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypedDict
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client.errors import A2AClientHTTPError
@@ -18,7 +18,7 @@ from a2a.types import (
TaskStatusUpdateEvent,
TextPart,
)
from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (

View File

@@ -7,12 +7,11 @@ from typing import (
Any,
Literal,
Protocol,
TypedDict,
runtime_checkable,
)
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict
try:

View File

@@ -2,10 +2,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from typing_extensions import TypedDict
class CommonParams(NamedTuple):

View File

@@ -28,6 +28,7 @@ APPLICATION_PDF: Literal["application/pdf"] = "application/pdf"
APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = (
"application/octet-stream"
)
APPLICATION_A2UI_JSON: Literal["application/json+a2ui"] = "application/json+a2ui"
DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
@@ -311,6 +312,10 @@ def get_part_content_type(part: Part) -> str:
if root.kind == "text":
return TEXT_PLAIN
if root.kind == "data":
metadata = root.metadata or {}
mime = metadata.get("mimeType", "")
if mime == APPLICATION_A2UI_JSON:
return APPLICATION_A2UI_JSON
return APPLICATION_JSON
if root.kind == "file":
return root.file.mime_type or APPLICATION_OCTET_STREAM

View File

@@ -10,7 +10,7 @@ from functools import wraps
import json
import logging
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from urllib.parse import urlparse
from a2a.server.agent_execution import RequestContext
@@ -38,6 +38,7 @@ from a2a.utils import (
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from pydantic import BaseModel
from typing_extensions import TypedDict
from crewai.a2a.utils.agent_card import _get_server_config
from crewai.a2a.utils.content_type import validate_message_parts

View File

@@ -2,15 +2,19 @@ from pathlib import Path
import click
from crewai_cli.utils import copy_template
from crewai.cli.utils import copy_template
from crewai.utilities.printer import Printer
_printer = Printer()
def add_crew_to_flow(crew_name: str) -> None:
"""Add a new crew to the current flow."""
# Check if pyproject.toml exists in the current directory
if not Path("pyproject.toml").exists():
click.secho(
"This command must be run from the root of a flow project.", fg="red"
_printer.print(
"This command must be run from the root of a flow project.", color="red"
)
raise click.ClickException(
"This command must be run from the root of a flow project."
@@ -21,7 +25,7 @@ def add_crew_to_flow(crew_name: str) -> None:
crews_folder = flow_folder / "src" / flow_folder.name / "crews"
if not crews_folder.exists():
click.secho("Crews folder does not exist in the current flow.", fg="red")
_printer.print("Crews folder does not exist in the current flow.", color="red")
raise click.ClickException("Crews folder does not exist in the current flow.")
# Create the crew within the flow's crews directory

View File

@@ -180,7 +180,7 @@ class AuthenticationCommand:
def _login_to_tool_repository(self) -> None:
"""Login to the tool repository."""
from crewai_cli.tools.main import ToolCommand
from crewai.cli.tools.main import ToolCommand
try:
console.print(

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from importlib.metadata import version as get_version
import os
import subprocess
@@ -7,58 +5,44 @@ from typing import Any
import click
from crewai_cli.add_crew_to_flow import add_crew_to_flow
from crewai_cli.authentication.main import AuthenticationCommand
from crewai_cli.config import Settings
from crewai_cli.create_crew import create_crew
from crewai_cli.create_flow import create_flow
from crewai_cli.crew_chat import run_chat
from crewai_cli.deploy.main import DeployCommand
from crewai_cli.enterprise.main import EnterpriseConfigureCommand
from crewai_cli.evaluate_crew import evaluate_crew
from crewai_cli.install_crew import install_crew
from crewai_cli.kickoff_flow import kickoff_flow
from crewai_cli.organization.main import OrganizationCommand
from crewai_cli.plot_flow import plot_flow
from crewai_cli.replay_from_task import replay_task_command
from crewai_cli.reset_memories_command import reset_memories_command
from crewai_cli.run_crew import run_crew
from crewai_cli.settings.main import SettingsCommand
from crewai_cli.task_outputs import load_task_outputs
from crewai_cli.tools.main import ToolCommand
from crewai_cli.train_crew import train_crew
from crewai_cli.triggers.main import TriggersCommand
from crewai_cli.update_crew import update_crew
from crewai_cli.user_data import (
_load_user_data,
_save_user_data,
is_tracing_enabled,
from crewai.cli.add_crew_to_flow import add_crew_to_flow
from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.config import Settings
from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow
from crewai.cli.crew_chat import run_chat
from crewai.cli.deploy.main import DeployCommand
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
from crewai.cli.evaluate_crew import evaluate_crew
from crewai.cli.install_crew import install_crew
from crewai.cli.kickoff_flow import kickoff_flow
from crewai.cli.organization.main import OrganizationCommand
from crewai.cli.plot_flow import plot_flow
from crewai.cli.replay_from_task import replay_task_command
from crewai.cli.reset_memories_command import reset_memories_command
from crewai.cli.run_crew import run_crew
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.tools.main import ToolCommand
from crewai.cli.train_crew import train_crew
from crewai.cli.triggers.main import TriggersCommand
from crewai.cli.update_crew import update_crew
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
from crewai_cli.utils import build_env_with_tool_repository_credentials, read_toml
def _get_cli_version() -> str:
"""Return the best available version string for the CLI."""
# Prefer crewai version if installed (keeps existing UX)
try:
return get_version("crewai")
except Exception: # noqa: S110
pass
try:
return get_version("crewai-cli")
except Exception:
return "unknown"
@click.group()
@click.version_option(_get_cli_version())
@click.version_option(get_version("crewai"))
def crewai():
"""Top-level command group for crewai."""
@crewai.command(
name="uv",
context_settings={"ignore_unknown_options": True},
context_settings=dict(
ignore_unknown_options=True,
),
)
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
def uv(uv_args):
@@ -123,7 +107,7 @@ def version(tools):
if tools:
try:
tools_version = get_version("crewai-tools")
tools_version = get_version("crewai")
click.echo(f"crewai tools version: {tools_version}")
except Exception:
click.echo("crewai tools not installed")
@@ -158,7 +142,12 @@ def train(n_iterations: int, filename: str):
help="Replay the crew from this task ID, including all subsequent tasks.",
)
def replay(task_id: str) -> None:
"""Replay the crew execution from a specific task."""
"""
Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
try:
click.echo(f"Replaying the crew from task {task_id}")
replay_task_command(task_id)
@@ -168,9 +157,12 @@ def replay(task_id: str) -> None:
@crewai.command()
def log_tasks_outputs() -> None:
"""Retrieve your latest crew.kickoff() task outputs."""
"""
Retrieve your latest crew.kickoff() task outputs.
"""
try:
tasks = load_task_outputs()
storage = KickoffTaskOutputsSQLiteStorage()
tasks = storage.load()
if not tasks:
click.echo(
@@ -228,8 +220,11 @@ def reset_memories(
agent_knowledge: bool,
all: bool,
) -> None:
"""Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved."""
"""
Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved.
"""
try:
# Treat legacy flags as --memory with a deprecation warning
if long or short or entities:
legacy_used = [
f
@@ -296,7 +291,7 @@ def memory(
) -> None:
"""Open the Memory TUI to browse scopes and recall memories."""
try:
from crewai_cli.memory_tui import MemoryTUI
from crewai.cli.memory_tui import MemoryTUI
except ImportError as exc:
click.echo(
"Textual is required for the memory TUI but could not be imported. "
@@ -346,10 +341,10 @@ def test(n_iterations: int, model: str):
@crewai.command(
context_settings={
"ignore_unknown_options": True,
"allow_extra_args": True,
}
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.pass_context
def install(context):
@@ -514,12 +509,14 @@ def triggers_run(trigger_path: str):
@crewai.command()
def chat():
"""Start a conversation with the Crew, collecting user-supplied inputs,
"""
Start a conversation with the Crew, collecting user-supplied inputs,
and using the Chat LLM to generate responses.
"""
click.secho(
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
)
run_chat()
@@ -630,7 +627,7 @@ def env_view():
table.add_row(
"CREWAI_TRACING_ENABLED",
"[dim]Not set[/dim]",
"[dim]---[/dim]",
"[dim][/dim]",
)
# Check other related env vars
@@ -649,7 +646,7 @@ def env_view():
# Check if .env file exists
table.add_row(
".env file",
"Found" if env_file_exists else "Not found",
"Found" if env_file_exists else "Not found",
str(env_file.resolve()) if env_file_exists else "N/A",
)
@@ -665,11 +662,11 @@ def env_view():
# Show helpful message
if env_file_exists:
console.print(
"\n[dim]Tip: To enable tracing via .env, add: CREWAI_TRACING_ENABLED=true[/dim]"
"\n[dim]💡 Tip: To enable tracing via .env, add: CREWAI_TRACING_ENABLED=true[/dim]"
)
else:
console.print(
"\n[dim]Tip: Create a .env file in your project root and add: CREWAI_TRACING_ENABLED=true[/dim]"
"\n[dim]💡 Tip: Create a .env file in your project root and add: CREWAI_TRACING_ENABLED=true[/dim]"
)
console.print()
@@ -685,16 +682,14 @@ def traces_enable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import update_user_data
console = Console()
# Update user data to enable traces
user_data = _load_user_data()
user_data["trace_consent"] = True
user_data["first_execution_done"] = True
_save_user_data(user_data)
update_user_data({"trace_consent": True, "first_execution_done": True})
panel = Panel(
"Trace collection has been enabled!\n\n"
"Trace collection has been enabled!\n\n"
"Your crew/flow executions will now send traces to CrewAI+.\n"
"Use 'crewai traces disable' to turn off trace collection.",
title="Traces Enabled",
@@ -710,16 +705,14 @@ def traces_disable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import update_user_data
console = Console()
# Update user data to disable traces
user_data = _load_user_data()
user_data["trace_consent"] = False
user_data["first_execution_done"] = True
_save_user_data(user_data)
update_user_data({"trace_consent": False, "first_execution_done": True})
panel = Panel(
"Trace collection has been disabled!\n\n"
"Trace collection has been disabled!\n\n"
"Your crew/flow executions will no longer send traces.\n"
"Use 'crewai traces enable' to turn trace collection back on.",
title="Traces Disabled",
@@ -738,6 +731,11 @@ def traces_status():
from rich.panel import Panel
from rich.table import Table
from crewai.events.listeners.tracing.utils import (
_load_user_data,
is_tracing_enabled,
)
console = Console()
user_data = _load_user_data()
@@ -752,19 +750,19 @@ def traces_status():
# Check user consent
trace_consent = user_data.get("trace_consent")
if trace_consent is True:
consent_status = "Enabled (user consented)"
consent_status = "Enabled (user consented)"
elif trace_consent is False:
consent_status = "Disabled (user declined)"
consent_status = "Disabled (user declined)"
else:
consent_status = "Not set (first-time user)"
consent_status = "Not set (first-time user)"
table.add_row("User Consent", consent_status)
# Check overall status
if is_tracing_enabled():
overall_status = "ENABLED"
overall_status = "ENABLED"
border_style = "green"
else:
overall_status = "DISABLED"
overall_status = "DISABLED"
border_style = "red"
table.add_row("Overall Status", overall_status)

View File

@@ -5,13 +5,13 @@ import sys
import click
import tomli
from crewai_cli.constants import ENV_VARS, MODELS
from crewai_cli.provider import (
from crewai.cli.constants import ENV_VARS, MODELS
from crewai.cli.provider import (
get_provider_data,
select_model,
select_provider,
)
from crewai_cli.utils import copy_template, load_env_vars, write_env_file
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
def get_reserved_script_names() -> set[str]:

View File

@@ -3,6 +3,8 @@ import shutil
import click
from crewai.telemetry import Telemetry
def create_flow(name):
"""Create a new flow."""
@@ -16,6 +18,10 @@ def create_flow(name):
click.secho(f"Error: Folder {folder_name} already exists.", fg="red")
return
# Initialize telemetry
telemetry = Telemetry()
telemetry.flow_creation_span(class_name)
# Create directory structure
(project_root / "src" / folder_name).mkdir(parents=True)
(project_root / "src" / folder_name / "crews").mkdir(parents=True)

View File

@@ -1,11 +1,10 @@
from pathlib import Path
from typing import Any
from rich.console import Console
from crewai_cli import git
from crewai_cli.command import BaseCommand, PlusAPIMixin
from crewai_cli.utils import fetch_and_json_env_file, get_project_name
from crewai.cli import git
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.utils import fetch_and_json_env_file, get_project_name
console = Console()
@@ -22,43 +21,8 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
"""
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
self.project_name = get_project_name(require=True)
self._validate_project_structure()
def _validate_project_structure(self) -> None:
"""Validate that the local project has the files required for deployment."""
errors: list[str] = []
if not Path("pyproject.toml").exists():
errors.append("Cannot find pyproject.toml in the current directory.")
has_lockfile = Path("uv.lock").exists() or Path("poetry.lock").exists()
if not has_lockfile:
errors.append(
"No uv.lock or poetry.lock found. "
"Run 'uv lock' or 'poetry lock' to generate one."
)
src_dir = Path("src") / (self.project_name or "")
crew_py = src_dir / "crew.py"
config_dir = src_dir / "config"
if not crew_py.exists() and not config_dir.exists():
errors.append(
f"Cannot find src/{self.project_name}/crew.py or "
f"src/{self.project_name}/config. "
"Ensure you are running this command from the project root."
)
if errors:
console.print(
"\n[bold red]Pre-flight check failed:[/bold red] "
"Your project is missing required files for deployment.\n"
)
for error in errors:
console.print(f"{error}", style="red")
console.print()
raise SystemExit(1)
def _standard_no_param_error_message(self) -> None:
"""
@@ -103,6 +67,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
Args:
uuid (Optional[str]): The UUID of the crew to deploy.
"""
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
console.print("Starting deployment...", style="bold blue")
if uuid:
response = self.plus_api_client.deploy_by_uuid(uuid)
@@ -119,6 +84,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
"""
Create a new crew deployment.
"""
self._create_crew_deployment_span = (
self._telemetry.create_crew_deployment_span()
)
console.print("Creating deployment...", style="bold blue")
env_vars = fetch_and_json_env_file()
@@ -268,6 +236,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
uuid (Optional[str]): The UUID of the crew to get logs for.
log_type (str): The type of logs to retrieve (default: "deployment").
"""
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
console.print(f"Fetching {log_type} logs...", style="bold blue")
if uuid:
@@ -288,6 +257,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
Args:
uuid (Optional[str]): The UUID of the crew to remove.
"""
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
console.print("Removing deployment...", style="bold blue")
if uuid:

View File

@@ -4,10 +4,10 @@ from typing import Any, cast
import httpx
from rich.console import Console
from crewai_cli.authentication.main import Oauth2Settings, ProviderFactory
from crewai_cli.command import BaseCommand
from crewai_cli.settings.main import SettingsCommand
from crewai_cli.version import get_crewai_version
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
from crewai.cli.command import BaseCommand
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.version import get_crewai_version
console = Console()

View File

@@ -2,8 +2,8 @@ from httpx import HTTPStatusError
from rich.console import Console
from rich.table import Table
from crewai_cli.command import BaseCommand, PlusAPIMixin
from crewai_cli.config import Settings
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.config import Settings
console = Console()
@@ -12,7 +12,7 @@ console = Console()
class OrganizationCommand(BaseCommand, PlusAPIMixin):
def __init__(self) -> None:
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
def list(self) -> None:
try:

View File

@@ -5,8 +5,8 @@ import subprocess
import click
from packaging import version
from crewai_cli.utils import build_env_with_tool_repository_credentials, read_toml
from crewai_cli.version import get_crewai_version
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
from crewai.cli.version import get_crewai_version
class CrewType(Enum):

View File

@@ -5,9 +5,9 @@ from typing import Any
from rich.console import Console
from rich.table import Table
from crewai_cli.command import BaseCommand
from crewai_cli.config import HIDDEN_SETTINGS_KEYS, READONLY_SETTINGS_KEYS, Settings
from crewai_cli.user_data import _load_user_data
from crewai.cli.command import BaseCommand
from crewai.cli.config import HIDDEN_SETTINGS_KEYS, READONLY_SETTINGS_KEYS, Settings
from crewai.events.listeners.tracing.utils import _load_user_data
console = Console()

Some files were not shown because too many files have changed in this diff Show More