mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-15 08:18:19 +00:00
Compare commits
2 Commits
gl/chore/r
...
gl/ci/pr-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be6453fc1c | ||
|
|
156b9d3285 |
1
.github/workflows/linter.yml
vendored
1
.github/workflows/linter.yml
vendored
@@ -55,7 +55,6 @@ jobs:
|
||||
echo "${{ steps.changed-files.outputs.files }}" \
|
||||
| tr ' ' '\n' \
|
||||
| grep -v 'src/crewai/cli/templates/' \
|
||||
| grep -v 'src/crewai_cli/templates/' \
|
||||
| grep -v '/tests/' \
|
||||
| xargs -I{} uv run ruff check "{}"
|
||||
|
||||
|
||||
37
.github/workflows/pr-size.yml
vendored
Normal file
37
.github/workflows/pr-size.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: PR Size Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
jobs:
|
||||
pr-size:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: codelytv/pr-size-labeler@v1
|
||||
with:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
xs_label: "size/XS"
|
||||
xs_max_size: 25
|
||||
s_label: "size/S"
|
||||
s_max_size: 100
|
||||
m_label: "size/M"
|
||||
m_max_size: 250
|
||||
l_label: "size/L"
|
||||
l_max_size: 500
|
||||
xl_label: "size/XL"
|
||||
fail_if_xl: true
|
||||
message_if_xl: >
|
||||
This PR exceeds 500 lines changed and has been labeled `size/XL`.
|
||||
PRs of this size require release manager approval to merge.
|
||||
Please consider splitting into smaller PRs, or add a justification
|
||||
in the PR description for why this cannot be broken up.
|
||||
files_to_ignore: |
|
||||
uv.lock
|
||||
*.lock
|
||||
lib/crewai/src/crewai/cli/templates/**
|
||||
**/*.json
|
||||
**/test_durations/**
|
||||
**/cassettes/**
|
||||
38
.github/workflows/pr-title.yml
vendored
Normal file
38
.github/workflows/pr-title.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: PR Title Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, synchronize, reopened]
|
||||
|
||||
jobs:
|
||||
pr-title:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: amannn/action-semantic-pull-request@v5
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
types: |
|
||||
feat
|
||||
fix
|
||||
refactor
|
||||
perf
|
||||
test
|
||||
docs
|
||||
chore
|
||||
ci
|
||||
style
|
||||
revert
|
||||
requireScope: false
|
||||
subjectPattern: ^[a-z].+[^.]$
|
||||
subjectPatternError: >
|
||||
The PR title "{title}" does not follow conventional commit format.
|
||||
|
||||
Expected: <type>(<scope>): <lowercase description without trailing period>
|
||||
|
||||
Examples:
|
||||
feat(memory): add lancedb storage backend
|
||||
fix(agents): resolve deadlock in concurrent execution
|
||||
chore(deps): bump pydantic to 2.11.9
|
||||
|
||||
See RELEASE_PROCESS.md for the full commit message convention.
|
||||
71
.github/workflows/publish.yml
vendored
71
.github/workflows/publish.yml
vendored
@@ -59,8 +59,6 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.release_tag || github.ref }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
@@ -95,72 +93,3 @@ jobs:
|
||||
echo "Some packages failed to publish"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Build Slack payload
|
||||
if: success()
|
||||
id: slack
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
RELEASE_TAG: ${{ inputs.release_tag }}
|
||||
run: |
|
||||
payload=$(uv run python -c "
|
||||
import json, re, subprocess, sys
|
||||
|
||||
with open('lib/crewai/src/crewai/__init__.py') as f:
|
||||
m = re.search(r\"__version__\s*=\s*[\\\"']([^\\\"']+)\", f.read())
|
||||
version = m.group(1) if m else 'unknown'
|
||||
|
||||
import os
|
||||
tag = os.environ.get('RELEASE_TAG') or version
|
||||
|
||||
try:
|
||||
r = subprocess.run(['gh','release','view',tag,'--json','body','-q','.body'],
|
||||
capture_output=True, text=True, check=True)
|
||||
body = r.stdout.strip()
|
||||
except Exception:
|
||||
body = ''
|
||||
|
||||
blocks = [
|
||||
{'type':'section','text':{'type':'mrkdwn',
|
||||
'text':f':rocket: \`crewai v{version}\` published to PyPI'}},
|
||||
{'type':'section','text':{'type':'mrkdwn',
|
||||
'text':f'<https://pypi.org/project/crewai/{version}/|View on PyPI> · <https://github.com/crewAIInc/crewAI/releases/tag/{tag}|Release notes>'}},
|
||||
{'type':'divider'},
|
||||
]
|
||||
|
||||
if body:
|
||||
heading, items = '', []
|
||||
for line in body.split('\n'):
|
||||
line = line.strip()
|
||||
if not line: continue
|
||||
hm = re.match(r'^#{2,3}\s+(.*)', line)
|
||||
if hm:
|
||||
if heading and items:
|
||||
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
|
||||
if not skip:
|
||||
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
|
||||
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
|
||||
heading, items = hm.group(1), []
|
||||
elif line.startswith('- ') or line.startswith('* '):
|
||||
items.append(re.sub(r'\*\*([^*]*)\*\*', r'*\1*', line[2:]))
|
||||
if heading and items:
|
||||
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
|
||||
if not skip:
|
||||
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
|
||||
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
|
||||
|
||||
blocks.append({'type':'divider'})
|
||||
blocks.append({'type':'section','text':{'type':'mrkdwn',
|
||||
'text':f'\`\`\`uv add \"crewai[tools]=={version}\"\`\`\`'}})
|
||||
|
||||
print(json.dumps({'blocks':blocks}))
|
||||
")
|
||||
echo "payload=$payload" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Notify Slack
|
||||
if: success()
|
||||
uses: slackapi/slack-github-action@v2.1.0
|
||||
with:
|
||||
webhook: ${{ secrets.SLACK_WEBHOOK_URL }}
|
||||
webhook-type: incoming-webhook
|
||||
payload: ${{ steps.slack.outputs.payload }}
|
||||
|
||||
@@ -19,7 +19,7 @@ repos:
|
||||
language: system
|
||||
pass_filenames: true
|
||||
types: [python]
|
||||
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/cli/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
|
||||
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 0.9.3
|
||||
hooks:
|
||||
|
||||
@@ -226,7 +226,7 @@ def vcr_cassette_dir(request: Any) -> str:
|
||||
|
||||
for parent in test_file.parents:
|
||||
if (
|
||||
parent.name in ("crewai", "crewai-tools", "crewai-files", "cli")
|
||||
parent.name in ("crewai", "crewai-tools", "crewai-files")
|
||||
and parent.parent.name == "lib"
|
||||
):
|
||||
package_root = parent
|
||||
|
||||
@@ -4,49 +4,6 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="Mar 14, 2026">
|
||||
## v1.10.2rc2
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Bug Fixes
|
||||
- Remove exclusive locks from read-only storage operations
|
||||
|
||||
### Documentation
|
||||
- Update changelog and version for v1.10.2rc1
|
||||
|
||||
## Contributors
|
||||
|
||||
@greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Mar 13, 2026">
|
||||
## v1.10.2rc1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Features
|
||||
- Add release command and trigger PyPI publish
|
||||
|
||||
### Bug Fixes
|
||||
- Fix cross-process and thread-safe locking to unprotected I/O
|
||||
- Propagate contextvars across all thread and executor boundaries
|
||||
- Propagate ContextVars into async task threads
|
||||
|
||||
### Documentation
|
||||
- Update changelog and version for v1.10.2a1
|
||||
|
||||
## Contributors
|
||||
|
||||
@danglies007, @greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Mar 11, 2026">
|
||||
## v1.10.2a1
|
||||
|
||||
|
||||
@@ -4,49 +4,6 @@ description: "CrewAI의 제품 업데이트, 개선 사항 및 버그 수정"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="2026년 3월 14일">
|
||||
## v1.10.2rc2
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
### 버그 수정
|
||||
- 읽기 전용 스토리지 작업에서 독점 잠금 제거
|
||||
|
||||
### 문서
|
||||
- v1.10.2rc1에 대한 변경 로그 및 버전 업데이트
|
||||
|
||||
## 기여자
|
||||
|
||||
@greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2026년 3월 13일">
|
||||
## v1.10.2rc1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
### 기능
|
||||
- 릴리스 명령 추가 및 PyPI 게시 트리거
|
||||
|
||||
### 버그 수정
|
||||
- 보호되지 않은 I/O에 대한 프로세스 간 및 스레드 안전 잠금 수정
|
||||
- 모든 스레드 및 실행기 경계를 넘는 contextvars 전파
|
||||
- async 작업 스레드로 ContextVars 전파
|
||||
|
||||
### 문서
|
||||
- v1.10.2a1에 대한 변경 로그 및 버전 업데이트
|
||||
|
||||
## 기여자
|
||||
|
||||
@danglies007, @greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2026년 3월 11일">
|
||||
## v1.10.2a1
|
||||
|
||||
|
||||
@@ -4,49 +4,6 @@ description: "Atualizações de produto, melhorias e correções do CrewAI"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="14 mar 2026">
|
||||
## v1.10.2rc2
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
### Correções de Bugs
|
||||
- Remover bloqueios exclusivos de operações de armazenamento somente leitura
|
||||
|
||||
### Documentação
|
||||
- Atualizar changelog e versão para v1.10.2rc1
|
||||
|
||||
## Contribuidores
|
||||
|
||||
@greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="13 mar 2026">
|
||||
## v1.10.2rc1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
### Funcionalidades
|
||||
- Adicionar comando de lançamento e acionar publicação no PyPI
|
||||
|
||||
### Correções de Bugs
|
||||
- Corrigir bloqueio seguro entre processos e threads para I/O não protegido
|
||||
- Propagar contextvars através de todos os limites de thread e executor
|
||||
- Propagar ContextVars para threads de tarefas assíncronas
|
||||
|
||||
### Documentação
|
||||
- Atualizar changelog e versão para v1.10.2a1
|
||||
|
||||
## Contribuidores
|
||||
|
||||
@danglies007, @greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="11 mar 2026">
|
||||
## v1.10.2a1
|
||||
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
# crewai-cli
|
||||
|
||||
CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install crewai-cli
|
||||
```
|
||||
|
||||
Or install alongside the full framework:
|
||||
|
||||
```bash
|
||||
pip install crewai[cli]
|
||||
```
|
||||
@@ -1,39 +0,0 @@
|
||||
[project]
|
||||
name = "crewai-cli"
|
||||
version = "1.10.0"
|
||||
description = "CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework."
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Joao Moura", email = "joao@crewai.com" }
|
||||
]
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"click~=8.1.7",
|
||||
"pydantic~=2.11.9",
|
||||
"pydantic-settings~=2.10.1",
|
||||
"appdirs~=1.4.4",
|
||||
"httpx~=0.28.1",
|
||||
"pyjwt>=2.9.0,<3",
|
||||
"rich>=13.7.1",
|
||||
"tomli~=2.0.2",
|
||||
"tomli-w~=1.1.0",
|
||||
"packaging>=23.0",
|
||||
"python-dotenv~=1.1.1",
|
||||
"uv~=0.9.13",
|
||||
"portalocker~=2.7.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://crewai.com"
|
||||
Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.scripts]
|
||||
crewai = "crewai_cli.cli:crewai"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/crewai_cli"]
|
||||
@@ -1 +0,0 @@
|
||||
__version__ = "1.10.0"
|
||||
@@ -1,4 +0,0 @@
|
||||
from crewai_cli.authentication.main import AuthenticationCommand
|
||||
|
||||
|
||||
__all__ = ["AuthenticationCommand"]
|
||||
@@ -1 +0,0 @@
|
||||
ALGORITHMS = ["RS256"]
|
||||
@@ -1,215 +0,0 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
import webbrowser
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
from rich.console import Console
|
||||
|
||||
from crewai_cli.authentication.utils import validate_jwt_token
|
||||
from crewai_cli.config import Settings
|
||||
from crewai_cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
|
||||
|
||||
|
||||
class Oauth2Settings(BaseModel):
|
||||
provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
|
||||
)
|
||||
client_id: str = Field(
|
||||
description="OAuth2 client ID issued by the provider, used during authentication requests."
|
||||
)
|
||||
domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
|
||||
)
|
||||
audience: str | None = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=None,
|
||||
)
|
||||
extra: dict[str, Any] = Field(
|
||||
description="Extra configuration for the OAuth2 provider.",
|
||||
default={},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
|
||||
"""Create an Oauth2Settings instance from the CLI settings."""
|
||||
|
||||
settings = Settings()
|
||||
|
||||
return cls(
|
||||
provider=settings.oauth2_provider,
|
||||
domain=settings.oauth2_domain,
|
||||
client_id=settings.oauth2_client_id,
|
||||
audience=settings.oauth2_audience,
|
||||
extra=settings.oauth2_extra,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class ProviderFactory:
|
||||
@classmethod
|
||||
def from_settings(
|
||||
cls: type["ProviderFactory"], # noqa: UP037
|
||||
settings: Oauth2Settings | None = None,
|
||||
) -> "BaseProvider": # noqa: UP037
|
||||
settings = settings or Oauth2Settings.from_settings()
|
||||
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(
|
||||
f"crewai_cli.authentication.providers.{settings.provider.lower()}"
|
||||
)
|
||||
# Converts from snake_case to CamelCase to obtain the provider class name.
|
||||
provider = getattr(
|
||||
module,
|
||||
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
|
||||
)
|
||||
|
||||
return cast("BaseProvider", provider(settings))
|
||||
|
||||
|
||||
class AuthenticationCommand:
|
||||
def __init__(self) -> None:
|
||||
self.token_manager = TokenManager()
|
||||
self.oauth2_provider = ProviderFactory.from_settings()
|
||||
|
||||
def login(self) -> None:
|
||||
"""Sign up to CrewAI+"""
|
||||
console.print("Signing in to CrewAI AMP...\n", style="bold blue")
|
||||
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
|
||||
return self._poll_for_token(device_code_data)
|
||||
|
||||
def _get_device_code(self) -> dict[str, Any]:
|
||||
"""Get the device code to authenticate the user."""
|
||||
|
||||
device_code_payload = {
|
||||
"client_id": self.oauth2_provider.get_client_id(),
|
||||
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = httpx.post(
|
||||
url=self.oauth2_provider.get_authorize_url(),
|
||||
data=device_code_payload,
|
||||
timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cast(dict[str, Any], response.json())
|
||||
|
||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||
"""Display the authentication instructions to the user."""
|
||||
|
||||
verification_uri = device_code_data.get(
|
||||
"verification_uri_complete", device_code_data.get("verification_uri", "")
|
||||
)
|
||||
|
||||
console.print("1. Navigate to: ", verification_uri)
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(verification_uri)
|
||||
|
||||
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
|
||||
token_payload = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": device_code_data["device_code"],
|
||||
"client_id": self.oauth2_provider.get_client_id(),
|
||||
}
|
||||
|
||||
console.print("\nWaiting for authentication... ", style="bold blue", end="")
|
||||
|
||||
attempts = 0
|
||||
while True and attempts < 10:
|
||||
response = httpx.post(
|
||||
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
|
||||
)
|
||||
token_data = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
self._validate_and_save_token(token_data)
|
||||
|
||||
console.print(
|
||||
"Success!",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
self._login_to_tool_repository()
|
||||
|
||||
console.print("\n[bold green]Welcome to CrewAI AMP![/bold green]\n")
|
||||
return
|
||||
|
||||
if token_data["error"] not in ("authorization_pending", "slow_down"):
|
||||
raise httpx.HTTPError(
|
||||
token_data.get("error_description") or token_data.get("error")
|
||||
)
|
||||
|
||||
time.sleep(device_code_data["interval"])
|
||||
attempts += 1
|
||||
|
||||
console.print(
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
)
|
||||
|
||||
def _validate_and_save_token(self, token_data: dict[str, Any]) -> None:
|
||||
"""Validates the JWT token and saves the token to the token manager."""
|
||||
|
||||
jwt_token = token_data["access_token"]
|
||||
issuer = self.oauth2_provider.get_issuer()
|
||||
jwt_token_data = {
|
||||
"jwt_token": jwt_token,
|
||||
"jwks_url": self.oauth2_provider.get_jwks_url(),
|
||||
"issuer": issuer,
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
|
||||
decoded_token = validate_jwt_token(**jwt_token_data)
|
||||
|
||||
expires_at = decoded_token.get("exp", 0)
|
||||
self.token_manager.save_tokens(jwt_token, expires_at)
|
||||
|
||||
def _login_to_tool_repository(self) -> None:
|
||||
"""Login to the tool repository."""
|
||||
|
||||
from crewai_cli.tools.main import ToolCommand
|
||||
|
||||
try:
|
||||
console.print(
|
||||
"Now logging you in to the Tool Repository... ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
)
|
||||
|
||||
ToolCommand().login()
|
||||
|
||||
console.print(
|
||||
"Success!\n",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
console.print(
|
||||
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
|
||||
style="green",
|
||||
)
|
||||
except Exception:
|
||||
console.print(
|
||||
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
|
||||
style="yellow",
|
||||
)
|
||||
console.print(
|
||||
"Other features will work normally, but you may experience limitations "
|
||||
"with downloading and publishing tools."
|
||||
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
)
|
||||
@@ -1,34 +0,0 @@
|
||||
from crewai_cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class Auth0Provider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth/device/code"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/.well-known/jwks.json"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return f"https://{self._get_domain()}/"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
if self.settings.audience is None:
|
||||
raise ValueError(
|
||||
"Audience is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.audience
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def _get_domain(self) -> str:
|
||||
if self.settings.domain is None:
|
||||
raise ValueError("Domain is required. Please set it in the configuration.")
|
||||
return self.settings.domain
|
||||
@@ -1,33 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
def __init__(self, settings: Oauth2Settings):
|
||||
self.settings = settings
|
||||
|
||||
@abstractmethod
|
||||
def get_authorize_url(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_token_url(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_jwks_url(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_issuer(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_audience(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_client_id(self) -> str: ...
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
"""Returns which provider-specific fields inside the "extra" dict will be required"""
|
||||
return []
|
||||
|
||||
def get_oauth_scopes(self) -> list[str]:
|
||||
return ["openid", "profile", "email"]
|
||||
@@ -1,43 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from crewai_cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class EntraIdProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._base_url()}/oauth2/v2.0/devicecode"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"{self._base_url()}/oauth2/v2.0/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"{self._base_url()}/discovery/v2.0/keys"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return f"{self._base_url()}/v2.0"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
if self.settings.audience is None:
|
||||
raise ValueError(
|
||||
"Audience is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.audience
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def get_oauth_scopes(self) -> list[str]:
|
||||
return [
|
||||
*super().get_oauth_scopes(),
|
||||
*cast(str, self.settings.extra.get("scope", "")).split(),
|
||||
]
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
return ["scope"]
|
||||
|
||||
def _base_url(self) -> str:
|
||||
return f"https://login.microsoftonline.com/{self.settings.domain}"
|
||||
@@ -1,32 +0,0 @@
|
||||
from crewai_cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class KeycloakProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/auth/device"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/certs"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
return self.settings.audience or "no-audience-provided"
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
return ["realm"]
|
||||
|
||||
def _oauth2_base_url(self) -> str:
|
||||
domain = self.settings.domain.removeprefix("https://").removeprefix("http://")
|
||||
return f"https://{domain}"
|
||||
@@ -1,42 +0,0 @@
|
||||
from crewai_cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class OktaProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/v1/device/authorize"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/v1/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/v1/keys"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return self._oauth2_base_url().removesuffix("/oauth2")
|
||||
|
||||
def get_audience(self) -> str:
|
||||
if self.settings.audience is None:
|
||||
raise ValueError(
|
||||
"Audience is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.audience
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
return ["authorization_server_name", "using_org_auth_server"]
|
||||
|
||||
def _oauth2_base_url(self) -> str:
|
||||
using_org_auth_server = self.settings.extra.get("using_org_auth_server", False)
|
||||
|
||||
if using_org_auth_server:
|
||||
base_url = f"https://{self.settings.domain}/oauth2"
|
||||
else:
|
||||
base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}"
|
||||
|
||||
return f"{base_url}"
|
||||
@@ -1,30 +0,0 @@
|
||||
from crewai_cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class WorkosProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth2/device_authorization"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth2/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth2/jwks"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return f"https://{self._get_domain()}"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
return self.settings.audience or ""
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def _get_domain(self) -> str:
|
||||
if self.settings.domain is None:
|
||||
raise ValueError("Domain is required. Please set it in the configuration.")
|
||||
return self.settings.domain
|
||||
@@ -1,13 +0,0 @@
|
||||
from crewai_cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_auth_token() -> str:
|
||||
"""Get the authentication token."""
|
||||
access_token = TokenManager().get_token()
|
||||
if not access_token:
|
||||
raise AuthError("No token found, make sure you are logged in")
|
||||
return access_token
|
||||
@@ -1,63 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
|
||||
|
||||
def validate_jwt_token(
|
||||
jwt_token: str, jwks_url: str, issuer: str, audience: str
|
||||
) -> Any:
|
||||
"""
|
||||
Verify the token's signature and claims using PyJWT.
|
||||
:param jwt_token: The JWT (JWS) string to validate.
|
||||
:param jwks_url: The URL of the JWKS endpoint.
|
||||
:param issuer: The expected issuer of the token.
|
||||
:param audience: The expected audience of the token.
|
||||
:return: The decoded token.
|
||||
:raises Exception: If the token is invalid for any reason (e.g., signature mismatch,
|
||||
expired, incorrect issuer/audience, JWKS fetching error,
|
||||
missing required claims).
|
||||
"""
|
||||
|
||||
try:
|
||||
jwk_client = PyJWKClient(jwks_url)
|
||||
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
|
||||
|
||||
_unverified_decoded_token = jwt.decode(
|
||||
jwt_token, options={"verify_signature": False}
|
||||
)
|
||||
|
||||
return jwt.decode(
|
||||
jwt_token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256"],
|
||||
audience=audience,
|
||||
issuer=issuer,
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_nbf": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "iat", "iss", "aud", "sub"],
|
||||
},
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
raise Exception("Token has expired.") from e
|
||||
except jwt.InvalidAudienceError as e:
|
||||
actual_audience = _unverified_decoded_token.get("aud", "[no audience found]")
|
||||
raise Exception(
|
||||
f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'"
|
||||
) from e
|
||||
except jwt.InvalidIssuerError as e:
|
||||
actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]")
|
||||
raise Exception(
|
||||
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
|
||||
) from e
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
raise Exception(f"Token is missing required claims: {e!s}") from e
|
||||
except jwt.exceptions.PyJWKClientError as e:
|
||||
raise Exception(f"JWKS or key processing error: {e!s}") from e
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise Exception(f"Invalid token: {e!s}") from e
|
||||
@@ -1,68 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
|
||||
from crewai_cli.authentication.token import get_auth_token
|
||||
from crewai_cli.plus_api import PlusAPI
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class BaseCommand:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class PlusAPIMixin:
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
||||
except Exception:
|
||||
console.print(
|
||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Run 'crewai login' to sign up/login.", style="bold green")
|
||||
raise SystemExit from None
|
||||
|
||||
def _validate_response(self, response: httpx.Response) -> None:
|
||||
try:
|
||||
json_response = response.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
console.print(
|
||||
"Failed to parse response from Enterprise API failed. Details:",
|
||||
style="bold red",
|
||||
)
|
||||
console.print(f"Status Code: {response.status_code}")
|
||||
console.print(
|
||||
f"Response:\n{response.content.decode('utf-8', errors='replace')}"
|
||||
)
|
||||
raise SystemExit from None
|
||||
|
||||
if response.status_code == 422:
|
||||
console.print(
|
||||
"Failed to complete operation. Please fix the following errors:",
|
||||
style="bold red",
|
||||
)
|
||||
for field, messages in json_response.items():
|
||||
for message in messages:
|
||||
console.print(
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}"
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
if not response.is_success:
|
||||
console.print(
|
||||
"Request to Enterprise API failed. Details:", style="bold red"
|
||||
)
|
||||
details = (
|
||||
json_response.get("error")
|
||||
or json_response.get("message")
|
||||
or response.content.decode("utf-8", errors="replace")
|
||||
)
|
||||
console.print(f"{details}")
|
||||
raise SystemExit
|
||||
@@ -1,221 +0,0 @@
|
||||
import json
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
)
|
||||
from crewai_cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
|
||||
def get_writable_config_path() -> Path | None:
|
||||
"""
|
||||
Find a writable location for the config file with fallback options.
|
||||
|
||||
Tries in order:
|
||||
1. Default: ~/.config/crewai/settings.json
|
||||
2. Temp directory: /tmp/crewai_settings.json (or OS equivalent)
|
||||
3. Current directory: ./crewai_settings.json
|
||||
4. In-memory only (returns None)
|
||||
|
||||
Returns:
|
||||
Path object for writable config location, or None if no writable location found
|
||||
"""
|
||||
fallback_paths = [
|
||||
DEFAULT_CONFIG_PATH, # Default location
|
||||
Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory
|
||||
Path.cwd() / "crewai_settings.json", # Current working directory
|
||||
]
|
||||
|
||||
for config_path in fallback_paths:
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file = config_path.parent / ".crewai_write_test"
|
||||
try:
|
||||
test_file.write_text("test")
|
||||
test_file.unlink() # Clean up test file
|
||||
logger.info(f"Using config path: {config_path}")
|
||||
return config_path
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Settings that are related to the user's account
|
||||
USER_SETTINGS_KEYS = [
|
||||
"tool_repository_username",
|
||||
"tool_repository_password",
|
||||
"org_name",
|
||||
"org_uuid",
|
||||
]
|
||||
|
||||
# Settings that are related to the CLI
|
||||
CLI_SETTINGS_KEYS = [
|
||||
"enterprise_base_url",
|
||||
"oauth2_provider",
|
||||
"oauth2_audience",
|
||||
"oauth2_client_id",
|
||||
"oauth2_domain",
|
||||
"oauth2_extra",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
DEFAULT_CLI_SETTINGS = {
|
||||
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
"oauth2_extra": {},
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
READONLY_SETTINGS_KEYS = [
|
||||
"org_name",
|
||||
"org_uuid",
|
||||
]
|
||||
|
||||
# Hidden settings - not displayed by the 'list' command and cannot be set by the user
|
||||
HIDDEN_SETTINGS_KEYS = [
|
||||
"config_path",
|
||||
"tool_repository_username",
|
||||
"tool_repository_password",
|
||||
]
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: str | None = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
description="Base URL of the CrewAI AMP instance",
|
||||
)
|
||||
tool_repository_username: str | None = Field(
|
||||
None, description="Username for interacting with the Tool Repository"
|
||||
)
|
||||
tool_repository_password: str | None = Field(
|
||||
None, description="Password for interacting with the Tool Repository"
|
||||
)
|
||||
org_name: str | None = Field(
|
||||
None, description="Name of the currently active organization"
|
||||
)
|
||||
org_uuid: str | None = Field(
|
||||
None, description="UUID of the currently active organization"
|
||||
)
|
||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
|
||||
|
||||
oauth2_provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||
)
|
||||
|
||||
oauth2_audience: str | None = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||
)
|
||||
|
||||
oauth2_client_id: str = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_client_id"],
|
||||
description="OAuth2 client ID issued by the provider, used during authentication requests.",
|
||||
)
|
||||
|
||||
oauth2_domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
)
|
||||
|
||||
oauth2_extra: dict[str, Any] = Field(
|
||||
description="Extra configuration for the OAuth2 provider.",
|
||||
default={},
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
|
||||
"""Load Settings from config path with fallback support"""
|
||||
if config_path is None:
|
||||
config_path = get_writable_config_path()
|
||||
|
||||
# If config_path is None, we're in memory-only mode
|
||||
if config_path is None:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
file_data = {}
|
||||
if config_path.is_file():
|
||||
try:
|
||||
with config_path.open("r") as f:
|
||||
file_data = json.load(f)
|
||||
except Exception:
|
||||
file_data = {}
|
||||
|
||||
merged_data = {**file_data, **data}
|
||||
super().__init__(config_path=config_path, **merged_data)
|
||||
|
||||
def clear_user_settings(self) -> None:
|
||||
"""Clear all user settings"""
|
||||
self._reset_user_settings()
|
||||
self.dump()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all settings to default values"""
|
||||
self._reset_user_settings()
|
||||
self._reset_cli_settings()
|
||||
self._clear_auth_tokens()
|
||||
self.dump()
|
||||
|
||||
def dump(self) -> None:
|
||||
"""Save current settings to settings.json"""
|
||||
if str(self.config_path) == "/dev/null":
|
||||
return
|
||||
|
||||
try:
|
||||
if self.config_path.is_file():
|
||||
with self.config_path.open("r") as f:
|
||||
existing_data = json.load(f)
|
||||
else:
|
||||
existing_data = {}
|
||||
|
||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(updated_data, f, indent=4)
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def _reset_user_settings(self) -> None:
|
||||
"""Reset all user settings to default values"""
|
||||
for key in USER_SETTINGS_KEYS:
|
||||
setattr(self, key, None)
|
||||
|
||||
def _reset_cli_settings(self) -> None:
|
||||
"""Reset all CLI settings to default values"""
|
||||
for key in CLI_SETTINGS_KEYS:
|
||||
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
def _clear_auth_tokens(self) -> None:
|
||||
"""Clear all authentication tokens"""
|
||||
TokenManager().clear_tokens()
|
||||
@@ -1,333 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
|
||||
|
||||
ENV_VARS: dict[str, list[dict[str, Any]]] = {
|
||||
"openai": [
|
||||
{
|
||||
"prompt": "Enter your OPENAI API key (press Enter to skip)",
|
||||
"key_name": "OPENAI_API_KEY",
|
||||
}
|
||||
],
|
||||
"anthropic": [
|
||||
{
|
||||
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
|
||||
"key_name": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
],
|
||||
"gemini": [
|
||||
{
|
||||
"prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)",
|
||||
"key_name": "GEMINI_API_KEY",
|
||||
}
|
||||
],
|
||||
"nvidia_nim": [
|
||||
{
|
||||
"prompt": "Enter your NVIDIA API key (press Enter to skip)",
|
||||
"key_name": "NVIDIA_NIM_API_KEY",
|
||||
}
|
||||
],
|
||||
"groq": [
|
||||
{
|
||||
"prompt": "Enter your GROQ API key (press Enter to skip)",
|
||||
"key_name": "GROQ_API_KEY",
|
||||
}
|
||||
],
|
||||
"watson": [
|
||||
{
|
||||
"prompt": "Enter your WATSONX URL (press Enter to skip)",
|
||||
"key_name": "WATSONX_URL",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your WATSONX API Key (press Enter to skip)",
|
||||
"key_name": "WATSONX_APIKEY",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your WATSONX Project Id (press Enter to skip)",
|
||||
"key_name": "WATSONX_PROJECT_ID",
|
||||
},
|
||||
],
|
||||
"ollama": [
|
||||
{
|
||||
"default": True,
|
||||
"API_BASE": "http://localhost:11434",
|
||||
}
|
||||
],
|
||||
"bedrock": [
|
||||
{
|
||||
"prompt": "Enter your AWS Access Key ID (press Enter to skip)",
|
||||
"key_name": "AWS_ACCESS_KEY_ID",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AWS Secret Access Key (press Enter to skip)",
|
||||
"key_name": "AWS_SECRET_ACCESS_KEY",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AWS Region Name (press Enter to skip)",
|
||||
"key_name": "AWS_DEFAULT_REGION",
|
||||
},
|
||||
],
|
||||
"azure": [
|
||||
{
|
||||
"prompt": "Enter your Azure deployment name (must start with 'azure/')",
|
||||
"key_name": "model",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API key (press Enter to skip)",
|
||||
"key_name": "AZURE_API_KEY",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API base URL (press Enter to skip)",
|
||||
"key_name": "AZURE_API_BASE",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API version (press Enter to skip)",
|
||||
"key_name": "AZURE_API_VERSION",
|
||||
},
|
||||
],
|
||||
"cerebras": [
|
||||
{
|
||||
"prompt": "Enter your Cerebras model name (must start with 'cerebras/')",
|
||||
"key_name": "model",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your Cerebras API version (press Enter to skip)",
|
||||
"key_name": "CEREBRAS_API_KEY",
|
||||
},
|
||||
],
|
||||
"huggingface": [
|
||||
{
|
||||
"prompt": "Enter your Huggingface API key (HF_TOKEN) (press Enter to skip)",
|
||||
"key_name": "HF_TOKEN",
|
||||
},
|
||||
],
|
||||
"sambanova": [
|
||||
{
|
||||
"prompt": "Enter your SambaNovaCloud API key (press Enter to skip)",
|
||||
"key_name": "SAMBANOVA_API_KEY",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
PROVIDERS: list[str] = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"gemini",
|
||||
"nvidia_nim",
|
||||
"groq",
|
||||
"huggingface",
|
||||
"ollama",
|
||||
"watson",
|
||||
"bedrock",
|
||||
"azure",
|
||||
"cerebras",
|
||||
"sambanova",
|
||||
]
|
||||
|
||||
MODELS: dict[str, list[str]] = {
|
||||
"openai": [
|
||||
"gpt-4",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini-2025-04-14",
|
||||
"gpt-4.1-nano-2025-04-14",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
],
|
||||
"anthropic": [
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
],
|
||||
"gemini": [
|
||||
"gemini/gemini-3-pro-preview",
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-2.0-flash-lite-001",
|
||||
"gemini/gemini-2.0-flash-001",
|
||||
"gemini/gemini-2.0-flash-thinking-exp-01-21",
|
||||
"gemini/gemini-2.5-flash-preview-04-17",
|
||||
"gemini/gemini-2.5-pro-exp-03-25",
|
||||
"gemini/gemini-gemma-2-9b-it",
|
||||
"gemini/gemini-gemma-2-27b-it",
|
||||
"gemini/gemma-3-1b-it",
|
||||
"gemini/gemma-3-4b-it",
|
||||
"gemini/gemma-3-12b-it",
|
||||
"gemini/gemma-3-27b-it",
|
||||
],
|
||||
"nvidia_nim": [
|
||||
"nvidia_nim/nvidia/mistral-nemo-minitron-8b-8k-instruct",
|
||||
"nvidia_nim/nvidia/nemotron-4-mini-hindi-4b-instruct",
|
||||
"nvidia_nim/nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
"nvidia_nim/nvidia/llama3-chatqa-1.5-8b",
|
||||
"nvidia_nim/nvidia/llama3-chatqa-1.5-70b",
|
||||
"nvidia_nim/nvidia/vila",
|
||||
"nvidia_nim/nvidia/neva-22",
|
||||
"nvidia_nim/nvidia/nemotron-mini-4b-instruct",
|
||||
"nvidia_nim/nvidia/usdcode-llama3-70b-instruct",
|
||||
"nvidia_nim/nvidia/nemotron-4-340b-instruct",
|
||||
"nvidia_nim/meta/codellama-70b",
|
||||
"nvidia_nim/meta/llama2-70b",
|
||||
"nvidia_nim/meta/llama3-8b-instruct",
|
||||
"nvidia_nim/meta/llama3-70b-instruct",
|
||||
"nvidia_nim/meta/llama-3.1-8b-instruct",
|
||||
"nvidia_nim/meta/llama-3.1-70b-instruct",
|
||||
"nvidia_nim/meta/llama-3.1-405b-instruct",
|
||||
"nvidia_nim/meta/llama-3.2-1b-instruct",
|
||||
"nvidia_nim/meta/llama-3.2-3b-instruct",
|
||||
"nvidia_nim/meta/llama-3.2-11b-vision-instruct",
|
||||
"nvidia_nim/meta/llama-3.2-90b-vision-instruct",
|
||||
"nvidia_nim/meta/llama-3.1-70b-instruct",
|
||||
"nvidia_nim/google/gemma-7b",
|
||||
"nvidia_nim/google/gemma-2b",
|
||||
"nvidia_nim/google/codegemma-7b",
|
||||
"nvidia_nim/google/codegemma-1.1-7b",
|
||||
"nvidia_nim/google/recurrentgemma-2b",
|
||||
"nvidia_nim/google/gemma-2-9b-it",
|
||||
"nvidia_nim/google/gemma-2-27b-it",
|
||||
"nvidia_nim/google/gemma-2-2b-it",
|
||||
"nvidia_nim/google/deplot",
|
||||
"nvidia_nim/google/paligemma",
|
||||
"nvidia_nim/mistralai/mistral-7b-instruct-v0.2",
|
||||
"nvidia_nim/mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
"nvidia_nim/mistralai/mistral-large",
|
||||
"nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1",
|
||||
"nvidia_nim/mistralai/mistral-7b-instruct-v0.3",
|
||||
"nvidia_nim/nv-mistralai/mistral-nemo-12b-instruct",
|
||||
"nvidia_nim/mistralai/mamba-codestral-7b-v0.1",
|
||||
"nvidia_nim/microsoft/phi-3-mini-128k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-mini-4k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-small-8k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-small-128k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-medium-4k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-medium-128k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3.5-mini-instruct",
|
||||
"nvidia_nim/microsoft/phi-3.5-moe-instruct",
|
||||
"nvidia_nim/microsoft/kosmos-2",
|
||||
"nvidia_nim/microsoft/phi-3-vision-128k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3.5-vision-instruct",
|
||||
"nvidia_nim/databricks/dbrx-instruct",
|
||||
"nvidia_nim/snowflake/arctic",
|
||||
"nvidia_nim/aisingapore/sea-lion-7b-instruct",
|
||||
"nvidia_nim/ibm/granite-8b-code-instruct",
|
||||
"nvidia_nim/ibm/granite-34b-code-instruct",
|
||||
"nvidia_nim/ibm/granite-3.0-8b-instruct",
|
||||
"nvidia_nim/ibm/granite-3.0-3b-a800m-instruct",
|
||||
"nvidia_nim/mediatek/breeze-7b-instruct",
|
||||
"nvidia_nim/upstage/solar-10.7b-instruct",
|
||||
"nvidia_nim/writer/palmyra-med-70b-32k",
|
||||
"nvidia_nim/writer/palmyra-med-70b",
|
||||
"nvidia_nim/writer/palmyra-fin-70b-32k",
|
||||
"nvidia_nim/01-ai/yi-large",
|
||||
"nvidia_nim/deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
"nvidia_nim/rakuten/rakutenai-7b-instruct",
|
||||
"nvidia_nim/rakuten/rakutenai-7b-chat",
|
||||
"nvidia_nim/baichuan-inc/baichuan2-13b-chat",
|
||||
],
|
||||
"groq": [
|
||||
"groq/llama-3.1-8b-instant",
|
||||
"groq/llama-3.1-70b-versatile",
|
||||
"groq/llama-3.1-405b-reasoning",
|
||||
"groq/gemma2-9b-it",
|
||||
"groq/gemma-7b-it",
|
||||
],
|
||||
"ollama": ["ollama/llama3.1", "ollama/mixtral"],
|
||||
"watson": [
|
||||
"watsonx/meta-llama/llama-3-1-70b-instruct",
|
||||
"watsonx/meta-llama/llama-3-1-8b-instruct",
|
||||
"watsonx/meta-llama/llama-3-2-11b-vision-instruct",
|
||||
"watsonx/meta-llama/llama-3-2-1b-instruct",
|
||||
"watsonx/meta-llama/llama-3-2-90b-vision-instruct",
|
||||
"watsonx/meta-llama/llama-3-405b-instruct",
|
||||
"watsonx/mistral/mistral-large",
|
||||
"watsonx/ibm/granite-3-8b-instruct",
|
||||
],
|
||||
"bedrock": [
|
||||
"bedrock/us.amazon.nova-pro-v1:0",
|
||||
"bedrock/us.amazon.nova-micro-v1:0",
|
||||
"bedrock/us.amazon.nova-lite-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-opus-20240229-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/us.meta.llama3-2-11b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-2-3b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-2-90b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-2-1b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-1-8b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-1-70b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-3-70b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-1-405b-instruct-v1:0",
|
||||
"bedrock/eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"bedrock/eu.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/eu.anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/eu.meta.llama3-2-3b-instruct-v1:0",
|
||||
"bedrock/eu.meta.llama3-2-1b-instruct-v1:0",
|
||||
"bedrock/apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"bedrock/apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"bedrock/apac.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/apac.anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/amazon.nova-pro-v1:0",
|
||||
"bedrock/amazon.nova-micro-v1:0",
|
||||
"bedrock/amazon.nova-lite-v1:0",
|
||||
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
||||
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/anthropic.claude-v2:1",
|
||||
"bedrock/anthropic.claude-v2",
|
||||
"bedrock/anthropic.claude-instant-v1",
|
||||
"bedrock/meta.llama3-1-405b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-1-70b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-1-8b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-70b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-8b-instruct-v1:0",
|
||||
"bedrock/amazon.titan-text-lite-v1",
|
||||
"bedrock/amazon.titan-text-express-v1",
|
||||
"bedrock/cohere.command-text-v14",
|
||||
"bedrock/ai21.j2-mid-v1",
|
||||
"bedrock/ai21.j2-ultra-v1",
|
||||
"bedrock/ai21.jamba-instruct-v1:0",
|
||||
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
||||
"bedrock/mistral.mixtral-8x7b-instruct-v0:1",
|
||||
],
|
||||
"huggingface": [
|
||||
"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"huggingface/tiiuae/falcon-180B-chat",
|
||||
"huggingface/google/gemma-7b-it",
|
||||
],
|
||||
"sambanova": [
|
||||
"sambanova/Meta-Llama-3.3-70B-Instruct",
|
||||
"sambanova/QwQ-32B-Preview",
|
||||
"sambanova/Qwen2.5-72B-Instruct",
|
||||
"sambanova/Qwen2.5-Coder-32B-Instruct",
|
||||
"sambanova/Meta-Llama-3.1-405B-Instruct",
|
||||
"sambanova/Meta-Llama-3.1-70B-Instruct",
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
"sambanova/Llama-3.2-90B-Vision-Instruct",
|
||||
"sambanova/Llama-3.2-11B-Vision-Instruct",
|
||||
"sambanova/Meta-Llama-3.2-3B-Instruct",
|
||||
"sambanova/Meta-Llama-3.2-1B-Instruct",
|
||||
],
|
||||
}
|
||||
|
||||
DEFAULT_LLM_MODEL = "gpt-4.1-mini"
|
||||
|
||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
|
||||
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Wrapper for the crew chat command.
|
||||
|
||||
Delegates to ``crewai.cli.crew_chat.run_chat`` when the full crewai package is
|
||||
installed, otherwise prints a helpful error message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def run_chat() -> None:
|
||||
try:
|
||||
from crewai.cli.crew_chat import run_chat as _run_chat
|
||||
except ImportError:
|
||||
click.secho(
|
||||
"The 'chat' command requires the full crewai package.\n"
|
||||
"Install it with: pip install crewai",
|
||||
fg="red",
|
||||
)
|
||||
raise SystemExit(1) from None
|
||||
|
||||
_run_chat()
|
||||
@@ -1,89 +0,0 @@
|
||||
from functools import lru_cache
|
||||
import subprocess
|
||||
|
||||
|
||||
class Repository:
|
||||
def __init__(self, path: str = ".") -> None:
|
||||
self.path = path
|
||||
|
||||
if not self.is_git_installed():
|
||||
raise ValueError("Git is not installed or not found in your PATH.")
|
||||
|
||||
if not self.is_git_repo():
|
||||
raise ValueError(f"{self.path} is not a Git repository.")
|
||||
|
||||
self.fetch()
|
||||
|
||||
@staticmethod
|
||||
def is_git_installed() -> bool:
|
||||
"""Check if Git is installed and available in the system."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "--version"], # noqa: S607
|
||||
capture_output=True,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def fetch(self) -> None:
|
||||
"""Fetch latest updates from the remote."""
|
||||
subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607
|
||||
|
||||
def status(self) -> str:
|
||||
"""Get the git status in porcelain format."""
|
||||
return subprocess.check_output(
|
||||
["git", "status", "--branch", "--porcelain"], # noqa: S607
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
).strip()
|
||||
|
||||
@lru_cache(maxsize=None) # noqa: B019
|
||||
def is_git_repo(self) -> bool:
|
||||
"""Check if the current directory is a git repository.
|
||||
|
||||
Notes:
|
||||
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
|
||||
"""
|
||||
try:
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
def has_uncommitted_changes(self) -> bool:
|
||||
"""Check if the repository has uncommitted changes."""
|
||||
return len(self.status().splitlines()) > 1
|
||||
|
||||
def is_ahead_or_behind(self) -> bool:
|
||||
"""Check if the repository is ahead or behind the remote."""
|
||||
for line in self.status().splitlines():
|
||||
if line.startswith("##") and ("ahead" in line or "behind" in line):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_synced(self) -> bool:
|
||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||
return False
|
||||
return True
|
||||
|
||||
def origin_url(self) -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "remote", "get-url", "origin"], # noqa: S607
|
||||
cwd=self.path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
@@ -1,210 +0,0 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from crewai_cli.config import Settings
|
||||
from crewai_cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai_cli.version import get_crewai_version
|
||||
|
||||
|
||||
class PlusAPI:
|
||||
"""
|
||||
This class exposes methods for working with the CrewAI+ API.
|
||||
"""
|
||||
|
||||
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
||||
ORGANIZATIONS_RESOURCE = "/crewai_plus/api/v1/me/organizations"
|
||||
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
||||
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
|
||||
TRACING_RESOURCE = "/crewai_plus/api/v1/tracing"
|
||||
EPHEMERAL_TRACING_RESOURCE = "/crewai_plus/api/v1/tracing/ephemeral"
|
||||
INTEGRATIONS_RESOURCE = "/crewai_plus/api/v1/integrations"
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
|
||||
"X-Crewai-Version": get_crewai_version(),
|
||||
}
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
|
||||
|
||||
self.base_url = (
|
||||
os.getenv("CREWAI_PLUS_URL")
|
||||
or str(settings.enterprise_base_url)
|
||||
or DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
)
|
||||
|
||||
def _make_request(
|
||||
self, method: str, endpoint: str, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
verify = kwargs.pop("verify", True)
|
||||
with httpx.Client(trust_env=False, verify=verify) as client:
|
||||
return client.request(method, url, headers=self.headers, **kwargs)
|
||||
|
||||
def login_to_tool_repository(self) -> httpx.Response:
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
|
||||
|
||||
def get_tool(self, handle: str) -> httpx.Response:
|
||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||
|
||||
async def get_agent(self, handle: str) -> httpx.Response:
|
||||
url = urljoin(self.base_url, f"{self.AGENTS_RESOURCE}/{handle}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.get(url, headers=self.headers)
|
||||
|
||||
def publish_tool(
|
||||
self,
|
||||
handle: str,
|
||||
is_public: bool,
|
||||
version: str,
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
available_exports: list[dict[str, Any]] | None = None,
|
||||
) -> httpx.Response:
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": is_public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": available_exports,
|
||||
}
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
|
||||
|
||||
def deploy_by_name(self, project_name: str) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
|
||||
)
|
||||
|
||||
def deploy_by_uuid(self, uuid: str) -> httpx.Response:
|
||||
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
|
||||
|
||||
def crew_status_by_name(self, project_name: str) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
|
||||
)
|
||||
|
||||
def crew_status_by_uuid(self, uuid: str) -> httpx.Response:
|
||||
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
|
||||
|
||||
def crew_by_name(
|
||||
self, project_name: str, log_type: str = "deployment"
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
|
||||
)
|
||||
|
||||
def crew_by_uuid(self, uuid: str, log_type: str = "deployment") -> httpx.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
|
||||
)
|
||||
|
||||
def delete_crew_by_name(self, project_name: str) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
|
||||
)
|
||||
|
||||
def delete_crew_by_uuid(self, uuid: str) -> httpx.Response:
|
||||
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
|
||||
|
||||
def list_crews(self) -> httpx.Response:
|
||||
return self._make_request("GET", self.CREWS_RESOURCE)
|
||||
|
||||
def create_crew(self, payload: dict[str, Any]) -> httpx.Response:
|
||||
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
||||
|
||||
def get_organizations(self) -> httpx.Response:
|
||||
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
||||
|
||||
def initialize_trace_batch(self, payload: dict[str, Any]) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.TRACING_RESOURCE}/batches",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def initialize_ephemeral_trace_batch(
|
||||
self, payload: dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def send_trace_events(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def send_ephemeral_trace_events(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def finalize_trace_batch(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def finalize_ephemeral_trace_batch(
|
||||
self, trace_batch_id: str, payload: dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def mark_trace_batch_as_failed(
|
||||
self, trace_batch_id: str, error_message: str
|
||||
) -> httpx.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
|
||||
json={"status": "failed", "failure_reason": error_message},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
|
||||
"""Get MCP server configurations for the given slugs."""
|
||||
return self._make_request(
|
||||
"GET",
|
||||
f"{self.INTEGRATIONS_RESOURCE}/mcp_configs",
|
||||
params={"slugs": ",".join(slugs)},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def get_triggers(self) -> httpx.Response:
|
||||
"""Get all available triggers from integrations."""
|
||||
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
|
||||
|
||||
def get_trigger_payload(self, app_slug: str, trigger_slug: str) -> httpx.Response:
|
||||
"""Get sample payload for a specific trigger."""
|
||||
return self._make_request(
|
||||
"GET", f"{self.INTEGRATIONS_RESOURCE}/{app_slug}/{trigger_slug}/payload"
|
||||
)
|
||||
@@ -1,231 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import certifi
|
||||
import click
|
||||
import httpx
|
||||
|
||||
from crewai_cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||
|
||||
|
||||
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
|
||||
"""Presents a list of choices to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
prompt_message: The message to display to the user before presenting the choices.
|
||||
choices: A list of options to present to the user.
|
||||
|
||||
Returns:
|
||||
The selected choice from the list, or None if the user chooses to quit.
|
||||
"""
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return None
|
||||
click.secho(prompt_message, fg="cyan")
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||
click.secho("q. Quit", fg="cyan")
|
||||
|
||||
while True:
|
||||
choice = click.prompt(
|
||||
"Enter the number of your choice or 'q' to quit", type=str
|
||||
)
|
||||
|
||||
if choice.lower() == "q":
|
||||
return None
|
||||
|
||||
try:
|
||||
selected_index = int(choice) - 1
|
||||
if 0 <= selected_index < len(choices):
|
||||
return choices[selected_index]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
click.secho(
|
||||
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
|
||||
fg="red",
|
||||
)
|
||||
|
||||
|
||||
def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
|
||||
"""Presents a list of providers to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
provider_models: A dictionary of provider models.
|
||||
|
||||
Returns:
|
||||
The selected provider, None if user explicitly quits, or False if no selection.
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||
|
||||
provider = select_choice(
|
||||
"Select a provider to set up:", [*predefined_providers, "other"]
|
||||
)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
|
||||
if provider == "other":
|
||||
provider = select_choice("Select a provider from the full list:", all_providers)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
|
||||
return provider.lower() if provider else False
|
||||
|
||||
|
||||
def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
|
||||
"""Presents a list of models for a given provider to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
provider: The provider for which to select a model.
|
||||
provider_models: A dictionary of provider models.
|
||||
|
||||
Returns:
|
||||
The selected model, or None if the operation is aborted or an invalid selection is made.
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
|
||||
if provider in predefined_providers:
|
||||
available_models = MODELS.get(provider, [])
|
||||
else:
|
||||
available_models = provider_models.get(provider, [])
|
||||
|
||||
if not available_models:
|
||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||
return None
|
||||
|
||||
return select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||
)
|
||||
|
||||
|
||||
def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
|
||||
"""Loads provider data from a cache file if it exists and is not expired.
|
||||
|
||||
If the cache is expired or corrupted, it fetches the data from the web.
|
||||
|
||||
Args:
|
||||
cache_file: The path to the cache file.
|
||||
cache_expiry: The cache expiry time in seconds.
|
||||
|
||||
Returns:
|
||||
The loaded provider data or None if the operation fails.
|
||||
"""
|
||||
current_time = time.time()
|
||||
if (
|
||||
cache_file.exists()
|
||||
and (current_time - cache_file.stat().st_mtime) < cache_expiry
|
||||
):
|
||||
data = read_cache_file(cache_file)
|
||||
if data:
|
||||
return data
|
||||
click.secho(
|
||||
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
|
||||
)
|
||||
else:
|
||||
click.secho(
|
||||
"Cache expired or not found. Fetching provider data from the web...",
|
||||
fg="cyan",
|
||||
)
|
||||
return fetch_provider_data(cache_file)
|
||||
|
||||
|
||||
def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
|
||||
"""Reads and returns the JSON content from a cache file.
|
||||
|
||||
Args:
|
||||
cache_file: The path to the cache file.
|
||||
|
||||
Returns:
|
||||
The JSON content of the cache file or None if the JSON is invalid.
|
||||
"""
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
data: dict[str, Any] = json.load(f)
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
|
||||
"""Fetches provider data from a specified URL and caches it to a file.
|
||||
|
||||
Args:
|
||||
cache_file: The path to the cache file.
|
||||
|
||||
Returns:
|
||||
The fetched provider data or None if the operation fails.
|
||||
"""
|
||||
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
|
||||
try:
|
||||
with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
|
||||
response.raise_for_status()
|
||||
data = download_data(response)
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
return data
|
||||
except httpx.HTTPError as e:
|
||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
||||
except json.JSONDecodeError:
|
||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||
return None
|
||||
|
||||
|
||||
def download_data(response: httpx.Response) -> dict[str, Any]:
|
||||
"""Downloads data from a given HTTP response and returns the JSON content.
|
||||
|
||||
Args:
|
||||
response: The HTTP response object.
|
||||
|
||||
Returns:
|
||||
The JSON content of the response.
|
||||
"""
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
data_chunks: list[bytes] = []
|
||||
bar: Any
|
||||
with click.progressbar(
|
||||
length=total_size, label="Downloading", show_pos=True
|
||||
) as bar:
|
||||
for chunk in response.iter_bytes(block_size):
|
||||
if chunk:
|
||||
data_chunks.append(chunk)
|
||||
bar.update(len(chunk))
|
||||
data_content = b"".join(data_chunks)
|
||||
result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
|
||||
return result
|
||||
|
||||
|
||||
def get_provider_data() -> dict[str, list[str]] | None:
|
||||
"""Retrieves provider data from a cache file.
|
||||
|
||||
Filters out models based on provider criteria, and returns a dictionary of providers
|
||||
mapped to their models.
|
||||
|
||||
Returns:
|
||||
A dictionary of providers mapped to their models or None if the operation fails.
|
||||
"""
|
||||
cache_dir = Path.home() / ".crewai"
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
cache_file = cache_dir / "provider_cache.json"
|
||||
cache_expiry = 24 * 3600
|
||||
|
||||
data = load_provider_data(cache_file, cache_expiry)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
provider_models = defaultdict(list)
|
||||
for model_name, properties in data.items():
|
||||
provider = properties.get("litellm_provider", "").strip().lower()
|
||||
if "http" in provider or provider == "other":
|
||||
continue
|
||||
if provider:
|
||||
provider_models[provider].append(model_name)
|
||||
return provider_models
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Wrapper for the reset-memories command.
|
||||
|
||||
Delegates to ``crewai.cli.reset_memories_command`` when the full crewai
|
||||
package is installed, otherwise prints a helpful error message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
memory: bool,
|
||||
knowledge: bool,
|
||||
agent_knowledge: bool,
|
||||
kickoff_outputs: bool,
|
||||
all: bool,
|
||||
) -> None:
|
||||
try:
|
||||
from crewai.cli.reset_memories_command import (
|
||||
reset_memories_command as _reset,
|
||||
)
|
||||
except ImportError:
|
||||
click.secho(
|
||||
"The 'reset-memories' command requires the full crewai package.\n"
|
||||
"Install it with: pip install crewai",
|
||||
fg="red",
|
||||
)
|
||||
raise SystemExit(1) from None
|
||||
|
||||
_reset(memory, knowledge, agent_knowledge, kickoff_outputs, all)
|
||||
@@ -1,186 +0,0 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Final, Literal, cast
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Manages encrypted token storage."""
|
||||
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""Initialize the TokenManager.
|
||||
|
||||
Args:
|
||||
file_path: The file path to store encrypted tokens.
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""Get or create the encryption key.
|
||||
|
||||
Returns:
|
||||
The encryption key as bytes.
|
||||
"""
|
||||
key_filename: str = "secret.key"
|
||||
|
||||
key = self._read_secure_file(key_filename)
|
||||
if key is not None and len(key) == _FERNET_KEY_LENGTH:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
if self._atomic_create_secure_file(key_filename, new_key):
|
||||
return new_key
|
||||
|
||||
key = self._read_secure_file(key_filename)
|
||||
if key is not None and len(key) == _FERNET_KEY_LENGTH:
|
||||
return key
|
||||
|
||||
raise RuntimeError("Failed to create or read encryption key")
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""Save the access token and its expiration time.
|
||||
|
||||
Args:
|
||||
access_token: The access token to save.
|
||||
expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self._atomic_write_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> str | None:
|
||||
"""Get the access token if it is valid and not expired.
|
||||
|
||||
Returns:
|
||||
The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self._read_secure_file(self.file_path)
|
||||
if encrypted_data is None:
|
||||
return None
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data)
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return cast(str | None, data.get("access_token"))
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""Clear the stored tokens."""
|
||||
self._delete_secure_file(self.file_path)
|
||||
|
||||
@staticmethod
|
||||
def _get_secure_storage_path() -> Path:
|
||||
"""Get the secure storage path based on the operating system.
|
||||
|
||||
Returns:
|
||||
The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return storage_path
|
||||
|
||||
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
|
||||
"""Create a file only if it doesn't exist.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
content: The content to write.
|
||||
|
||||
Returns:
|
||||
True if file was created, False if it already exists.
|
||||
"""
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
try:
|
||||
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
|
||||
try:
|
||||
os.write(fd, content)
|
||||
finally:
|
||||
os.close(fd)
|
||||
return True
|
||||
except FileExistsError:
|
||||
return False
|
||||
|
||||
def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""Write content to a secure file.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
content: The content to write.
|
||||
"""
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
|
||||
fd_closed = False
|
||||
try:
|
||||
os.write(fd, content)
|
||||
os.close(fd)
|
||||
fd_closed = True
|
||||
os.chmod(temp_path, 0o600)
|
||||
os.replace(temp_path, file_path)
|
||||
except Exception:
|
||||
if not fd_closed:
|
||||
os.close(fd)
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
raise
|
||||
|
||||
def _read_secure_file(self, filename: str) -> bytes | None:
|
||||
"""Read the content of a secure file.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
|
||||
Returns:
|
||||
The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def _delete_secure_file(self, filename: str) -> None:
|
||||
"""Delete a secure file.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
"""
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
try:
|
||||
file_path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Lightweight SQLite reader for kickoff task outputs.
|
||||
|
||||
Only used by the ``crewai log-tasks-outputs`` CLI command. Depends solely on
|
||||
the standard library + *appdirs* so crewai-cli can read stored outputs without
|
||||
importing the full crewai framework.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
from crewai_cli.user_data import _db_storage_path
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_task_outputs(db_path: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Return all rows from the kickoff task outputs database."""
|
||||
if db_path is None:
|
||||
db_path = str(Path(_db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
|
||||
if not Path(db_path).exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT *
|
||||
FROM latest_kickoff_task_outputs
|
||||
ORDER BY task_index
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
results: list[dict[str, Any]] = [
|
||||
{
|
||||
"task_id": row[0],
|
||||
"expected_output": row[1],
|
||||
"output": json.loads(row[2]),
|
||||
"task_index": row[3],
|
||||
"inputs": json.loads(row[4]),
|
||||
"was_replayed": row[5],
|
||||
"timestamp": row[6],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
return results
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to load task outputs: %s", e)
|
||||
return []
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Standalone user-data helpers for the CLI package.
|
||||
|
||||
These mirror the functions in ``crewai.events.listeners.tracing.utils`` but
|
||||
depend only on the standard library + *appdirs* so that crewai-cli can work
|
||||
without importing the full crewai framework.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import appdirs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_project_directory_name() -> str:
|
||||
return os.environ.get("CREWAI_STORAGE_DIR", Path.cwd().name)
|
||||
|
||||
|
||||
def _db_storage_path() -> str:
|
||||
app_name = _get_project_directory_name()
|
||||
app_author = "CrewAI"
|
||||
data_dir = Path(appdirs.user_data_dir(app_name, app_author))
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
return str(data_dir)
|
||||
|
||||
|
||||
def _user_data_file() -> Path:
|
||||
base = Path(_db_storage_path())
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
return base / ".crewai_user.json"
|
||||
|
||||
|
||||
def _load_user_data() -> dict[str, Any]:
|
||||
p = _user_data_file()
|
||||
if p.exists():
|
||||
try:
|
||||
return cast(dict[str, Any], json.loads(p.read_text()))
|
||||
except (json.JSONDecodeError, OSError, PermissionError) as e:
|
||||
logger.warning("Failed to load user data: %s", e)
|
||||
return {}
|
||||
|
||||
|
||||
def _save_user_data(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.warning("Failed to save user data: %s", e)
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
"""Check if tracing is enabled (mirrors crewai core logic)."""
|
||||
data = _load_user_data()
|
||||
if (
|
||||
data.get("first_execution_done", False)
|
||||
and data.get("trace_consent", False) is False
|
||||
):
|
||||
return False
|
||||
return os.getenv("CREWAI_TRACING_ENABLED", "false").lower() == "true"
|
||||
@@ -1,369 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import reduce
|
||||
from inspect import getmro, isclass
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
import tomli
|
||||
|
||||
from crewai_cli.config import Settings
|
||||
from crewai_cli.constants import ENV_VARS
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def copy_template(
|
||||
src: Path, dst: Path, name: str, class_name: str, folder_name: str
|
||||
) -> None:
|
||||
"""Copy a file from src to dst."""
|
||||
with open(src, "r") as file:
|
||||
content = file.read()
|
||||
|
||||
content = content.replace("{{name}}", name)
|
||||
content = content.replace("{{crew_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
|
||||
with open(dst, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
click.secho(f" - Created {dst}", fg="green")
|
||||
|
||||
|
||||
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
|
||||
"""Read the content of a TOML file and return it as a dictionary."""
|
||||
with open(file_path, "rb") as f:
|
||||
return tomli.load(f)
|
||||
|
||||
|
||||
def parse_toml(content: str) -> dict[str, Any]:
|
||||
if sys.version_info >= (3, 11):
|
||||
return tomllib.loads(content)
|
||||
return tomli.loads(content)
|
||||
|
||||
|
||||
def get_project_name(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
"""Get the project name from the pyproject.toml file."""
|
||||
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
|
||||
|
||||
|
||||
def get_project_version(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
"""Get the project version from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["project", "version"], require=require
|
||||
)
|
||||
|
||||
|
||||
def get_project_description(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
"""Get the project description from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["project", "description"], require=require
|
||||
)
|
||||
|
||||
|
||||
def _get_project_attribute(
|
||||
pyproject_path: str, keys: list[str], require: bool
|
||||
) -> Any | None:
|
||||
"""Get an attribute from the pyproject.toml file."""
|
||||
attribute = None
|
||||
|
||||
try:
|
||||
with open(pyproject_path, "r") as f:
|
||||
pyproject_content = parse_toml(f.read())
|
||||
|
||||
dependencies = (
|
||||
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
|
||||
)
|
||||
if not any(True for dep in dependencies if "crewai" in dep):
|
||||
raise Exception("crewai is not in the dependencies.")
|
||||
|
||||
attribute = _get_nested_value(pyproject_content, keys)
|
||||
except FileNotFoundError:
|
||||
console.print(f"Error: {pyproject_path} not found.", style="bold red")
|
||||
except KeyError:
|
||||
console.print(
|
||||
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
|
||||
style="bold red",
|
||||
)
|
||||
except Exception as e:
|
||||
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
|
||||
console.print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"Error reading the pyproject.toml file: {e}", style="bold red"
|
||||
)
|
||||
|
||||
if require and not attribute:
|
||||
console.print(
|
||||
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
return attribute
|
||||
|
||||
|
||||
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
with open(env_file_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
env_dict = {}
|
||||
for line in env_content.splitlines():
|
||||
if line.strip() and not line.strip().startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
env_dict[key.strip()] = value.strip()
|
||||
|
||||
return env_dict
|
||||
|
||||
except FileNotFoundError:
|
||||
console.print(f"Error: {env_file_path} not found.", style="bold red")
|
||||
except Exception as e:
|
||||
console.print(f"Error reading the .env file: {e}", style="bold red")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def tree_copy(source: Path, destination: Path) -> None:
|
||||
"""Copies the entire directory structure from the source to the destination."""
|
||||
for item in os.listdir(source):
|
||||
source_item = os.path.join(source, item)
|
||||
destination_item = os.path.join(destination, item)
|
||||
if os.path.isdir(source_item):
|
||||
shutil.copytree(source_item, destination_item)
|
||||
else:
|
||||
shutil.copy2(source_item, destination_item)
|
||||
|
||||
|
||||
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
|
||||
"""Recursively searches through a directory, replacing a target string in
|
||||
both file contents and filenames with a specified replacement string.
|
||||
"""
|
||||
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
|
||||
for filename in files:
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8", errors="ignore") as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
|
||||
if find in filename:
|
||||
new_filename = filename.replace(find, replace)
|
||||
new_filepath = os.path.join(path, new_filename)
|
||||
os.rename(filepath, new_filepath)
|
||||
|
||||
for dirname in dirs:
|
||||
if find in dirname:
|
||||
new_dirname = dirname.replace(find, replace)
|
||||
new_dirpath = os.path.join(path, new_dirname)
|
||||
old_dirpath = os.path.join(path, dirname)
|
||||
os.rename(old_dirpath, new_dirpath)
|
||||
|
||||
|
||||
def load_env_vars(folder_path: Path) -> dict[str, Any]:
|
||||
"""Loads environment variables from a .env file in the specified folder path."""
|
||||
env_file_path = folder_path / ".env"
|
||||
env_vars = {}
|
||||
if env_file_path.exists():
|
||||
with open(env_file_path, "r") as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
env_vars[key] = value
|
||||
return env_vars
|
||||
|
||||
|
||||
def update_env_vars(
|
||||
env_vars: dict[str, Any], provider: str, model: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Updates environment variables with the API key for the selected provider and model."""
|
||||
provider_config = cast(
|
||||
list[str],
|
||||
ENV_VARS.get(
|
||||
provider,
|
||||
[
|
||||
click.prompt(
|
||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||
type=str,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
api_key_var = provider_config[0]
|
||||
|
||||
if api_key_var not in env_vars:
|
||||
try:
|
||||
env_vars[api_key_var] = click.prompt(
|
||||
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
|
||||
)
|
||||
except click.exceptions.Abort:
|
||||
click.secho("Operation aborted by the user.", fg="red")
|
||||
return None
|
||||
else:
|
||||
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
|
||||
|
||||
env_vars["MODEL"] = model
|
||||
click.secho(f"Selected model: {model}", fg="green")
|
||||
return env_vars
|
||||
|
||||
|
||||
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
|
||||
"""Writes environment variables to a .env file in the specified folder."""
|
||||
env_file_path = folder_path / ".env"
|
||||
with open(env_file_path, "w") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key.upper()}={value}\n")
|
||||
|
||||
|
||||
def is_valid_tool(obj: Any) -> bool:
|
||||
"""Check if an object is a valid tool class.
|
||||
|
||||
Works without importing crewai by checking MRO class names.
|
||||
Falls back to crewai's ``is_valid_tool`` when available.
|
||||
"""
|
||||
try:
|
||||
from crewai.cli.utils import is_valid_tool as _core_is_valid_tool
|
||||
|
||||
return _core_is_valid_tool(obj)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if isclass(obj):
|
||||
try:
|
||||
return any(base.__name__ == "BaseTool" for base in getmro(obj))
|
||||
except (TypeError, AttributeError):
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
|
||||
"""Extract available tool classes from the project's __init__.py files."""
|
||||
try:
|
||||
init_files = Path(dir_path).glob("**/__init__.py")
|
||||
available_exports: list[dict[str, Any]] = []
|
||||
|
||||
for init_file in init_files:
|
||||
tools = _load_tools_from_init(init_file)
|
||||
available_exports.extend(tools)
|
||||
|
||||
if not available_exports:
|
||||
_print_no_tools_warning()
|
||||
raise SystemExit(1)
|
||||
|
||||
return available_exports
|
||||
|
||||
except SystemExit:
|
||||
raise
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
|
||||
console.print(
|
||||
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
|
||||
)
|
||||
raise SystemExit(1) from e
|
||||
|
||||
|
||||
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
"""Load and validate tools from a given __init__.py file."""
|
||||
import importlib.util as _importlib_util
|
||||
|
||||
spec = _importlib_util.spec_from_file_location("temp_module", init_file)
|
||||
|
||||
if not spec or not spec.loader:
|
||||
return []
|
||||
|
||||
module = _importlib_util.module_from_spec(spec)
|
||||
sys.modules["temp_module"] = module
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
if not hasattr(module, "__all__"):
|
||||
console.print(
|
||||
f"Warning: No __all__ defined in {init_file}",
|
||||
style="bold yellow",
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
return [
|
||||
{"name": name}
|
||||
for name in module.__all__
|
||||
if hasattr(module, name) and is_valid_tool(getattr(module, name))
|
||||
]
|
||||
|
||||
except SystemExit:
|
||||
raise
|
||||
except Exception as e:
|
||||
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
finally:
|
||||
sys.modules.pop("temp_module", None)
|
||||
|
||||
|
||||
def _print_no_tools_warning() -> None:
|
||||
"""Display warning and usage instructions if no tools were found."""
|
||||
console.print(
|
||||
"\n[bold yellow]Warning: No valid tools were exposed in your __init__.py file![/bold yellow]"
|
||||
)
|
||||
console.print(
|
||||
"Your __init__.py file must contain all classes that inherit from [bold]BaseTool[/bold] "
|
||||
"or functions decorated with [bold]@tool[/bold]."
|
||||
)
|
||||
console.print(
|
||||
"\nExample:\n[dim]# In your __init__.py file[/dim]\n"
|
||||
"[green]__all__ = ['YourTool', 'your_tool_function'][/green]\n\n"
|
||||
"[dim]# In your tool.py file[/dim]\n"
|
||||
"[green]from crewai.tools import BaseTool, tool\n\n"
|
||||
"# Tool class example\n"
|
||||
"class YourTool(BaseTool):\n"
|
||||
' name = "your_tool"\n'
|
||||
' description = "Your tool description"\n'
|
||||
" # ... rest of implementation\n\n"
|
||||
"# Decorated function example\n"
|
||||
"@tool\n"
|
||||
"def your_tool_function(text: str) -> str:\n"
|
||||
' """Your tool description"""\n'
|
||||
" # ... implementation\n"
|
||||
" return result\n"
|
||||
)
|
||||
|
||||
|
||||
def build_env_with_tool_repository_credentials(
|
||||
repository_handle: str,
|
||||
) -> dict[str, Any]:
|
||||
repository_handle = repository_handle.upper().replace("-", "_")
|
||||
settings = Settings()
|
||||
|
||||
env = os.environ.copy()
|
||||
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
|
||||
settings.tool_repository_username or ""
|
||||
)
|
||||
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
|
||||
settings.tool_repository_password or ""
|
||||
)
|
||||
|
||||
return env
|
||||
@@ -1,215 +0,0 @@
|
||||
"""Version utilities for CrewAI CLI."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache
|
||||
import importlib.metadata
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib import request
|
||||
from urllib.error import URLError
|
||||
|
||||
import appdirs
|
||||
from packaging.version import InvalidVersion, Version, parse
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_cache_file() -> Path:
|
||||
"""Get the path to the version cache file.
|
||||
|
||||
Cached to avoid repeated filesystem operations.
|
||||
"""
|
||||
cache_dir = Path(appdirs.user_cache_dir("crewai"))
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir / "version_cache.json"
|
||||
|
||||
|
||||
def get_crewai_version() -> str:
|
||||
"""Get the version number of CrewAI running the CLI."""
|
||||
return importlib.metadata.version("crewai")
|
||||
|
||||
|
||||
def _is_cache_valid(cache_data: Mapping[str, Any]) -> bool:
|
||||
"""Check if the cache is still valid, less than 24 hours old."""
|
||||
if "timestamp" not in cache_data:
|
||||
return False
|
||||
|
||||
try:
|
||||
cache_time = datetime.fromisoformat(str(cache_data["timestamp"]))
|
||||
return datetime.now() - cache_time < timedelta(hours=24)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
def _find_latest_non_yanked_version(
|
||||
releases: Mapping[str, list[dict[str, Any]]],
|
||||
) -> str | None:
|
||||
"""Find the latest non-yanked version from PyPI releases data.
|
||||
|
||||
Args:
|
||||
releases: PyPI releases dict mapping version strings to file info lists.
|
||||
|
||||
Returns:
|
||||
The latest non-yanked version string, or None if all versions are yanked.
|
||||
"""
|
||||
best_version: Version | None = None
|
||||
best_version_str: str | None = None
|
||||
|
||||
for version_str, files in releases.items():
|
||||
try:
|
||||
v = parse(version_str)
|
||||
except InvalidVersion:
|
||||
continue
|
||||
|
||||
if v.is_prerelease or v.is_devrelease:
|
||||
continue
|
||||
|
||||
if not files:
|
||||
continue
|
||||
|
||||
all_yanked = all(f.get("yanked", False) for f in files)
|
||||
if all_yanked:
|
||||
continue
|
||||
|
||||
if best_version is None or v > best_version:
|
||||
best_version = v
|
||||
best_version_str = version_str
|
||||
|
||||
return best_version_str
|
||||
|
||||
|
||||
def _is_version_yanked(
|
||||
version_str: str,
|
||||
releases: Mapping[str, list[dict[str, Any]]],
|
||||
) -> tuple[bool, str]:
|
||||
"""Check if a specific version is yanked.
|
||||
|
||||
Args:
|
||||
version_str: The version string to check.
|
||||
releases: PyPI releases dict mapping version strings to file info lists.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_yanked, yanked_reason).
|
||||
"""
|
||||
files = releases.get(version_str, [])
|
||||
if not files:
|
||||
return False, ""
|
||||
|
||||
all_yanked = all(f.get("yanked", False) for f in files)
|
||||
if not all_yanked:
|
||||
return False, ""
|
||||
|
||||
for f in files:
|
||||
reason = f.get("yanked_reason", "")
|
||||
if reason:
|
||||
return True, str(reason)
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_latest_version_from_pypi(timeout: int = 2) -> str | None:
|
||||
"""Get the latest non-yanked version of CrewAI from PyPI.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Latest non-yanked version string or None if unable to fetch.
|
||||
"""
|
||||
cache_file = _get_cache_file()
|
||||
if cache_file.exists():
|
||||
try:
|
||||
cache_data = json.loads(cache_file.read_text())
|
||||
if _is_cache_valid(cache_data) and "current_version" in cache_data:
|
||||
version: str | None = cache_data.get("version")
|
||||
return version
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
|
||||
try:
|
||||
with request.urlopen(
|
||||
"https://pypi.org/pypi/crewai/json", timeout=timeout
|
||||
) as response:
|
||||
data = json.loads(response.read())
|
||||
releases: dict[str, list[dict[str, Any]]] = data["releases"]
|
||||
latest_version = _find_latest_non_yanked_version(releases)
|
||||
|
||||
current_version = get_crewai_version()
|
||||
is_yanked, yanked_reason = _is_version_yanked(current_version, releases)
|
||||
|
||||
cache_data = {
|
||||
"version": latest_version,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"current_version": current_version,
|
||||
"current_version_yanked": is_yanked,
|
||||
"current_version_yanked_reason": yanked_reason,
|
||||
}
|
||||
cache_file.write_text(json.dumps(cache_data))
|
||||
|
||||
return latest_version
|
||||
except (URLError, json.JSONDecodeError, KeyError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def is_current_version_yanked() -> tuple[bool, str]:
|
||||
"""Check if the currently installed version has been yanked on PyPI.
|
||||
|
||||
Reads from cache if available, otherwise triggers a fetch.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_yanked, yanked_reason).
|
||||
"""
|
||||
cache_file = _get_cache_file()
|
||||
if cache_file.exists():
|
||||
try:
|
||||
cache_data = json.loads(cache_file.read_text())
|
||||
if _is_cache_valid(cache_data) and "current_version" in cache_data:
|
||||
current = get_crewai_version()
|
||||
if cache_data.get("current_version") == current:
|
||||
return (
|
||||
bool(cache_data.get("current_version_yanked", False)),
|
||||
str(cache_data.get("current_version_yanked_reason", "")),
|
||||
)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
|
||||
get_latest_version_from_pypi()
|
||||
|
||||
try:
|
||||
cache_data = json.loads(cache_file.read_text())
|
||||
return (
|
||||
bool(cache_data.get("current_version_yanked", False)),
|
||||
str(cache_data.get("current_version_yanked_reason", "")),
|
||||
)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return False, ""
|
||||
|
||||
|
||||
def check_version() -> tuple[str, str | None]:
|
||||
"""Check current and latest versions.
|
||||
|
||||
Returns:
|
||||
Tuple of (current_version, latest_version).
|
||||
latest_version is None if unable to fetch from PyPI.
|
||||
"""
|
||||
current = get_crewai_version()
|
||||
latest = get_latest_version_from_pypi()
|
||||
return current, latest
|
||||
|
||||
|
||||
def is_newer_version_available() -> tuple[bool, str, str | None]:
|
||||
"""Check if a newer version is available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_newer, current_version, latest_version).
|
||||
"""
|
||||
current, latest = check_version()
|
||||
|
||||
if latest is None:
|
||||
return False, current, None
|
||||
|
||||
try:
|
||||
return parse(latest) > parse(current), current, latest
|
||||
except (InvalidVersion, TypeError):
|
||||
return False, current, latest
|
||||
@@ -1,91 +0,0 @@
|
||||
import pytest
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.auth0 import Auth0Provider
|
||||
|
||||
|
||||
|
||||
class TestAuth0Provider:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="test-domain.auth0.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience"
|
||||
)
|
||||
self.provider = Auth0Provider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = Auth0Provider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "auth0"
|
||||
assert provider.settings.domain == "test-domain.auth0.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://test-domain.auth0.com/oauth/device/code"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="my-company.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_url = "https://my-company.auth0.com/oauth/device/code"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://test-domain.auth0.com/oauth/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="another-domain.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_url = "https://another-domain.auth0.com/oauth/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://test-domain.auth0.com/.well-known/jwks.json"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="dev.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_url = "https://dev.auth0.com/.well-known/jwks.json"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://test-domain.auth0.com/"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="prod.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_issuer = "https://prod.auth0.com/"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
@@ -1,141 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.entra_id import EntraIdProvider
|
||||
|
||||
|
||||
class TestEntraIdProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "openid profile email api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
self.provider = EntraIdProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = EntraIdProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "entra_id"
|
||||
assert provider.settings.domain == "tenant-id-abcdef123456"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="my-company.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="another-domain.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="dev.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="other-tenant-id-xpto",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_assertion_error_when_none(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="test-tenant-id",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="Audience is required"):
|
||||
provider.get_audience()
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["scope"])
|
||||
|
||||
def test_get_oauth_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
|
||||
|
||||
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
|
||||
|
||||
def test_base_url(self):
|
||||
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"
|
||||
@@ -1,138 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.keycloak import KeycloakProvider
|
||||
|
||||
|
||||
class TestKeycloakProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="keycloak.example.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "test-realm"
|
||||
}
|
||||
)
|
||||
self.provider = KeycloakProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = KeycloakProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "keycloak"
|
||||
assert provider.settings.domain == "keycloak.example.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
assert provider.settings.extra.get("realm") == "test-realm"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/auth/device"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="auth.company.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "my-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_url = "https://auth.company.com/realms/my-realm/protocol/openid-connect/auth/device"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="sso.enterprise.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "enterprise-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_url = "https://sso.enterprise.com/realms/enterprise-realm/protocol/openid-connect/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/certs"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="identity.org",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "org-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_url = "https://identity.org/realms/org-realm/protocol/openid-connect/certs"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://keycloak.example.com/realms/test-realm"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="login.myapp.io",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "app-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_issuer = "https://login.myapp.io/realms/app-realm"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert self.provider.get_required_fields() == ["realm"]
|
||||
|
||||
def test_oauth2_base_url(self):
|
||||
assert self.provider._oauth2_base_url() == "https://keycloak.example.com"
|
||||
|
||||
def test_oauth2_base_url_strips_https_prefix(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="https://keycloak.example.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "test-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://keycloak.example.com"
|
||||
|
||||
def test_oauth2_base_url_strips_http_prefix(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="http://keycloak.example.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "test-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://keycloak.example.com"
|
||||
@@ -1,257 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.okta import OktaProvider
|
||||
|
||||
|
||||
class TestOktaProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
)
|
||||
self.provider = OktaProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = OktaProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "okta"
|
||||
assert provider.settings.domain == "test-domain.okta.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="my-company.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="another-domain.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="dev.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://test-domain.okta.com/oauth2/default"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="prod.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://prod.okta.com/oauth2/default"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://test-domain.okta.com"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_assertion_error_when_none(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="Audience is required"):
|
||||
provider.get_audience()
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
|
||||
|
||||
def test_oauth2_base_url(self):
|
||||
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
|
||||
|
||||
def test_oauth2_base_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
|
||||
provider = OktaProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||
|
||||
def test_oauth2_base_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"
|
||||
@@ -1,100 +0,0 @@
|
||||
import pytest
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.workos import WorkosProvider
|
||||
|
||||
|
||||
class TestWorkosProvider:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="login.company.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience"
|
||||
)
|
||||
self.provider = WorkosProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = WorkosProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "workos"
|
||||
assert provider.settings.domain == "login.company.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://login.company.com/oauth2/device_authorization"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="login.example.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_url = "https://login.example.com/oauth2/device_authorization"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://login.company.com/oauth2/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="api.workos.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_url = "https://api.workos.com/oauth2/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://login.company.com/oauth2/jwks"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="auth.enterprise.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_url = "https://auth.enterprise.com/oauth2/jwks"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://login.company.com"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="sso.company.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_issuer = "https://sso.company.com"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_fallback_to_default(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="login.company.com",
|
||||
client_id="test-client-id",
|
||||
audience=None
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
assert provider.get_audience() == ""
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
@@ -1,348 +0,0 @@
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
from crewai_cli.authentication.main import AuthenticationCommand
|
||||
from crewai_cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
)
|
||||
|
||||
|
||||
class TestAuthenticationCommand:
|
||||
def setup_method(self):
|
||||
# Mock Settings so we always use default constants regardless of local config.
|
||||
with patch("crewai_cli.authentication.main.Settings") as mock_settings:
|
||||
instance = mock_settings.return_value
|
||||
instance.oauth2_provider = "workos"
|
||||
instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN
|
||||
instance.oauth2_client_id = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID
|
||||
instance.oauth2_audience = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE
|
||||
instance.oauth2_extra = {}
|
||||
self.auth_command = AuthenticationCommand()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,expected_urls",
|
||||
[
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
"device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization",
|
||||
"token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token",
|
||||
"client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("crewai_cli.authentication.main.AuthenticationCommand._get_device_code")
|
||||
@patch(
|
||||
"crewai_cli.authentication.main.AuthenticationCommand._display_auth_instructions"
|
||||
)
|
||||
@patch("crewai_cli.authentication.main.AuthenticationCommand._poll_for_token")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_login(
|
||||
self,
|
||||
mock_console_print,
|
||||
mock_poll,
|
||||
mock_display,
|
||||
mock_get_device,
|
||||
user_provider,
|
||||
expected_urls,
|
||||
):
|
||||
mock_get_device.return_value = {
|
||||
"device_code": "test_code",
|
||||
"user_code": "123456",
|
||||
}
|
||||
|
||||
self.auth_command.login()
|
||||
|
||||
mock_console_print.assert_called_once_with(
|
||||
"Signing in to CrewAI AMP...\n", style="bold blue"
|
||||
)
|
||||
mock_get_device.assert_called_once()
|
||||
mock_display.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"}
|
||||
)
|
||||
mock_poll.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"},
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_client_id()
|
||||
== expected_urls["client_id"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_audience()
|
||||
== expected_urls["audience"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
||||
)
|
||||
|
||||
@patch("crewai_cli.authentication.main.webbrowser")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_display_auth_instructions(self, mock_console_print, mock_webbrowser):
|
||||
device_code_data = {
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
"user_code": "123456",
|
||||
}
|
||||
|
||||
self.auth_command._display_auth_instructions(device_code_data)
|
||||
|
||||
expected_calls = [
|
||||
call("1. Navigate to: ", "https://example.com/auth"),
|
||||
call("2. Enter the following code: ", "123456"),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
mock_webbrowser.open.assert_called_once_with("https://example.com/auth")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,jwt_config",
|
||||
[
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
"jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks",
|
||||
"issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}",
|
||||
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("has_expiration", [True, False])
|
||||
@patch("crewai_cli.authentication.main.validate_jwt_token")
|
||||
@patch("crewai_cli.authentication.main.TokenManager.save_tokens")
|
||||
def test_validate_and_save_token(
|
||||
self,
|
||||
mock_save_tokens,
|
||||
mock_validate_jwt,
|
||||
user_provider,
|
||||
jwt_config,
|
||||
has_expiration,
|
||||
):
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.workos import WorkosProvider
|
||||
|
||||
if user_provider == "workos":
|
||||
self.auth_command.oauth2_provider = WorkosProvider(
|
||||
settings=Oauth2Settings(
|
||||
provider=user_provider,
|
||||
client_id="test-client-id",
|
||||
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
audience=jwt_config["audience"],
|
||||
)
|
||||
)
|
||||
|
||||
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
|
||||
|
||||
if has_expiration:
|
||||
future_timestamp = int((datetime.now() + timedelta(days=100)).timestamp())
|
||||
decoded_token = {"exp": future_timestamp}
|
||||
else:
|
||||
decoded_token = {}
|
||||
|
||||
mock_validate_jwt.return_value = decoded_token
|
||||
|
||||
self.auth_command._validate_and_save_token(token_data)
|
||||
|
||||
mock_validate_jwt.assert_called_once_with(
|
||||
jwt_token="test_access_token",
|
||||
jwks_url=jwt_config["jwks_url"],
|
||||
issuer=jwt_config["issuer"],
|
||||
audience=jwt_config["audience"],
|
||||
)
|
||||
|
||||
if has_expiration:
|
||||
mock_save_tokens.assert_called_once_with(
|
||||
"test_access_token", future_timestamp
|
||||
)
|
||||
else:
|
||||
mock_save_tokens.assert_called_once_with("test_access_token", 0)
|
||||
|
||||
@patch("crewai_cli.tools.main.ToolCommand")
|
||||
@patch("crewai_cli.authentication.main.Settings")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_login_to_tool_repository_success(
|
||||
self, mock_console_print, mock_settings, mock_tool_command
|
||||
):
|
||||
mock_tool_instance = MagicMock()
|
||||
mock_tool_command.return_value = mock_tool_instance
|
||||
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.org_name = "Test Org"
|
||||
mock_settings_instance.org_uuid = "test-uuid-123"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
|
||||
self.auth_command._login_to_tool_repository()
|
||||
|
||||
mock_tool_command.assert_called_once()
|
||||
mock_tool_instance.login.assert_called_once()
|
||||
|
||||
expected_calls = [
|
||||
call(
|
||||
"Now logging you in to the Tool Repository... ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
),
|
||||
call("Success!\n", style="bold green"),
|
||||
call(
|
||||
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
|
||||
style="green",
|
||||
),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_cli.tools.main.ToolCommand")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_login_to_tool_repository_error(
|
||||
self, mock_console_print, mock_tool_command
|
||||
):
|
||||
mock_tool_instance = MagicMock()
|
||||
mock_tool_instance.login.side_effect = Exception("Tool repository error")
|
||||
mock_tool_command.return_value = mock_tool_instance
|
||||
|
||||
self.auth_command._login_to_tool_repository()
|
||||
|
||||
mock_tool_command.assert_called_once()
|
||||
mock_tool_instance.login.assert_called_once()
|
||||
|
||||
expected_calls = [
|
||||
call(
|
||||
"Now logging you in to the Tool Repository... ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
),
|
||||
call(
|
||||
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
|
||||
style="yellow",
|
||||
),
|
||||
call(
|
||||
"Other features will work normally, but you may experience limitations with downloading and publishing tools.\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_cli.authentication.main.httpx.post")
|
||||
def test_get_device_code(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"device_code": "test_device_code",
|
||||
"user_code": "123456",
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
|
||||
"https://example.com/device"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
|
||||
|
||||
result = self.auth_command._get_device_code()
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
url="https://example.com/device",
|
||||
data={
|
||||
"client_id": "test_client",
|
||||
"scope": "openid profile email",
|
||||
"audience": "test_audience",
|
||||
},
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"device_code": "test_device_code",
|
||||
"user_code": "123456",
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
}
|
||||
|
||||
@patch("crewai_cli.authentication.main.httpx.post")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_poll_for_token_success(self, mock_console_print, mock_post):
|
||||
mock_response_success = MagicMock()
|
||||
mock_response_success.status_code = 200
|
||||
mock_response_success.json.return_value = {
|
||||
"access_token": "test_access_token",
|
||||
"id_token": "test_id_token",
|
||||
}
|
||||
mock_post.return_value = mock_response_success
|
||||
|
||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
self.auth_command, "_validate_and_save_token"
|
||||
) as mock_validate,
|
||||
patch.object(
|
||||
self.auth_command, "_login_to_tool_repository"
|
||||
) as mock_tool_login,
|
||||
):
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_token_url.return_value = (
|
||||
"https://example.com/token"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
"https://example.com/token",
|
||||
data={
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": "test_device_code",
|
||||
"client_id": "test_client",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
mock_validate.assert_called_once()
|
||||
mock_tool_login.assert_called_once()
|
||||
|
||||
expected_calls = [
|
||||
call("\nWaiting for authentication... ", style="bold blue", end=""),
|
||||
call("Success!", style="bold green"),
|
||||
call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_cli.authentication.main.httpx.post")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
|
||||
mock_response_pending = MagicMock()
|
||||
mock_response_pending.status_code = 400
|
||||
mock_response_pending.json.return_value = {"error": "authorization_pending"}
|
||||
mock_post.return_value = mock_response_pending
|
||||
|
||||
device_code_data = {
|
||||
"device_code": "test_device_code",
|
||||
"interval": 0.1, # Short interval for testing
|
||||
}
|
||||
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_console_print.assert_any_call(
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.authentication.main.httpx.post")
|
||||
def test_poll_for_token_error(self, mock_post):
|
||||
"""Test the method to poll for token (error path)."""
|
||||
# Setup mock to return error
|
||||
mock_response_error = MagicMock()
|
||||
mock_response_error.status_code = 400
|
||||
mock_response_error.json.return_value = {
|
||||
"error": "access_denied",
|
||||
"error_description": "User denied access",
|
||||
}
|
||||
mock_post.return_value = mock_response_error
|
||||
|
||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
@@ -1,107 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from crewai_cli.authentication.utils import validate_jwt_token
|
||||
|
||||
|
||||
@patch("crewai_cli.authentication.utils.PyJWKClient", return_value=MagicMock())
|
||||
@patch("crewai_cli.authentication.utils.jwt")
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.return_value = {"exp": 1719859200}
|
||||
|
||||
# Create signing key object mock with a .key attribute
|
||||
mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock(
|
||||
key="mock_signing_key"
|
||||
)
|
||||
|
||||
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
|
||||
|
||||
decoded_token = validate_jwt_token(
|
||||
jwt_token=jwt_token,
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
mock_jwt.decode.assert_called_with(
|
||||
jwt_token,
|
||||
"mock_signing_key",
|
||||
algorithms=["RS256"],
|
||||
audience="app_id_xxxx",
|
||||
issuer="https://mock_issuer",
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_nbf": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "iat", "iss", "aud", "sub"],
|
||||
},
|
||||
)
|
||||
mock_pyjwkclient.assert_called_once_with("https://mock_jwks_url")
|
||||
self.assertEqual(decoded_token, {"exp": 1719859200})
|
||||
|
||||
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_missing_required_claims(
|
||||
self, mock_jwt, mock_pyjwkclient
|
||||
):
|
||||
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidTokenError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
@@ -1,255 +0,0 @@
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from crewai_cli.cli import (
|
||||
deploy_create,
|
||||
deploy_list,
|
||||
deploy_logs,
|
||||
deploy_push,
|
||||
deploy_remove,
|
||||
deply_status,
|
||||
flow_add_crew,
|
||||
login,
|
||||
reset_memories,
|
||||
test,
|
||||
train,
|
||||
version,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.train_crew")
|
||||
def test_train_default_iterations(train_crew, runner):
|
||||
result = runner.invoke(train)
|
||||
|
||||
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
|
||||
assert result.exit_code == 0
|
||||
assert "Training the Crew for 5 iterations" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.train_crew")
|
||||
def test_train_custom_iterations(train_crew, runner):
|
||||
result = runner.invoke(train, ["--n_iterations", "10"])
|
||||
|
||||
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
|
||||
assert result.exit_code == 0
|
||||
assert "Training the Crew for 10 iterations" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.train_crew")
|
||||
def test_train_invalid_string_iterations(train_crew, runner):
|
||||
result = runner.invoke(train, ["--n_iterations", "invalid"])
|
||||
|
||||
train_crew.assert_not_called()
|
||||
assert result.exit_code == 2
|
||||
assert (
|
||||
"Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
|
||||
in result.output
|
||||
)
|
||||
|
||||
|
||||
def test_reset_no_memory_flags(runner):
|
||||
result = runner.invoke(
|
||||
reset_memories,
|
||||
)
|
||||
assert (
|
||||
result.output
|
||||
== "Please specify at least one memory type to reset using the appropriate flags.\n"
|
||||
)
|
||||
|
||||
|
||||
def test_version_flag(runner):
|
||||
result = runner.invoke(version)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "crewai version:" in result.output
|
||||
|
||||
|
||||
def test_version_command(runner):
|
||||
result = runner.invoke(version)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "crewai version:" in result.output
|
||||
|
||||
|
||||
def test_version_command_with_tools(runner):
|
||||
result = runner.invoke(version, ["--tools"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "crewai version:" in result.output
|
||||
assert (
|
||||
"crewai tools version:" in result.output
|
||||
or "crewai tools not installed" in result.output
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.evaluate_crew")
|
||||
def test_test_default_iterations(evaluate_crew, runner):
|
||||
result = runner.invoke(test)
|
||||
|
||||
evaluate_crew.assert_called_once_with(3, "gpt-4o-mini")
|
||||
assert result.exit_code == 0
|
||||
assert "Testing the crew for 3 iterations with model gpt-4o-mini" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.evaluate_crew")
|
||||
def test_test_custom_iterations(evaluate_crew, runner):
|
||||
result = runner.invoke(test, ["--n_iterations", "5", "--model", "gpt-4o"])
|
||||
|
||||
evaluate_crew.assert_called_once_with(5, "gpt-4o")
|
||||
assert result.exit_code == 0
|
||||
assert "Testing the crew for 5 iterations with model gpt-4o" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.evaluate_crew")
|
||||
def test_test_invalid_string_iterations(evaluate_crew, runner):
|
||||
result = runner.invoke(test, ["--n_iterations", "invalid"])
|
||||
|
||||
evaluate_crew.assert_not_called()
|
||||
assert result.exit_code == 2
|
||||
assert (
|
||||
"Usage: test [OPTIONS]\nTry 'test --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
|
||||
in result.output
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.AuthenticationCommand")
|
||||
def test_login(command, runner):
|
||||
mock_auth = command.return_value
|
||||
result = runner.invoke(login)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_auth.login.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_create(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_create)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.create_crew.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_list(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_list)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.list_crews.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_push(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
uuid = "test-uuid"
|
||||
result = runner.invoke(deploy_push, ["-u", uuid])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.deploy.assert_called_once_with(uuid=uuid)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_push_no_uuid(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_push)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.deploy.assert_called_once_with(uuid=None)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_status(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
uuid = "test-uuid"
|
||||
result = runner.invoke(deply_status, ["-u", uuid])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.get_crew_status.assert_called_once_with(uuid=uuid)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_status_no_uuid(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deply_status)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.get_crew_status.assert_called_once_with(uuid=None)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_logs(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
uuid = "test-uuid"
|
||||
result = runner.invoke(deploy_logs, ["-u", uuid])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.get_crew_logs.assert_called_once_with(uuid=uuid)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_logs_no_uuid(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_logs)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.get_crew_logs.assert_called_once_with(uuid=None)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_remove(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
uuid = "test-uuid"
|
||||
result = runner.invoke(deploy_remove, ["-u", uuid])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.remove_crew.assert_called_once_with(uuid=uuid)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_remove_no_uuid(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_remove)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.remove_crew.assert_called_once_with(uuid=None)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.add_crew_to_flow.create_embedded_crew")
|
||||
@mock.patch("pathlib.Path.exists", return_value=True)
|
||||
def test_flow_add_crew(mock_path_exists, mock_create_embedded_crew, runner):
|
||||
crew_name = "new_crew"
|
||||
result = runner.invoke(flow_add_crew, [crew_name])
|
||||
|
||||
assert result.exit_code == 0, f"Command failed with output: {result.output}"
|
||||
assert f"Adding crew {crew_name} to the flow" in result.output
|
||||
|
||||
mock_create_embedded_crew.assert_called_once()
|
||||
call_args, call_kwargs = mock_create_embedded_crew.call_args
|
||||
assert call_args[0] == crew_name
|
||||
assert "parent_folder" in call_kwargs
|
||||
assert isinstance(call_kwargs["parent_folder"], Path)
|
||||
|
||||
|
||||
def test_add_crew_to_flow_not_in_root(runner):
|
||||
with mock.patch("pathlib.Path.exists", autospec=True) as mock_exists:
|
||||
def exists_side_effect(self):
|
||||
if self.name == "pyproject.toml":
|
||||
return False
|
||||
return True
|
||||
|
||||
mock_exists.side_effect = exists_side_effect
|
||||
|
||||
result = runner.invoke(flow_add_crew, ["new_crew"])
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "This command must be run from the root of a flow project." in str(
|
||||
result.output
|
||||
)
|
||||
@@ -1,148 +0,0 @@
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai_cli.config import (
|
||||
CLI_SETTINGS_KEYS,
|
||||
DEFAULT_CLI_SETTINGS,
|
||||
USER_SETTINGS_KEYS,
|
||||
Settings,
|
||||
)
|
||||
from crewai_cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
class TestSettings(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = Path(tempfile.mkdtemp())
|
||||
self.config_path = self.test_dir / "settings.json"
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def test_empty_initialization(self):
|
||||
settings = Settings(config_path=self.config_path)
|
||||
self.assertIsNone(settings.tool_repository_username)
|
||||
self.assertIsNone(settings.tool_repository_password)
|
||||
|
||||
def test_initialization_with_data(self):
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="user1"
|
||||
)
|
||||
self.assertEqual(settings.tool_repository_username, "user1")
|
||||
self.assertIsNone(settings.tool_repository_password)
|
||||
|
||||
def test_initialization_with_existing_file(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump({"tool_repository_username": "file_user"}, f)
|
||||
|
||||
settings = Settings(config_path=self.config_path)
|
||||
self.assertEqual(settings.tool_repository_username, "file_user")
|
||||
|
||||
def test_merge_file_and_input_data(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"tool_repository_username": "file_user",
|
||||
"tool_repository_password": "file_pass",
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="new_user"
|
||||
)
|
||||
self.assertEqual(settings.tool_repository_username, "new_user")
|
||||
self.assertEqual(settings.tool_repository_password, "file_pass")
|
||||
|
||||
def test_clear_user_settings(self):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
|
||||
settings = Settings(config_path=self.config_path, **user_settings)
|
||||
settings.clear_user_settings()
|
||||
|
||||
for key in user_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), None)
|
||||
|
||||
@patch("crewai_cli.config.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
|
||||
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
)
|
||||
|
||||
mock_token_manager.return_value = MagicMock()
|
||||
TokenManager().save_tokens(
|
||||
"aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp()
|
||||
)
|
||||
|
||||
settings.reset()
|
||||
|
||||
for key in user_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), None)
|
||||
for key in cli_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
mock_token_manager.return_value.clear_tokens.assert_called_once()
|
||||
|
||||
def test_dump_new_settings(self):
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="user1"
|
||||
)
|
||||
settings.dump()
|
||||
|
||||
with self.config_path.open("r") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
self.assertEqual(saved_data["tool_repository_username"], "user1")
|
||||
|
||||
def test_update_existing_settings(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump({"existing_setting": "value"}, f)
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="user1"
|
||||
)
|
||||
settings.dump()
|
||||
|
||||
with self.config_path.open("r") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
self.assertEqual(saved_data["existing_setting"], "value")
|
||||
self.assertEqual(saved_data["tool_repository_username"], "user1")
|
||||
|
||||
def test_none_values(self):
|
||||
settings = Settings(config_path=self.config_path, tool_repository_username=None)
|
||||
settings.dump()
|
||||
|
||||
with self.config_path.open("r") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
self.assertIsNone(saved_data.get("tool_repository_username"))
|
||||
|
||||
def test_invalid_json_in_config(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.config_path.open("w") as f:
|
||||
f.write("invalid json")
|
||||
|
||||
try:
|
||||
settings = Settings(config_path=self.config_path)
|
||||
self.assertIsNone(settings.tool_repository_username)
|
||||
except json.JSONDecodeError:
|
||||
self.fail("Settings initialization should handle invalid JSON")
|
||||
|
||||
def test_empty_config_file(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.config_path.touch()
|
||||
|
||||
settings = Settings(config_path=self.config_path)
|
||||
self.assertIsNone(settings.tool_repository_username)
|
||||
@@ -1,20 +0,0 @@
|
||||
from crewai_cli.constants import ENV_VARS, MODELS, PROVIDERS
|
||||
|
||||
|
||||
def test_huggingface_in_providers():
|
||||
"""Test that Huggingface is in the PROVIDERS list."""
|
||||
assert "huggingface" in PROVIDERS
|
||||
|
||||
|
||||
def test_huggingface_env_vars():
|
||||
"""Test that Huggingface environment variables are properly configured."""
|
||||
assert "huggingface" in ENV_VARS
|
||||
assert any(
|
||||
detail.get("key_name") == "HF_TOKEN" for detail in ENV_VARS["huggingface"]
|
||||
)
|
||||
|
||||
|
||||
def test_huggingface_models():
|
||||
"""Test that Huggingface models are properly configured."""
|
||||
assert "huggingface" in MODELS
|
||||
assert len(MODELS["huggingface"]) > 0
|
||||
@@ -1,101 +0,0 @@
|
||||
import pytest
|
||||
from crewai_cli.git import Repository
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def repository(fp):
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
|
||||
fp.register(["git", "fetch"], stdout="")
|
||||
return Repository(path=".")
|
||||
|
||||
|
||||
def test_init_with_invalid_git_repo(fp):
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
returncode=1,
|
||||
stderr="fatal: not a git repository\n",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Repository(path="invalid/path")
|
||||
|
||||
|
||||
def test_is_git_not_installed(fp):
|
||||
fp.register(["git", "--version"], returncode=1)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Git is not installed or not found in your PATH."
|
||||
):
|
||||
Repository(path=".")
|
||||
|
||||
|
||||
def test_status(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.status() == "## main...origin/main [ahead 1]"
|
||||
|
||||
|
||||
def test_has_uncommitted_changes(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.has_uncommitted_changes() is True
|
||||
|
||||
|
||||
def test_is_ahead_or_behind(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.is_ahead_or_behind() is True
|
||||
|
||||
|
||||
def test_is_synced_when_synced(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
|
||||
)
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
|
||||
)
|
||||
assert repository.is_synced() is True
|
||||
|
||||
|
||||
def test_is_synced_with_uncommitted_changes(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_is_synced_when_ahead_or_behind(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_origin_url(fp, repository):
|
||||
fp.register(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
stdout="https://github.com/user/repo.git\n",
|
||||
)
|
||||
assert repository.origin_url() == "https://github.com/user/repo.git"
|
||||
@@ -1,356 +0,0 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_cli.plus_api import PlusAPI
|
||||
|
||||
|
||||
class TestPlusAPI(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_api_key"
|
||||
self.api = PlusAPI(self.api_key)
|
||||
self.org_uuid = "test-org-uuid"
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.api.api_key, self.api_key)
|
||||
self.assertEqual(self.api.headers["Authorization"], f"Bearer {self.api_key}")
|
||||
self.assertEqual(self.api.headers["Content-Type"], "application/json")
|
||||
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
|
||||
self.assertTrue(self.api.headers["X-Crewai-Version"])
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_login_to_tool_repository(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.login_to_tool_repository()
|
||||
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools/login"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
def assert_request_with_org_id(
|
||||
self, mock_client_instance, method: str, endpoint: str, **kwargs
|
||||
):
|
||||
mock_client_instance.request.assert_called_once_with(
|
||||
method,
|
||||
f"{os.getenv('CREWAI_PLUS_URL')}{endpoint}",
|
||||
headers={
|
||||
"Authorization": ANY,
|
||||
"Content-Type": ANY,
|
||||
"User-Agent": ANY,
|
||||
"X-Crewai-Version": ANY,
|
||||
"X-Crewai-Organization-Id": self.org_uuid,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
@patch("crewai_cli.plus_api.httpx.Client")
|
||||
def test_login_to_tool_repository_with_org_uuid(
|
||||
self, mock_client_class, mock_settings_class
|
||||
):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
response = self.api.login_to_tool_repository()
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_client_instance, "POST", "/crewai_plus/api/v1/tools/login"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_get_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.get_tool("test_tool_handle")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
@patch("crewai_cli.plus_api.httpx.Client")
|
||||
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
response = self.api.get_tool("test_tool_handle")
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_client_instance, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
@patch("crewai_cli.plus_api.httpx.Client")
|
||||
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
expected_params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
}
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_client_instance, "POST", "/crewai_plus/api/v1/tools", json=expected_params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool_without_description(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = False
|
||||
version = "2.0.0"
|
||||
description = None
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.httpx.Client")
|
||||
def test_make_request(self, mock_client_class):
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
response = self.api._make_request("GET", "test_endpoint")
|
||||
|
||||
mock_client_class.assert_called_once_with(trust_env=False, verify=True)
|
||||
mock_client_instance.request.assert_called_once_with(
|
||||
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_deploy_by_name(self, mock_make_request):
|
||||
self.api.deploy_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_deploy_by_uuid(self, mock_make_request):
|
||||
self.api.deploy_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_status_by_name(self, mock_make_request):
|
||||
self.api.crew_status_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_status_by_uuid(self, mock_make_request):
|
||||
self.api.crew_status_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_by_name(self, mock_make_request):
|
||||
self.api.crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
|
||||
)
|
||||
|
||||
self.api.crew_by_name("test_project", "custom_log")
|
||||
mock_make_request.assert_called_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_by_uuid(self, mock_make_request):
|
||||
self.api.crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
|
||||
)
|
||||
|
||||
self.api.crew_by_uuid("test_uuid", "custom_log")
|
||||
mock_make_request.assert_called_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_delete_crew_by_name(self, mock_make_request):
|
||||
self.api.delete_crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_delete_crew_by_uuid(self, mock_make_request):
|
||||
self.api.delete_crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_list_crews(self, mock_make_request):
|
||||
self.api.list_crews()
|
||||
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_create_crew(self, mock_make_request):
|
||||
payload = {"name": "test_crew"}
|
||||
self.api.create_crew(payload)
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews", json=payload
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
@patch.dict(os.environ, {"CREWAI_PLUS_URL": ""})
|
||||
def test_custom_base_url(self, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.enterprise_base_url = "https://custom-url.com/api"
|
||||
mock_settings_class.return_value = mock_settings
|
||||
custom_api = PlusAPI("test_key")
|
||||
self.assertEqual(
|
||||
custom_api.base_url,
|
||||
"https://custom-url.com/api",
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLUS_URL": "https://custom-url-from-env.com"})
|
||||
def test_custom_base_url_from_env(self):
|
||||
custom_api = PlusAPI("test_key")
|
||||
self.assertEqual(
|
||||
custom_api.base_url,
|
||||
"https://custom-url-from-env.com",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("httpx.AsyncClient")
|
||||
async def test_get_agent(mock_async_client_class):
|
||||
api = PlusAPI("test_api_key")
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
response = await api.get_agent("test_agent_handle")
|
||||
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
|
||||
headers=api.headers,
|
||||
)
|
||||
assert response == mock_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("httpx.AsyncClient")
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class):
|
||||
org_uuid = "test-org-uuid"
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv("CREWAI_PLUS_URL")
|
||||
mock_settings_class.return_value = mock_settings
|
||||
|
||||
api = PlusAPI("test_api_key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
response = await api.get_agent("test_agent_handle")
|
||||
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
|
||||
headers=api.headers,
|
||||
)
|
||||
assert "X-Crewai-Organization-Id" in api.headers
|
||||
assert api.headers["X-Crewai-Organization-Id"] == org_uuid
|
||||
assert response == mock_response
|
||||
@@ -1,294 +0,0 @@
|
||||
"""Tests for TokenManager with atomic file operations."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from crewai_cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
class TestTokenManager(unittest.TestCase):
|
||||
"""Test cases for TokenManager."""
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None:
|
||||
"""Set up test fixtures."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_get_or_create_key_existing(
|
||||
self,
|
||||
mock_get_or_create: unittest.mock.MagicMock,
|
||||
mock_read: unittest.mock.MagicMock,
|
||||
) -> None:
|
||||
"""Test that existing key is returned when present."""
|
||||
mock_key = Fernet.generate_key()
|
||||
mock_get_or_create.return_value = mock_key
|
||||
|
||||
token_manager = TokenManager()
|
||||
result = token_manager.key
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
|
||||
def test_get_or_create_key_new(self) -> None:
|
||||
"""Test that new key is created when none exists."""
|
||||
mock_key = Fernet.generate_key()
|
||||
|
||||
with (
|
||||
patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read,
|
||||
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create,
|
||||
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
|
||||
):
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
mock_read.assert_called_with("secret.key")
|
||||
mock_generate.assert_called_once()
|
||||
mock_atomic_create.assert_called_once_with("secret.key", mock_key)
|
||||
|
||||
def test_get_or_create_key_race_condition(self) -> None:
|
||||
"""Test that another process's key is used when atomic create fails."""
|
||||
our_key = Fernet.generate_key()
|
||||
their_key = Fernet.generate_key()
|
||||
|
||||
with (
|
||||
patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read,
|
||||
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create,
|
||||
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=our_key),
|
||||
):
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
self.assertEqual(result, their_key)
|
||||
self.assertEqual(mock_read.call_count, 2)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._atomic_write_secure_file")
|
||||
def test_save_tokens(
|
||||
self, mock_write: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test saving tokens encrypts and writes atomically."""
|
||||
access_token = "test_token"
|
||||
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
|
||||
|
||||
self.token_manager.save_tokens(access_token, expires_at)
|
||||
|
||||
mock_write.assert_called_once()
|
||||
args = mock_write.call_args[0]
|
||||
self.assertEqual(args[0], "tokens.enc")
|
||||
decrypted_data = self.token_manager.fernet.decrypt(args[1])
|
||||
data = json.loads(decrypted_data)
|
||||
self.assertEqual(data["access_token"], access_token)
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
def test_get_token_valid(
|
||||
self, mock_read: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test getting a valid non-expired token."""
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertEqual(result, access_token)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
def test_get_token_expired(
|
||||
self, mock_read: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test that expired token returns None."""
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
def test_get_token_not_found(
|
||||
self, mock_read: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test that missing token file returns None."""
|
||||
mock_read.return_value = None
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._delete_secure_file")
|
||||
def test_clear_tokens(
|
||||
self, mock_delete: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test clearing tokens deletes the token file."""
|
||||
self.token_manager.clear_tokens()
|
||||
|
||||
mock_delete.assert_called_once_with("tokens.enc")
|
||||
|
||||
|
||||
class TestAtomicFileOperations(unittest.TestCase):
|
||||
"""Test atomic file operations directly."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up test fixtures with temp directory."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.original_get_path = TokenManager._get_secure_storage_path
|
||||
|
||||
# Patch to use temp directory
|
||||
def mock_get_path() -> Path:
|
||||
return Path(self.temp_dir)
|
||||
|
||||
TokenManager._get_secure_storage_path = staticmethod(mock_get_path)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""Clean up temp directory."""
|
||||
TokenManager._get_secure_storage_path = staticmethod(self.original_get_path)
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_create_new_file(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test atomic create succeeds for new file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
result = tm._atomic_create_secure_file("test.txt", b"content")
|
||||
|
||||
self.assertTrue(result)
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
self.assertTrue(file_path.exists())
|
||||
self.assertEqual(file_path.read_bytes(), b"content")
|
||||
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_create_existing_file(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test atomic create fails for existing file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
# Create file first
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
file_path.write_bytes(b"original")
|
||||
|
||||
result = tm._atomic_create_secure_file("test.txt", b"new content")
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(file_path.read_bytes(), b"original")
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_write_new_file(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test atomic write creates new file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
tm._atomic_write_secure_file("test.txt", b"content")
|
||||
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
self.assertTrue(file_path.exists())
|
||||
self.assertEqual(file_path.read_bytes(), b"content")
|
||||
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_write_overwrites(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test atomic write overwrites existing file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
file_path.write_bytes(b"original")
|
||||
|
||||
tm._atomic_write_secure_file("test.txt", b"new content")
|
||||
|
||||
self.assertEqual(file_path.read_bytes(), b"new content")
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_write_no_temp_file_on_success(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test that temp file is cleaned up after successful write."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
tm._atomic_write_secure_file("test.txt", b"content")
|
||||
|
||||
# Check no temp files remain
|
||||
temp_files = list(Path(self.temp_dir).glob(".test.txt.*"))
|
||||
self.assertEqual(len(temp_files), 0)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_read_secure_file_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test reading existing file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
result = tm._read_secure_file("test.txt")
|
||||
|
||||
self.assertEqual(result, b"content")
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_read_secure_file_not_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test reading non-existent file returns None."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
result = tm._read_secure_file("nonexistent.txt")
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_delete_secure_file_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test deleting existing file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
tm._delete_secure_file("test.txt")
|
||||
|
||||
self.assertFalse(file_path.exists())
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_delete_secure_file_not_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test deleting non-existent file doesn't raise."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
# Should not raise
|
||||
tm._delete_secure_file("nonexistent.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,146 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from crewai_cli import utils
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_tree():
|
||||
root_dir = tempfile.mkdtemp()
|
||||
|
||||
create_file(os.path.join(root_dir, "file1.txt"), "Hello, world!")
|
||||
create_file(os.path.join(root_dir, "file2.txt"), "Another file")
|
||||
os.mkdir(os.path.join(root_dir, "empty_dir"))
|
||||
nested_dir = os.path.join(root_dir, "nested_dir")
|
||||
os.mkdir(nested_dir)
|
||||
create_file(os.path.join(nested_dir, "nested_file.txt"), "Nested content")
|
||||
|
||||
yield root_dir
|
||||
|
||||
shutil.rmtree(root_dir)
|
||||
|
||||
|
||||
def create_file(path, content):
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def test_tree_find_and_replace_file_content(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "world", "universe")
|
||||
with open(os.path.join(temp_tree, "file1.txt"), "r") as f:
|
||||
assert f.read() == "Hello, universe!"
|
||||
|
||||
|
||||
def test_tree_find_and_replace_file_name(temp_tree):
|
||||
old_path = os.path.join(temp_tree, "file2.txt")
|
||||
new_path = os.path.join(temp_tree, "file2_renamed.txt")
|
||||
os.rename(old_path, new_path)
|
||||
utils.tree_find_and_replace(temp_tree, "renamed", "modified")
|
||||
assert os.path.exists(os.path.join(temp_tree, "file2_modified.txt"))
|
||||
assert not os.path.exists(new_path)
|
||||
|
||||
|
||||
def test_tree_find_and_replace_directory_name(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "empty", "renamed")
|
||||
assert os.path.exists(os.path.join(temp_tree, "renamed_dir"))
|
||||
assert not os.path.exists(os.path.join(temp_tree, "empty_dir"))
|
||||
|
||||
|
||||
def test_tree_find_and_replace_nested_content(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "Nested", "Updated")
|
||||
with open(os.path.join(temp_tree, "nested_dir", "nested_file.txt"), "r") as f:
|
||||
assert f.read() == "Updated content"
|
||||
|
||||
|
||||
def test_tree_find_and_replace_no_matches(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "nonexistent", "replacement")
|
||||
assert set(os.listdir(temp_tree)) == {
|
||||
"file1.txt",
|
||||
"file2.txt",
|
||||
"empty_dir",
|
||||
"nested_dir",
|
||||
}
|
||||
|
||||
|
||||
def test_tree_copy_full_structure(temp_tree):
|
||||
dest_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
utils.tree_copy(temp_tree, dest_dir)
|
||||
assert set(os.listdir(dest_dir)) == set(os.listdir(temp_tree))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "file2.txt"))
|
||||
assert os.path.isdir(os.path.join(dest_dir, "empty_dir"))
|
||||
assert os.path.isdir(os.path.join(dest_dir, "nested_dir"))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "nested_dir", "nested_file.txt"))
|
||||
finally:
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
def test_tree_copy_preserve_content(temp_tree):
|
||||
dest_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
utils.tree_copy(temp_tree, dest_dir)
|
||||
with open(os.path.join(dest_dir, "file1.txt"), "r") as f:
|
||||
assert f.read() == "Hello, world!"
|
||||
with open(os.path.join(dest_dir, "nested_dir", "nested_file.txt"), "r") as f:
|
||||
assert f.read() == "Nested content"
|
||||
finally:
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
def test_tree_copy_to_existing_directory(temp_tree):
|
||||
dest_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
create_file(os.path.join(dest_dir, "existing_file.txt"), "I was here first")
|
||||
utils.tree_copy(temp_tree, dest_dir)
|
||||
assert os.path.isfile(os.path.join(dest_dir, "existing_file.txt"))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
|
||||
finally:
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project_dir():
|
||||
"""Create a temporary directory for testing tool extraction."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
def create_init_file(directory, content):
|
||||
return create_file(directory / "__init__.py", content)
|
||||
|
||||
|
||||
def test_extract_available_exports_empty_project(temp_project_dir, capsys):
|
||||
with pytest.raises(SystemExit):
|
||||
utils.extract_available_exports(dir_path=temp_project_dir)
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert "No valid tools were exposed in your __init__.py file" in captured.out
|
||||
|
||||
|
||||
def test_extract_available_exports_no_init_file(temp_project_dir, capsys):
|
||||
(temp_project_dir / "some_file.py").write_text("print('hello')")
|
||||
with pytest.raises(SystemExit):
|
||||
utils.extract_available_exports(dir_path=temp_project_dir)
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert "No valid tools were exposed in your __init__.py file" in captured.out
|
||||
|
||||
|
||||
def test_extract_available_exports_empty_init_file(temp_project_dir, capsys):
|
||||
create_init_file(temp_project_dir, "")
|
||||
with pytest.raises(SystemExit):
|
||||
utils.extract_available_exports(dir_path=temp_project_dir)
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert "Warning: No __all__ defined in" in captured.out
|
||||
|
||||
|
||||
# Tests for extract_available_exports with crewai.tools (BaseTool, @tool)
|
||||
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.
|
||||
|
||||
# Tests for get_crews, get_flows, fetch_crews, is_valid_tool
|
||||
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.
|
||||
@@ -1,372 +0,0 @@
|
||||
"""Test for version management."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai_cli.version import get_crewai_version as _get_ver
|
||||
from crewai_cli.version import (
|
||||
_find_latest_non_yanked_version,
|
||||
_get_cache_file,
|
||||
_is_cache_valid,
|
||||
_is_version_yanked,
|
||||
get_crewai_version,
|
||||
get_latest_version_from_pypi,
|
||||
is_current_version_yanked,
|
||||
is_newer_version_available,
|
||||
)
|
||||
|
||||
|
||||
def test_dynamic_versioning_consistency() -> None:
|
||||
"""Test that dynamic versioning provides consistent version across all access methods."""
|
||||
cli_version = get_crewai_version()
|
||||
package_version = _get_ver()
|
||||
|
||||
assert cli_version == package_version
|
||||
|
||||
assert package_version is not None
|
||||
assert len(package_version.strip()) > 0
|
||||
|
||||
|
||||
class TestVersionChecking:
|
||||
"""Test version checking utilities."""
|
||||
|
||||
def test_get_crewai_version(self) -> None:
|
||||
"""Test getting current crewai version."""
|
||||
version = get_crewai_version()
|
||||
assert isinstance(version, str)
|
||||
assert len(version) > 0
|
||||
|
||||
def test_get_cache_file(self) -> None:
|
||||
"""Test cache file path generation."""
|
||||
cache_file = _get_cache_file()
|
||||
assert isinstance(cache_file, Path)
|
||||
assert cache_file.name == "version_cache.json"
|
||||
|
||||
def test_is_cache_valid_with_fresh_cache(self) -> None:
|
||||
"""Test cache validation with fresh cache."""
|
||||
cache_data = {"timestamp": datetime.now().isoformat(), "version": "1.0.0"}
|
||||
assert _is_cache_valid(cache_data) is True
|
||||
|
||||
def test_is_cache_valid_with_stale_cache(self) -> None:
|
||||
"""Test cache validation with stale cache."""
|
||||
old_time = datetime.now() - timedelta(hours=25)
|
||||
cache_data = {"timestamp": old_time.isoformat(), "version": "1.0.0"}
|
||||
assert _is_cache_valid(cache_data) is False
|
||||
|
||||
def test_is_cache_valid_with_missing_timestamp(self) -> None:
|
||||
"""Test cache validation with missing timestamp."""
|
||||
cache_data = {"version": "1.0.0"}
|
||||
assert _is_cache_valid(cache_data) is False
|
||||
|
||||
@patch("crewai_cli.version.Path.exists")
|
||||
@patch("crewai_cli.version.request.urlopen")
|
||||
def test_get_latest_version_from_pypi_success(
|
||||
self, mock_urlopen: MagicMock, mock_exists: MagicMock
|
||||
) -> None:
|
||||
"""Test successful PyPI version fetch uses releases data."""
|
||||
mock_exists.return_value = False
|
||||
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0": [{"yanked": False}],
|
||||
"2.1.0": [{"yanked": True, "yanked_reason": "bad release"}],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps(
|
||||
{"info": {"version": "2.1.0"}, "releases": releases}
|
||||
).encode()
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
version = get_latest_version_from_pypi()
|
||||
assert version == "2.0.0"
|
||||
|
||||
@patch("crewai_cli.version.Path.exists")
|
||||
@patch("crewai_cli.version.request.urlopen")
|
||||
def test_get_latest_version_from_pypi_failure(
|
||||
self, mock_urlopen: MagicMock, mock_exists: MagicMock
|
||||
) -> None:
|
||||
"""Test PyPI version fetch failure."""
|
||||
from urllib.error import URLError
|
||||
|
||||
mock_exists.return_value = False
|
||||
|
||||
mock_urlopen.side_effect = URLError("Network error")
|
||||
|
||||
version = get_latest_version_from_pypi()
|
||||
assert version is None
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
def test_is_newer_version_available_true(
|
||||
self, mock_latest: MagicMock, mock_current: MagicMock
|
||||
) -> None:
|
||||
"""Test when newer version is available."""
|
||||
mock_current.return_value = "1.0.0"
|
||||
mock_latest.return_value = "2.0.0"
|
||||
|
||||
is_newer, current, latest = is_newer_version_available()
|
||||
assert is_newer is True
|
||||
assert current == "1.0.0"
|
||||
assert latest == "2.0.0"
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
def test_is_newer_version_available_false(
|
||||
self, mock_latest: MagicMock, mock_current: MagicMock
|
||||
) -> None:
|
||||
"""Test when no newer version is available."""
|
||||
mock_current.return_value = "2.0.0"
|
||||
mock_latest.return_value = "2.0.0"
|
||||
|
||||
is_newer, current, latest = is_newer_version_available()
|
||||
assert is_newer is False
|
||||
assert current == "2.0.0"
|
||||
assert latest == "2.0.0"
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
def test_is_newer_version_available_with_none_latest(
|
||||
self, mock_latest: MagicMock, mock_current: MagicMock
|
||||
) -> None:
|
||||
"""Test when PyPI fetch fails."""
|
||||
mock_current.return_value = "1.0.0"
|
||||
mock_latest.return_value = None
|
||||
|
||||
is_newer, current, latest = is_newer_version_available()
|
||||
assert is_newer is False
|
||||
assert current == "1.0.0"
|
||||
assert latest is None
|
||||
|
||||
|
||||
class TestFindLatestNonYankedVersion:
|
||||
"""Test _find_latest_non_yanked_version helper."""
|
||||
|
||||
def test_skips_yanked_versions(self) -> None:
|
||||
"""Test that yanked versions are skipped."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0": [{"yanked": True}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.0.0"
|
||||
|
||||
def test_returns_highest_non_yanked(self) -> None:
|
||||
"""Test that the highest non-yanked version is returned."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"1.5.0": [{"yanked": False}],
|
||||
"2.0.0": [{"yanked": True}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.5.0"
|
||||
|
||||
def test_returns_none_when_all_yanked(self) -> None:
|
||||
"""Test that None is returned when all versions are yanked."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": True}],
|
||||
"2.0.0": [{"yanked": True}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) is None
|
||||
|
||||
def test_skips_prerelease_versions(self) -> None:
|
||||
"""Test that pre-release versions are skipped."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0a1": [{"yanked": False}],
|
||||
"2.0.0rc1": [{"yanked": False}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.0.0"
|
||||
|
||||
def test_skips_versions_with_empty_files(self) -> None:
|
||||
"""Test that versions with no files are skipped."""
|
||||
releases: dict[str, list[dict[str, bool]]] = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0": [],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.0.0"
|
||||
|
||||
def test_handles_invalid_version_strings(self) -> None:
|
||||
"""Test that invalid version strings are skipped."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"not-a-version": [{"yanked": False}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.0.0"
|
||||
|
||||
def test_partially_yanked_files_not_considered_yanked(self) -> None:
|
||||
"""Test that a version with some non-yanked files is not yanked."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0": [{"yanked": True}, {"yanked": False}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "2.0.0"
|
||||
|
||||
|
||||
class TestIsVersionYanked:
|
||||
"""Test _is_version_yanked helper."""
|
||||
|
||||
def test_non_yanked_version(self) -> None:
|
||||
"""Test a non-yanked version returns False."""
|
||||
releases = {"1.0.0": [{"yanked": False}]}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
def test_yanked_version_with_reason(self) -> None:
|
||||
"""Test a yanked version returns True with reason."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": True, "yanked_reason": "critical bug"}],
|
||||
}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is True
|
||||
assert reason == "critical bug"
|
||||
|
||||
def test_yanked_version_without_reason(self) -> None:
|
||||
"""Test a yanked version returns True with empty reason."""
|
||||
releases = {"1.0.0": [{"yanked": True}]}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is True
|
||||
assert reason == ""
|
||||
|
||||
def test_unknown_version(self) -> None:
|
||||
"""Test an unknown version returns False."""
|
||||
releases = {"1.0.0": [{"yanked": False}]}
|
||||
is_yanked, reason = _is_version_yanked("9.9.9", releases)
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
def test_partially_yanked_files(self) -> None:
|
||||
"""Test a version with mixed yanked/non-yanked files is not yanked."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": True}, {"yanked": False}],
|
||||
}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
def test_multiple_yanked_files_picks_first_reason(self) -> None:
|
||||
"""Test that the first available reason is returned."""
|
||||
releases = {
|
||||
"1.0.0": [
|
||||
{"yanked": True, "yanked_reason": ""},
|
||||
{"yanked": True, "yanked_reason": "second reason"},
|
||||
],
|
||||
}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is True
|
||||
assert reason == "second reason"
|
||||
|
||||
|
||||
class TestIsCurrentVersionYanked:
|
||||
"""Test is_current_version_yanked public function."""
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version._get_cache_file")
|
||||
def test_reads_from_valid_cache(
|
||||
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test reading yanked status from a valid cache."""
|
||||
mock_version.return_value = "1.0.0"
|
||||
cache_file = tmp_path / "version_cache.json"
|
||||
cache_data = {
|
||||
"version": "2.0.0",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"current_version": "1.0.0",
|
||||
"current_version_yanked": True,
|
||||
"current_version_yanked_reason": "bad release",
|
||||
}
|
||||
cache_file.write_text(json.dumps(cache_data))
|
||||
mock_cache_file.return_value = cache_file
|
||||
|
||||
is_yanked, reason = is_current_version_yanked()
|
||||
assert is_yanked is True
|
||||
assert reason == "bad release"
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version._get_cache_file")
|
||||
def test_not_yanked_from_cache(
|
||||
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test non-yanked status from a valid cache."""
|
||||
mock_version.return_value = "2.0.0"
|
||||
cache_file = tmp_path / "version_cache.json"
|
||||
cache_data = {
|
||||
"version": "2.0.0",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"current_version": "2.0.0",
|
||||
"current_version_yanked": False,
|
||||
"current_version_yanked_reason": "",
|
||||
}
|
||||
cache_file.write_text(json.dumps(cache_data))
|
||||
mock_cache_file.return_value = cache_file
|
||||
|
||||
is_yanked, reason = is_current_version_yanked()
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version._get_cache_file")
|
||||
def test_triggers_fetch_on_stale_cache(
|
||||
self,
|
||||
mock_cache_file: MagicMock,
|
||||
mock_version: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test that a stale cache triggers a re-fetch."""
|
||||
mock_version.return_value = "1.0.0"
|
||||
cache_file = tmp_path / "version_cache.json"
|
||||
old_time = datetime.now() - timedelta(hours=25)
|
||||
cache_data = {
|
||||
"version": "2.0.0",
|
||||
"timestamp": old_time.isoformat(),
|
||||
"current_version": "1.0.0",
|
||||
"current_version_yanked": True,
|
||||
"current_version_yanked_reason": "old reason",
|
||||
}
|
||||
cache_file.write_text(json.dumps(cache_data))
|
||||
mock_cache_file.return_value = cache_file
|
||||
|
||||
fresh_cache = {
|
||||
"version": "2.0.0",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"current_version": "1.0.0",
|
||||
"current_version_yanked": False,
|
||||
"current_version_yanked_reason": "",
|
||||
}
|
||||
|
||||
def write_fresh_cache() -> str:
|
||||
cache_file.write_text(json.dumps(fresh_cache))
|
||||
return "2.0.0"
|
||||
|
||||
mock_fetch.side_effect = lambda: write_fresh_cache()
|
||||
|
||||
is_yanked, reason = is_current_version_yanked()
|
||||
assert is_yanked is False
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version._get_cache_file")
|
||||
def test_returns_false_on_fetch_failure(
|
||||
self,
|
||||
mock_cache_file: MagicMock,
|
||||
mock_version: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test that fetch failure returns not yanked."""
|
||||
mock_version.return_value = "1.0.0"
|
||||
cache_file = tmp_path / "version_cache.json"
|
||||
mock_cache_file.return_value = cache_file
|
||||
mock_fetch.return_value = None
|
||||
|
||||
is_yanked, reason = is_current_version_yanked()
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
|
||||
|
||||
# TestConsoleFormatterVersionCheck tests remain in lib/crewai/tests/cli/test_version.py
|
||||
# as they depend on crewai.events.utils.console_formatter (core package).
|
||||
@@ -152,4 +152,4 @@ __all__ = [
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.10.2rc2"
|
||||
__version__ = "1.10.2a1"
|
||||
|
||||
@@ -11,7 +11,7 @@ dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.10.2rc2",
|
||||
"crewai==1.10.2a1",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
|
||||
@@ -309,4 +309,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.10.2rc2"
|
||||
__version__ = "1.10.2a1"
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from lancedb import ( # type: ignore[import-untyped]
|
||||
DBConnection as LanceDBConnection,
|
||||
connect as lancedb_connect,
|
||||
@@ -35,12 +33,10 @@ class LanceDBAdapter(Adapter):
|
||||
|
||||
_db: LanceDBConnection = PrivateAttr()
|
||||
_table: LanceDBTable = PrivateAttr()
|
||||
_lock_name: str = PrivateAttr(default="")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._db = lancedb_connect(self.uri)
|
||||
self._table = self._db.open_table(self.table_name)
|
||||
self._lock_name = f"lancedb:{os.path.realpath(str(self.uri))}"
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
@@ -60,5 +56,4 @@ class LanceDBAdapter(Adapter):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
with store_lock(self._lock_name):
|
||||
self._table.add(*args, **kwargs)
|
||||
self._table.add(*args, **kwargs)
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -21,9 +18,6 @@ class BrowserSessionManager:
|
||||
This class maintains separate browser sessions for different threads,
|
||||
enabling concurrent usage of browsers in multi-threaded environments.
|
||||
Browsers are created lazily only when needed by tools.
|
||||
|
||||
Uses per-key events to serialize creation for the same thread_id without
|
||||
blocking unrelated callers or wasting resources on duplicate sessions.
|
||||
"""
|
||||
|
||||
def __init__(self, region: str = "us-west-2"):
|
||||
@@ -33,10 +27,8 @@ class BrowserSessionManager:
|
||||
region: AWS region for browser client
|
||||
"""
|
||||
self.region = region
|
||||
self._lock = threading.Lock()
|
||||
self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
|
||||
self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
|
||||
self._creating: dict[str, threading.Event] = {}
|
||||
|
||||
async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
|
||||
"""Get or create an async browser for the specified thread.
|
||||
@@ -47,29 +39,10 @@ class BrowserSessionManager:
|
||||
Returns:
|
||||
An async browser instance specific to the thread
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
with self._lock:
|
||||
if thread_id in self._async_sessions:
|
||||
return self._async_sessions[thread_id][1]
|
||||
if thread_id not in self._creating:
|
||||
self._creating[thread_id] = threading.Event()
|
||||
break
|
||||
event = self._creating[thread_id]
|
||||
ctx = contextvars.copy_context()
|
||||
await loop.run_in_executor(None, ctx.run, event.wait)
|
||||
if thread_id in self._async_sessions:
|
||||
return self._async_sessions[thread_id][1]
|
||||
|
||||
try:
|
||||
browser_client, browser = await self._create_async_browser_session(
|
||||
thread_id
|
||||
)
|
||||
with self._lock:
|
||||
self._async_sessions[thread_id] = (browser_client, browser)
|
||||
return browser
|
||||
finally:
|
||||
with self._lock:
|
||||
evt = self._creating.pop(thread_id)
|
||||
evt.set()
|
||||
return await self._create_async_browser_session(thread_id)
|
||||
|
||||
def get_sync_browser(self, thread_id: str) -> SyncBrowser:
|
||||
"""Get or create a sync browser for the specified thread.
|
||||
@@ -80,33 +53,19 @@ class BrowserSessionManager:
|
||||
Returns:
|
||||
A sync browser instance specific to the thread
|
||||
"""
|
||||
while True:
|
||||
with self._lock:
|
||||
if thread_id in self._sync_sessions:
|
||||
return self._sync_sessions[thread_id][1]
|
||||
if thread_id not in self._creating:
|
||||
self._creating[thread_id] = threading.Event()
|
||||
break
|
||||
event = self._creating[thread_id]
|
||||
event.wait()
|
||||
if thread_id in self._sync_sessions:
|
||||
return self._sync_sessions[thread_id][1]
|
||||
|
||||
try:
|
||||
return self._create_sync_browser_session(thread_id)
|
||||
finally:
|
||||
with self._lock:
|
||||
evt = self._creating.pop(thread_id)
|
||||
evt.set()
|
||||
return self._create_sync_browser_session(thread_id)
|
||||
|
||||
async def _create_async_browser_session(
|
||||
self, thread_id: str
|
||||
) -> tuple[BrowserClient, AsyncBrowser]:
|
||||
async def _create_async_browser_session(self, thread_id: str) -> AsyncBrowser:
|
||||
"""Create a new async browser session for the specified thread.
|
||||
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
|
||||
Returns:
|
||||
Tuple of (BrowserClient, AsyncBrowser).
|
||||
The newly created async browser instance
|
||||
|
||||
Raises:
|
||||
Exception: If browser session creation fails
|
||||
@@ -116,8 +75,10 @@ class BrowserSessionManager:
|
||||
browser_client = BrowserClient(region=self.region)
|
||||
|
||||
try:
|
||||
# Start browser session
|
||||
browser_client.start()
|
||||
|
||||
# Get WebSocket connection info
|
||||
ws_url, headers = browser_client.generate_ws_headers()
|
||||
|
||||
logger.info(
|
||||
@@ -126,6 +87,7 @@ class BrowserSessionManager:
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
# Connect to browser using Playwright
|
||||
playwright = await async_playwright().start()
|
||||
browser = await playwright.chromium.connect_over_cdp(
|
||||
endpoint_url=ws_url, headers=headers, timeout=30000
|
||||
@@ -134,13 +96,17 @@ class BrowserSessionManager:
|
||||
f"Successfully connected to async browser for thread {thread_id}"
|
||||
)
|
||||
|
||||
return browser_client, browser
|
||||
# Store session resources
|
||||
self._async_sessions[thread_id] = (browser_client, browser)
|
||||
|
||||
return browser
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create async browser session for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Clean up resources if session creation fails
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -166,8 +132,10 @@ class BrowserSessionManager:
|
||||
browser_client = BrowserClient(region=self.region)
|
||||
|
||||
try:
|
||||
# Start browser session
|
||||
browser_client.start()
|
||||
|
||||
# Get WebSocket connection info
|
||||
ws_url, headers = browser_client.generate_ws_headers()
|
||||
|
||||
logger.info(
|
||||
@@ -176,6 +144,7 @@ class BrowserSessionManager:
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
# Connect to browser using Playwright
|
||||
playwright = sync_playwright().start()
|
||||
browser = playwright.chromium.connect_over_cdp(
|
||||
endpoint_url=ws_url, headers=headers, timeout=30000
|
||||
@@ -184,8 +153,8 @@ class BrowserSessionManager:
|
||||
f"Successfully connected to sync browser for thread {thread_id}"
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._sync_sessions[thread_id] = (browser_client, browser)
|
||||
# Store session resources
|
||||
self._sync_sessions[thread_id] = (browser_client, browser)
|
||||
|
||||
return browser
|
||||
|
||||
@@ -194,6 +163,7 @@ class BrowserSessionManager:
|
||||
f"Failed to create sync browser session for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Clean up resources if session creation fails
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -208,13 +178,13 @@ class BrowserSessionManager:
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
"""
|
||||
with self._lock:
|
||||
if thread_id not in self._async_sessions:
|
||||
logger.warning(f"No async browser session found for thread {thread_id}")
|
||||
return
|
||||
if thread_id not in self._async_sessions:
|
||||
logger.warning(f"No async browser session found for thread {thread_id}")
|
||||
return
|
||||
|
||||
browser_client, browser = self._async_sessions.pop(thread_id)
|
||||
browser_client, browser = self._async_sessions[thread_id]
|
||||
|
||||
# Close browser
|
||||
if browser:
|
||||
try:
|
||||
await browser.close()
|
||||
@@ -223,6 +193,7 @@ class BrowserSessionManager:
|
||||
f"Error closing async browser for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Stop browser client
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -231,6 +202,8 @@ class BrowserSessionManager:
|
||||
f"Error stopping browser client for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Remove session from dictionary
|
||||
del self._async_sessions[thread_id]
|
||||
logger.info(f"Async browser session cleaned up for thread {thread_id}")
|
||||
|
||||
def close_sync_browser(self, thread_id: str) -> None:
|
||||
@@ -239,13 +212,13 @@ class BrowserSessionManager:
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
"""
|
||||
with self._lock:
|
||||
if thread_id not in self._sync_sessions:
|
||||
logger.warning(f"No sync browser session found for thread {thread_id}")
|
||||
return
|
||||
if thread_id not in self._sync_sessions:
|
||||
logger.warning(f"No sync browser session found for thread {thread_id}")
|
||||
return
|
||||
|
||||
browser_client, browser = self._sync_sessions.pop(thread_id)
|
||||
browser_client, browser = self._sync_sessions[thread_id]
|
||||
|
||||
# Close browser
|
||||
if browser:
|
||||
try:
|
||||
browser.close()
|
||||
@@ -254,6 +227,7 @@ class BrowserSessionManager:
|
||||
f"Error closing sync browser for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Stop browser client
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -262,17 +236,19 @@ class BrowserSessionManager:
|
||||
f"Error stopping browser client for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Remove session from dictionary
|
||||
del self._sync_sessions[thread_id]
|
||||
logger.info(f"Sync browser session cleaned up for thread {thread_id}")
|
||||
|
||||
async def close_all_browsers(self) -> None:
|
||||
"""Close all browser sessions."""
|
||||
with self._lock:
|
||||
async_thread_ids = list(self._async_sessions.keys())
|
||||
sync_thread_ids = list(self._sync_sessions.keys())
|
||||
|
||||
# Close all async browsers
|
||||
async_thread_ids = list(self._async_sessions.keys())
|
||||
for thread_id in async_thread_ids:
|
||||
await self.close_async_browser(thread_id)
|
||||
|
||||
# Close all sync browsers
|
||||
sync_thread_ids = list(self._sync_sessions.keys())
|
||||
for thread_id in sync_thread_ids:
|
||||
self.close_sync_browser(thread_id)
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
@@ -40,32 +38,22 @@ class RAG(Adapter):
|
||||
_client: Any = PrivateAttr()
|
||||
_collection: Any = PrivateAttr()
|
||||
_embedding_service: EmbeddingService = PrivateAttr()
|
||||
_lock_name: str = PrivateAttr(default="")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
try:
|
||||
self._lock_name = (
|
||||
f"chromadb:{os.path.realpath(self.persist_directory)}"
|
||||
if self.persist_directory
|
||||
else "chromadb:ephemeral"
|
||||
if self.persist_directory:
|
||||
self._client = chromadb.PersistentClient(path=self.persist_directory)
|
||||
else:
|
||||
self._client = chromadb.Client()
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={
|
||||
"hnsw:space": "cosine",
|
||||
"description": "CrewAI Knowledge Base",
|
||||
},
|
||||
)
|
||||
|
||||
with store_lock(self._lock_name):
|
||||
if self.persist_directory:
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=self.persist_directory
|
||||
)
|
||||
else:
|
||||
self._client = chromadb.Client()
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={
|
||||
"hnsw:space": "cosine",
|
||||
"description": "CrewAI Knowledge Base",
|
||||
},
|
||||
)
|
||||
|
||||
self._embedding_service = EmbeddingService(
|
||||
provider=self.embedding_provider,
|
||||
model=self.embedding_model,
|
||||
@@ -99,8 +87,29 @@ class RAG(Adapter):
|
||||
loader_result = loader.load(source_content)
|
||||
doc_id = loader_result.doc_id
|
||||
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
existing_doc = self._collection.get(
|
||||
where={"source": source_content.source_ref}, limit=1
|
||||
)
|
||||
existing_doc_id = (
|
||||
existing_doc and existing_doc["metadatas"][0]["doc_id"]
|
||||
if existing_doc["metadatas"]
|
||||
else None
|
||||
)
|
||||
|
||||
if existing_doc_id == doc_id:
|
||||
logger.warning(
|
||||
f"Document with source {loader_result.source} already exists"
|
||||
)
|
||||
return
|
||||
|
||||
# Document with same source ref does exists but the content has changed, deleting the oldest reference
|
||||
if existing_doc_id and existing_doc_id != loader_result.doc_id:
|
||||
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
|
||||
self._collection.delete(where={"doc_id": existing_doc_id})
|
||||
|
||||
documents = []
|
||||
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc_metadata = (metadata or {}).copy()
|
||||
doc_metadata["chunk_index"] = i
|
||||
@@ -127,6 +136,7 @@ class RAG(Adapter):
|
||||
|
||||
ids = [doc.id for doc in documents]
|
||||
metadatas = []
|
||||
|
||||
for doc in documents:
|
||||
doc_metadata = doc.metadata.copy()
|
||||
doc_metadata.update(
|
||||
@@ -138,36 +148,16 @@ class RAG(Adapter):
|
||||
)
|
||||
metadatas.append(doc_metadata)
|
||||
|
||||
with store_lock(self._lock_name):
|
||||
existing_doc = self._collection.get(
|
||||
where={"source": source_content.source_ref}, limit=1
|
||||
try:
|
||||
self._collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=contents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
existing_doc_id = (
|
||||
existing_doc and existing_doc["metadatas"][0]["doc_id"]
|
||||
if existing_doc["metadatas"]
|
||||
else None
|
||||
)
|
||||
|
||||
if existing_doc_id == doc_id:
|
||||
logger.warning(
|
||||
f"Document with source {loader_result.source} already exists"
|
||||
)
|
||||
return
|
||||
|
||||
if existing_doc_id and existing_doc_id != loader_result.doc_id:
|
||||
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
|
||||
self._collection.delete(where={"doc_id": existing_doc_id})
|
||||
|
||||
try:
|
||||
self._collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=contents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
logger.info(f"Added {len(documents)} documents to knowledge base")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
logger.info(f"Added {len(documents)} documents to knowledge base")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
|
||||
try:
|
||||
@@ -211,8 +201,7 @@ class RAG(Adapter):
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
try:
|
||||
with store_lock(self._lock_name):
|
||||
self._client.delete_collection(self.collection_name)
|
||||
self._client.delete_collection(self.collection_name)
|
||||
logger.info(f"Deleted collection: {self.collection_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -9,8 +10,8 @@ from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
import requests
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
|
||||
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
|
||||
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -30,8 +30,9 @@ class FileWriterTool(BaseTool):
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
try:
|
||||
if kwargs.get("directory"):
|
||||
os.makedirs(kwargs["directory"], exist_ok=True)
|
||||
# Create the directory if it doesn't exist
|
||||
if kwargs.get("directory") and not os.path.exists(kwargs["directory"]):
|
||||
os.makedirs(kwargs["directory"])
|
||||
|
||||
# Construct the full path
|
||||
filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"])
|
||||
|
||||
@@ -99,8 +99,8 @@ class FileCompressorTool(BaseTool):
|
||||
def _prepare_output(output_path: str, overwrite: bool) -> bool:
|
||||
"""Ensures output path is ready for writing."""
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
if os.path.exists(output_path) and not overwrite:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -18,6 +18,7 @@ class MergeAgentHandlerToolError(Exception):
|
||||
"""Base exception for Merge Agent Handler tool errors."""
|
||||
|
||||
|
||||
|
||||
class MergeAgentHandlerTool(BaseTool):
|
||||
"""
|
||||
Wrapper for Merge Agent Handler tools.
|
||||
@@ -173,7 +174,7 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
>>> tool = MergeAgentHandlerTool.from_tool_name(
|
||||
... tool_name="linear__create_issue",
|
||||
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
|
||||
... )
|
||||
"""
|
||||
# Create an empty args schema model (proper BaseModel subclass)
|
||||
@@ -209,10 +210,7 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
if "parameters" in tool_schema:
|
||||
try:
|
||||
params = tool_schema["parameters"]
|
||||
if (
|
||||
params.get("type") == "object"
|
||||
and "properties" in params
|
||||
):
|
||||
if params.get("type") == "object" and "properties" in params:
|
||||
# Build field definitions for Pydantic
|
||||
fields = {}
|
||||
properties = params["properties"]
|
||||
@@ -300,7 +298,7 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
>>> tools = MergeAgentHandlerTool.from_tool_pack(
|
||||
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
|
||||
... tool_names=["linear__create_issue", "linear__get_issues"],
|
||||
... tool_names=["linear__create_issue", "linear__get_issues"]
|
||||
... )
|
||||
"""
|
||||
# Create a temporary instance to fetch the tool list
|
||||
|
||||
@@ -110,13 +110,11 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
self.custom_embedding_fn(query)
|
||||
if self.custom_embedding_fn
|
||||
else (
|
||||
lambda: (
|
||||
__import__("openai")
|
||||
.Client(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
.embeddings.create(input=[query], model="text-embedding-3-large")
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
lambda: __import__("openai")
|
||||
.Client(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
.embeddings.create(input=[query], model="text-embedding-3-large")
|
||||
.data[0]
|
||||
.embedding
|
||||
)()
|
||||
)
|
||||
results = self.client.query_points(
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
@@ -34,7 +33,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for query results
|
||||
_query_cache: dict[str, list[dict[str, Any]]] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
class SnowflakeConfig(BaseModel):
|
||||
@@ -104,7 +102,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
)
|
||||
|
||||
_connection_pool: list[SnowflakeConnection] | None = None
|
||||
_pool_lock: threading.Lock | None = None
|
||||
_pool_lock: asyncio.Lock | None = None
|
||||
_thread_pool: ThreadPoolExecutor | None = None
|
||||
_model_rebuilt: bool = False
|
||||
package_dependencies: list[str] = Field(
|
||||
@@ -124,7 +122,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
try:
|
||||
if SNOWFLAKE_AVAILABLE:
|
||||
self._connection_pool = []
|
||||
self._pool_lock = threading.Lock()
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
else:
|
||||
raise ImportError
|
||||
@@ -149,7 +147,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
)
|
||||
|
||||
self._connection_pool = []
|
||||
self._pool_lock = threading.Lock()
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise ImportError("Failed to install Snowflake dependencies") from e
|
||||
@@ -165,12 +163,13 @@ class SnowflakeSearchTool(BaseTool):
|
||||
raise RuntimeError("Pool lock not initialized")
|
||||
if self._connection_pool is None:
|
||||
raise RuntimeError("Connection pool not initialized")
|
||||
with self._pool_lock:
|
||||
if self._connection_pool:
|
||||
return self._connection_pool.pop()
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
self._thread_pool, self._create_connection
|
||||
)
|
||||
async with self._pool_lock:
|
||||
if not self._connection_pool:
|
||||
conn = await asyncio.get_event_loop().run_in_executor(
|
||||
self._thread_pool, self._create_connection
|
||||
)
|
||||
self._connection_pool.append(conn)
|
||||
return self._connection_pool.pop()
|
||||
|
||||
def _create_connection(self) -> SnowflakeConnection:
|
||||
"""Create a new Snowflake connection."""
|
||||
@@ -205,10 +204,9 @@ class SnowflakeSearchTool(BaseTool):
|
||||
"""Execute a query with retries and return results."""
|
||||
if self.enable_caching:
|
||||
cache_key = self._get_cache_key(query, timeout)
|
||||
with _cache_lock:
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
@@ -227,8 +225,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
]
|
||||
|
||||
if self.enable_caching:
|
||||
with _cache_lock:
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
|
||||
return results
|
||||
finally:
|
||||
@@ -237,7 +234,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
self._pool_lock is not None
|
||||
and self._connection_pool is not None
|
||||
):
|
||||
with self._pool_lock:
|
||||
async with self._pool_lock:
|
||||
self._connection_pool.append(conn)
|
||||
except (DatabaseError, OperationalError) as e: # noqa: PERF203
|
||||
if attempt == self.max_retries - 1:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -138,9 +137,7 @@ class StagehandTool(BaseTool):
|
||||
- 'observe': For finding elements in a specific area
|
||||
"""
|
||||
args_schema: type[BaseModel] = StagehandToolSchema
|
||||
package_dependencies: list[str] = Field(
|
||||
default_factory=lambda: ["stagehand<=0.5.9"]
|
||||
)
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["stagehand<=0.5.9"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -623,12 +620,9 @@ class StagehandTool(BaseTool):
|
||||
# We're in an existing event loop, use it
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
ctx.run,
|
||||
asyncio.run,
|
||||
self._async_run(instruction, url, command_type),
|
||||
asyncio.run, self._async_run(instruction, url, command_type)
|
||||
)
|
||||
result = future.result()
|
||||
else:
|
||||
@@ -712,12 +706,11 @@ class StagehandTool(BaseTool):
|
||||
if loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor() as executor
|
||||
):
|
||||
future = executor.submit(
|
||||
ctx.run, asyncio.run, self._async_close()
|
||||
asyncio.run, self._async_close()
|
||||
)
|
||||
future.result()
|
||||
else:
|
||||
|
||||
@@ -53,7 +53,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.10.2rc2",
|
||||
"crewai-tools==1.10.2a1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
@@ -105,15 +105,10 @@ a2a = [
|
||||
file-processing = [
|
||||
"crewai-files",
|
||||
]
|
||||
cli = [
|
||||
"crewai-cli",
|
||||
]
|
||||
|
||||
|
||||
# CLI entry point has moved to the crewai-cli package.
|
||||
# Install it via: pip install crewai[cli]
|
||||
# [project.scripts]
|
||||
# crewai = "crewai.cli.cli:crewai"
|
||||
[project.scripts]
|
||||
crewai = "crewai.cli.cli:crewai"
|
||||
|
||||
|
||||
# PyTorch index configuration, since torch 2.5.0 is not compatible with python 3.13
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import contextvars
|
||||
import threading
|
||||
from typing import Any
|
||||
import urllib.request
|
||||
@@ -41,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.10.2rc2"
|
||||
__version__ = "1.10.2a1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
@@ -67,8 +66,7 @@ def _track_install() -> None:
|
||||
def _track_install_async() -> None:
|
||||
"""Track installation in background thread to avoid blocking imports."""
|
||||
if not Telemetry._is_telemetry_disabled():
|
||||
ctx = contextvars.copy_context()
|
||||
thread = threading.Thread(target=ctx.run, args=(_track_install,), daemon=True)
|
||||
thread = threading.Thread(target=_track_install, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import MutableMapping
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
from functools import lru_cache
|
||||
import ssl
|
||||
import time
|
||||
@@ -148,9 +147,8 @@ def fetch_agent_card(
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@@ -217,9 +215,8 @@ def _fetch_agent_card_cached(
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import base64
|
||||
from collections.abc import AsyncIterator, Callable, MutableMapping
|
||||
import concurrent.futures
|
||||
from contextlib import asynccontextmanager
|
||||
import contextvars
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
import uuid
|
||||
@@ -230,9 +229,8 @@ def execute_a2a_delegation(
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
from functools import wraps
|
||||
import json
|
||||
from types import MethodType
|
||||
@@ -279,9 +278,7 @@ def _fetch_agent_cards_concurrently(
|
||||
max_workers = min(len(a2a_agents), 10)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(
|
||||
contextvars.copy_context().run, _fetch_card_from_config, config
|
||||
): config
|
||||
executor.submit(_fetch_card_from_config, config): config
|
||||
for config in a2a_agents
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Sequence
|
||||
import contextvars
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -514,13 +513,9 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
ctx.run,
|
||||
self._execute_without_timeout,
|
||||
task_prompt=task_prompt,
|
||||
task=task,
|
||||
self._execute_without_timeout, task_prompt=task_prompt, task=task
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -895,9 +895,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
args_dict, parse_error = parse_tool_call_args(
|
||||
func_args, func_name, call_id, original_tool
|
||||
)
|
||||
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
|
||||
if parse_error is not None:
|
||||
return parse_error
|
||||
|
||||
|
||||
@@ -2,15 +2,19 @@ from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
from crewai_cli.utils import copy_template
|
||||
from crewai.cli.utils import copy_template
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
_printer = Printer()
|
||||
|
||||
|
||||
def add_crew_to_flow(crew_name: str) -> None:
|
||||
"""Add a new crew to the current flow."""
|
||||
# Check if pyproject.toml exists in the current directory
|
||||
if not Path("pyproject.toml").exists():
|
||||
click.secho(
|
||||
"This command must be run from the root of a flow project.", fg="red"
|
||||
_printer.print(
|
||||
"This command must be run from the root of a flow project.", color="red"
|
||||
)
|
||||
raise click.ClickException(
|
||||
"This command must be run from the root of a flow project."
|
||||
@@ -21,7 +25,7 @@ def add_crew_to_flow(crew_name: str) -> None:
|
||||
crews_folder = flow_folder / "src" / flow_folder.name / "crews"
|
||||
|
||||
if not crews_folder.exists():
|
||||
click.secho("Crews folder does not exist in the current flow.", fg="red")
|
||||
_printer.print("Crews folder does not exist in the current flow.", color="red")
|
||||
raise click.ClickException("Crews folder does not exist in the current flow.")
|
||||
|
||||
# Create the crew within the flow's crews directory
|
||||
@@ -180,7 +180,7 @@ class AuthenticationCommand:
|
||||
def _login_to_tool_repository(self) -> None:
|
||||
"""Login to the tool repository."""
|
||||
|
||||
from crewai_cli.tools.main import ToolCommand
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
try:
|
||||
console.print(
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib.metadata import version as get_version
|
||||
import os
|
||||
import subprocess
|
||||
@@ -7,58 +5,44 @@ from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
from crewai_cli.add_crew_to_flow import add_crew_to_flow
|
||||
from crewai_cli.authentication.main import AuthenticationCommand
|
||||
from crewai_cli.config import Settings
|
||||
from crewai_cli.create_crew import create_crew
|
||||
from crewai_cli.create_flow import create_flow
|
||||
from crewai_cli.crew_chat import run_chat
|
||||
from crewai_cli.deploy.main import DeployCommand
|
||||
from crewai_cli.enterprise.main import EnterpriseConfigureCommand
|
||||
from crewai_cli.evaluate_crew import evaluate_crew
|
||||
from crewai_cli.install_crew import install_crew
|
||||
from crewai_cli.kickoff_flow import kickoff_flow
|
||||
from crewai_cli.organization.main import OrganizationCommand
|
||||
from crewai_cli.plot_flow import plot_flow
|
||||
from crewai_cli.replay_from_task import replay_task_command
|
||||
from crewai_cli.reset_memories_command import reset_memories_command
|
||||
from crewai_cli.run_crew import run_crew
|
||||
from crewai_cli.settings.main import SettingsCommand
|
||||
from crewai_cli.task_outputs import load_task_outputs
|
||||
from crewai_cli.tools.main import ToolCommand
|
||||
from crewai_cli.train_crew import train_crew
|
||||
from crewai_cli.triggers.main import TriggersCommand
|
||||
from crewai_cli.update_crew import update_crew
|
||||
from crewai_cli.user_data import (
|
||||
_load_user_data,
|
||||
_save_user_data,
|
||||
is_tracing_enabled,
|
||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||
from crewai.cli.authentication.main import AuthenticationCommand
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.create_crew import create_crew
|
||||
from crewai.cli.create_flow import create_flow
|
||||
from crewai.cli.crew_chat import run_chat
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
|
||||
from crewai.cli.evaluate_crew import evaluate_crew
|
||||
from crewai.cli.install_crew import install_crew
|
||||
from crewai.cli.kickoff_flow import kickoff_flow
|
||||
from crewai.cli.organization.main import OrganizationCommand
|
||||
from crewai.cli.plot_flow import plot_flow
|
||||
from crewai.cli.replay_from_task import replay_task_command
|
||||
from crewai.cli.reset_memories_command import reset_memories_command
|
||||
from crewai.cli.run_crew import run_crew
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
from crewai.cli.train_crew import train_crew
|
||||
from crewai.cli.triggers.main import TriggersCommand
|
||||
from crewai.cli.update_crew import update_crew
|
||||
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
from crewai_cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
|
||||
|
||||
def _get_cli_version() -> str:
|
||||
"""Return the best available version string for the CLI."""
|
||||
# Prefer crewai version if installed (keeps existing UX)
|
||||
try:
|
||||
return get_version("crewai")
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
try:
|
||||
return get_version("crewai-cli")
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(_get_cli_version())
|
||||
@click.version_option(get_version("crewai"))
|
||||
def crewai():
|
||||
"""Top-level command group for crewai."""
|
||||
|
||||
|
||||
@crewai.command(
|
||||
name="uv",
|
||||
context_settings={"ignore_unknown_options": True},
|
||||
context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
),
|
||||
)
|
||||
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
|
||||
def uv(uv_args):
|
||||
@@ -123,7 +107,7 @@ def version(tools):
|
||||
|
||||
if tools:
|
||||
try:
|
||||
tools_version = get_version("crewai-tools")
|
||||
tools_version = get_version("crewai")
|
||||
click.echo(f"crewai tools version: {tools_version}")
|
||||
except Exception:
|
||||
click.echo("crewai tools not installed")
|
||||
@@ -158,7 +142,12 @@ def train(n_iterations: int, filename: str):
|
||||
help="Replay the crew from this task ID, including all subsequent tasks.",
|
||||
)
|
||||
def replay(task_id: str) -> None:
|
||||
"""Replay the crew execution from a specific task."""
|
||||
"""
|
||||
Replay the crew execution from a specific task.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to replay from.
|
||||
"""
|
||||
try:
|
||||
click.echo(f"Replaying the crew from task {task_id}")
|
||||
replay_task_command(task_id)
|
||||
@@ -168,9 +157,12 @@ def replay(task_id: str) -> None:
|
||||
|
||||
@crewai.command()
|
||||
def log_tasks_outputs() -> None:
|
||||
"""Retrieve your latest crew.kickoff() task outputs."""
|
||||
"""
|
||||
Retrieve your latest crew.kickoff() task outputs.
|
||||
"""
|
||||
try:
|
||||
tasks = load_task_outputs()
|
||||
storage = KickoffTaskOutputsSQLiteStorage()
|
||||
tasks = storage.load()
|
||||
|
||||
if not tasks:
|
||||
click.echo(
|
||||
@@ -190,24 +182,15 @@ def log_tasks_outputs() -> None:
|
||||
@crewai.command()
|
||||
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
|
||||
@click.option(
|
||||
"-l",
|
||||
"--long",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
"-l", "--long", is_flag=True, hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-s",
|
||||
"--short",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
"-s", "--short", is_flag=True, hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-e",
|
||||
"--entities",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
"-e", "--entities", is_flag=True, hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
|
||||
@@ -228,17 +211,14 @@ def reset_memories(
|
||||
agent_knowledge: bool,
|
||||
all: bool,
|
||||
) -> None:
|
||||
"""Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved."""
|
||||
"""
|
||||
Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved.
|
||||
"""
|
||||
try:
|
||||
# Treat legacy flags as --memory with a deprecation warning
|
||||
if long or short or entities:
|
||||
legacy_used = [
|
||||
f
|
||||
for f, v in [
|
||||
("--long", long),
|
||||
("--short", short),
|
||||
("--entities", entities),
|
||||
]
|
||||
if v
|
||||
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
|
||||
]
|
||||
click.echo(
|
||||
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
|
||||
@@ -258,7 +238,9 @@ def reset_memories(
|
||||
"Please specify at least one memory type to reset using the appropriate flags."
|
||||
)
|
||||
return
|
||||
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
|
||||
reset_memories_command(
|
||||
memory, knowledge, agent_knowledge, kickoff_outputs, all
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"An error occurred while resetting memories: {e}", err=True)
|
||||
|
||||
@@ -296,7 +278,7 @@ def memory(
|
||||
) -> None:
|
||||
"""Open the Memory TUI to browse scopes and recall memories."""
|
||||
try:
|
||||
from crewai_cli.memory_tui import MemoryTUI
|
||||
from crewai.cli.memory_tui import MemoryTUI
|
||||
except ImportError as exc:
|
||||
click.echo(
|
||||
"Textual is required for the memory TUI but could not be imported. "
|
||||
@@ -346,10 +328,10 @@ def test(n_iterations: int, model: str):
|
||||
|
||||
|
||||
@crewai.command(
|
||||
context_settings={
|
||||
"ignore_unknown_options": True,
|
||||
"allow_extra_args": True,
|
||||
}
|
||||
context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
allow_extra_args=True,
|
||||
)
|
||||
)
|
||||
@click.pass_context
|
||||
def install(context):
|
||||
@@ -514,12 +496,14 @@ def triggers_run(trigger_path: str):
|
||||
|
||||
@crewai.command()
|
||||
def chat():
|
||||
"""Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
"""
|
||||
Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
and using the Chat LLM to generate responses.
|
||||
"""
|
||||
click.secho(
|
||||
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
|
||||
)
|
||||
|
||||
run_chat()
|
||||
|
||||
|
||||
@@ -630,7 +614,7 @@ def env_view():
|
||||
table.add_row(
|
||||
"CREWAI_TRACING_ENABLED",
|
||||
"[dim]Not set[/dim]",
|
||||
"[dim]---[/dim]",
|
||||
"[dim]—[/dim]",
|
||||
)
|
||||
|
||||
# Check other related env vars
|
||||
@@ -649,7 +633,7 @@ def env_view():
|
||||
# Check if .env file exists
|
||||
table.add_row(
|
||||
".env file",
|
||||
"Found" if env_file_exists else "Not found",
|
||||
"✅ Found" if env_file_exists else "❌ Not found",
|
||||
str(env_file.resolve()) if env_file_exists else "N/A",
|
||||
)
|
||||
|
||||
@@ -665,11 +649,11 @@ def env_view():
|
||||
# Show helpful message
|
||||
if env_file_exists:
|
||||
console.print(
|
||||
"\n[dim]Tip: To enable tracing via .env, add: CREWAI_TRACING_ENABLED=true[/dim]"
|
||||
"\n[dim]💡 Tip: To enable tracing via .env, add: CREWAI_TRACING_ENABLED=true[/dim]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"\n[dim]Tip: Create a .env file in your project root and add: CREWAI_TRACING_ENABLED=true[/dim]"
|
||||
"\n[dim]💡 Tip: Create a .env file in your project root and add: CREWAI_TRACING_ENABLED=true[/dim]"
|
||||
)
|
||||
console.print()
|
||||
|
||||
@@ -685,6 +669,11 @@ def traces_enable():
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
_load_user_data,
|
||||
_save_user_data,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
# Update user data to enable traces
|
||||
@@ -694,7 +683,7 @@ def traces_enable():
|
||||
_save_user_data(user_data)
|
||||
|
||||
panel = Panel(
|
||||
"Trace collection has been enabled!\n\n"
|
||||
"✅ Trace collection has been enabled!\n\n"
|
||||
"Your crew/flow executions will now send traces to CrewAI+.\n"
|
||||
"Use 'crewai traces disable' to turn off trace collection.",
|
||||
title="Traces Enabled",
|
||||
@@ -710,6 +699,11 @@ def traces_disable():
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
_load_user_data,
|
||||
_save_user_data,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
# Update user data to disable traces
|
||||
@@ -719,7 +713,7 @@ def traces_disable():
|
||||
_save_user_data(user_data)
|
||||
|
||||
panel = Panel(
|
||||
"Trace collection has been disabled!\n\n"
|
||||
"❌ Trace collection has been disabled!\n\n"
|
||||
"Your crew/flow executions will no longer send traces.\n"
|
||||
"Use 'crewai traces enable' to turn trace collection back on.",
|
||||
title="Traces Disabled",
|
||||
@@ -738,6 +732,11 @@ def traces_status():
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
_load_user_data,
|
||||
is_tracing_enabled,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
user_data = _load_user_data()
|
||||
|
||||
@@ -752,19 +751,19 @@ def traces_status():
|
||||
# Check user consent
|
||||
trace_consent = user_data.get("trace_consent")
|
||||
if trace_consent is True:
|
||||
consent_status = "Enabled (user consented)"
|
||||
consent_status = "✅ Enabled (user consented)"
|
||||
elif trace_consent is False:
|
||||
consent_status = "Disabled (user declined)"
|
||||
consent_status = "❌ Disabled (user declined)"
|
||||
else:
|
||||
consent_status = "Not set (first-time user)"
|
||||
consent_status = "⚪ Not set (first-time user)"
|
||||
table.add_row("User Consent", consent_status)
|
||||
|
||||
# Check overall status
|
||||
if is_tracing_enabled():
|
||||
overall_status = "ENABLED"
|
||||
overall_status = "✅ ENABLED"
|
||||
border_style = "green"
|
||||
else:
|
||||
overall_status = "DISABLED"
|
||||
overall_status = "❌ DISABLED"
|
||||
border_style = "red"
|
||||
table.add_row("Overall Status", overall_status)
|
||||
|
||||
@@ -5,13 +5,13 @@ import sys
|
||||
import click
|
||||
import tomli
|
||||
|
||||
from crewai_cli.constants import ENV_VARS, MODELS
|
||||
from crewai_cli.provider import (
|
||||
from crewai.cli.constants import ENV_VARS, MODELS
|
||||
from crewai.cli.provider import (
|
||||
get_provider_data,
|
||||
select_model,
|
||||
select_provider,
|
||||
)
|
||||
from crewai_cli.utils import copy_template, load_env_vars, write_env_file
|
||||
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
||||
|
||||
|
||||
def get_reserved_script_names() -> set[str]:
|
||||
@@ -3,6 +3,8 @@ import shutil
|
||||
|
||||
import click
|
||||
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
|
||||
def create_flow(name):
|
||||
"""Create a new flow."""
|
||||
@@ -16,6 +18,10 @@ def create_flow(name):
|
||||
click.secho(f"Error: Folder {folder_name} already exists.", fg="red")
|
||||
return
|
||||
|
||||
# Initialize telemetry
|
||||
telemetry = Telemetry()
|
||||
telemetry.flow_creation_span(class_name)
|
||||
|
||||
# Create directory structure
|
||||
(project_root / "src" / folder_name).mkdir(parents=True)
|
||||
(project_root / "src" / folder_name / "crews").mkdir(parents=True)
|
||||
@@ -1,4 +1,3 @@
|
||||
import contextvars
|
||||
import json
|
||||
from pathlib import Path
|
||||
import platform
|
||||
@@ -81,10 +80,7 @@ def run_chat() -> None:
|
||||
|
||||
# Start loading indicator
|
||||
loading_complete = threading.Event()
|
||||
ctx = contextvars.copy_context()
|
||||
loading_thread = threading.Thread(
|
||||
target=ctx.run, args=(show_loading, loading_complete)
|
||||
)
|
||||
loading_thread = threading.Thread(target=show_loading, args=(loading_complete,))
|
||||
loading_thread.start()
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from crewai_cli import git
|
||||
from crewai_cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai_cli.utils import fetch_and_json_env_file, get_project_name
|
||||
from crewai.cli import git
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.utils import fetch_and_json_env_file, get_project_name
|
||||
|
||||
|
||||
console = Console()
|
||||
@@ -22,43 +21,8 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
self.project_name = get_project_name(require=True)
|
||||
self._validate_project_structure()
|
||||
|
||||
def _validate_project_structure(self) -> None:
|
||||
"""Validate that the local project has the files required for deployment."""
|
||||
errors: list[str] = []
|
||||
|
||||
if not Path("pyproject.toml").exists():
|
||||
errors.append("Cannot find pyproject.toml in the current directory.")
|
||||
|
||||
has_lockfile = Path("uv.lock").exists() or Path("poetry.lock").exists()
|
||||
if not has_lockfile:
|
||||
errors.append(
|
||||
"No uv.lock or poetry.lock found. "
|
||||
"Run 'uv lock' or 'poetry lock' to generate one."
|
||||
)
|
||||
|
||||
src_dir = Path("src") / (self.project_name or "")
|
||||
crew_py = src_dir / "crew.py"
|
||||
config_dir = src_dir / "config"
|
||||
if not crew_py.exists() and not config_dir.exists():
|
||||
errors.append(
|
||||
f"Cannot find src/{self.project_name}/crew.py or "
|
||||
f"src/{self.project_name}/config. "
|
||||
"Ensure you are running this command from the project root."
|
||||
)
|
||||
|
||||
if errors:
|
||||
console.print(
|
||||
"\n[bold red]Pre-flight check failed:[/bold red] "
|
||||
"Your project is missing required files for deployment.\n"
|
||||
)
|
||||
for error in errors:
|
||||
console.print(f" • {error}", style="red")
|
||||
console.print()
|
||||
raise SystemExit(1)
|
||||
|
||||
def _standard_no_param_error_message(self) -> None:
|
||||
"""
|
||||
@@ -103,6 +67,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to deploy.
|
||||
"""
|
||||
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
|
||||
console.print("Starting deployment...", style="bold blue")
|
||||
if uuid:
|
||||
response = self.plus_api_client.deploy_by_uuid(uuid)
|
||||
@@ -119,6 +84,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
Create a new crew deployment.
|
||||
"""
|
||||
self._create_crew_deployment_span = (
|
||||
self._telemetry.create_crew_deployment_span()
|
||||
)
|
||||
console.print("Creating deployment...", style="bold blue")
|
||||
env_vars = fetch_and_json_env_file()
|
||||
|
||||
@@ -268,6 +236,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
uuid (Optional[str]): The UUID of the crew to get logs for.
|
||||
log_type (str): The type of logs to retrieve (default: "deployment").
|
||||
"""
|
||||
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
|
||||
console.print(f"Fetching {log_type} logs...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
@@ -288,6 +257,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to remove.
|
||||
"""
|
||||
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
|
||||
console.print("Removing deployment...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
@@ -4,10 +4,10 @@ from typing import Any, cast
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings, ProviderFactory
|
||||
from crewai_cli.command import BaseCommand
|
||||
from crewai_cli.settings.main import SettingsCommand
|
||||
from crewai_cli.version import get_crewai_version
|
||||
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
|
||||
from crewai.cli.command import BaseCommand
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
console = Console()
|
||||
@@ -125,19 +125,13 @@ class MemoryTUI(App[None]):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
storage = (
|
||||
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
)
|
||||
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
embedder = None
|
||||
if embedder_config is not None:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
embedder = build_embedder(embedder_config)
|
||||
self._memory = (
|
||||
Memory(storage=storage, embedder=embedder)
|
||||
if embedder
|
||||
else Memory(storage=storage)
|
||||
)
|
||||
self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage)
|
||||
except Exception as e:
|
||||
self._init_error = str(e)
|
||||
|
||||
@@ -206,7 +200,11 @@ class MemoryTUI(App[None]):
|
||||
if len(record.content) > 80
|
||||
else record.content
|
||||
)
|
||||
label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
|
||||
label = (
|
||||
f"{date_str} "
|
||||
f"[bold]{record.importance:.1f}[/] "
|
||||
f"{preview}"
|
||||
)
|
||||
option_list.add_option(label)
|
||||
|
||||
def _populate_recall_list(self) -> None:
|
||||
@@ -222,7 +220,9 @@ class MemoryTUI(App[None]):
|
||||
else m.record.content
|
||||
)
|
||||
label = (
|
||||
f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
|
||||
f"[bold]\\[{m.score:.2f}][/] "
|
||||
f"{preview} "
|
||||
f"[dim]scope={m.record.scope}[/]"
|
||||
)
|
||||
option_list.add_option(label)
|
||||
|
||||
@@ -251,7 +251,8 @@ class MemoryTUI(App[None]):
|
||||
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
|
||||
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
|
||||
lines.append(
|
||||
f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
f"[dim]Created:[/] "
|
||||
f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
lines.append(
|
||||
f"[dim]Last accessed:[/] "
|
||||
@@ -361,11 +362,17 @@ class MemoryTUI(App[None]):
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
panel.loading = True
|
||||
try:
|
||||
scope = self._selected_scope if self._selected_scope != "/" else None
|
||||
scope = (
|
||||
self._selected_scope
|
||||
if self._selected_scope != "/"
|
||||
else None
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
matches = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
|
||||
lambda: self._memory.recall(
|
||||
query, scope=scope, limit=10, depth="deep"
|
||||
),
|
||||
)
|
||||
self._recall_matches = matches or []
|
||||
self._view_mode = "recall"
|
||||
@@ -2,8 +2,8 @@ from httpx import HTTPStatusError
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from crewai_cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai_cli.config import Settings
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.config import Settings
|
||||
|
||||
|
||||
console = Console()
|
||||
@@ -12,7 +12,7 @@ console = Console()
|
||||
class OrganizationCommand(BaseCommand, PlusAPIMixin):
|
||||
def __init__(self) -> None:
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def list(self) -> None:
|
||||
try:
|
||||
@@ -95,7 +95,9 @@ def reset_memories_command(
|
||||
continue
|
||||
if memory:
|
||||
_reset_flow_memory(flow)
|
||||
click.echo(f"[Flow ({flow_name})] Memory has been reset.")
|
||||
click.echo(
|
||||
f"[Flow ({flow_name})] Memory has been reset."
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
|
||||
@@ -5,8 +5,8 @@ import subprocess
|
||||
import click
|
||||
from packaging import version
|
||||
|
||||
from crewai_cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
from crewai_cli.version import get_crewai_version
|
||||
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class CrewType(Enum):
|
||||
@@ -5,9 +5,9 @@ from typing import Any
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from crewai_cli.command import BaseCommand
|
||||
from crewai_cli.config import HIDDEN_SETTINGS_KEYS, READONLY_SETTINGS_KEYS, Settings
|
||||
from crewai_cli.user_data import _load_user_data
|
||||
from crewai.cli.command import BaseCommand
|
||||
from crewai.cli.config import HIDDEN_SETTINGS_KEYS, READONLY_SETTINGS_KEYS, Settings
|
||||
from crewai.events.listeners.tracing.utils import _load_user_data
|
||||
|
||||
|
||||
console = Console()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user