Compare commits

..

19 Commits

Author SHA1 Message Date
github-actions[bot]
2c78e60f56 chore: update tool specifications 2026-03-10 17:32:23 +00:00
Lorenze Jay
8e336a476f Merge branch 'main' into lorenze/feat/grep-tool 2026-03-10 10:31:04 -07:00
github-actions[bot]
2d0e81c10d chore: update tool specifications 2026-02-17 22:35:58 +00:00
Lorenze Jay
c8dd6c006c Merge branch 'main' into lorenze/feat/grep-tool 2026-02-17 14:34:36 -08:00
lorenzejay
73f44c878d Merge branch 'lorenze/feat/grep-tool' of github.com:crewAIInc/crewAI into lorenze/feat/grep-tool 2026-02-12 10:29:58 -08:00
lorenzejay
364143a682 fix test 2026-02-12 10:29:46 -08:00
github-actions[bot]
f894d8cf9d chore: update tool specifications 2026-02-12 18:29:36 +00:00
lorenzejay
1f0265781a Merge branch 'lorenze/feat/grep-tool' of github.com:crewAIInc/crewAI into lorenze/feat/grep-tool 2026-02-12 10:28:16 -08:00
lorenzejay
9fae6c0adf feat: enhance GrepTool with sensitive file exclusion and file size limit
- Added MAX_CONTEXT_LINES to define the upper limit for context lines shown in search results.
- Introduced MAX_FILE_SIZE_BYTES to skip files larger than 10 MB during searches.
- Implemented logic to exclude sensitive files (e.g., .env, .netrc) from search results to prevent accidental leakage of credentials.
- Updated tests to validate sensitive file exclusion and file size limits, ensuring robustness in handling sensitive content.
2026-02-12 10:27:24 -08:00
Lorenze Jay
dea2e1e715 Merge branch 'main' into lorenze/feat/grep-tool 2026-02-12 09:24:15 -08:00
github-actions[bot]
b97fc83656 chore: update tool specifications 2026-02-12 04:47:03 +00:00
lorenzejay
925ed7850e linted 2026-02-11 20:45:40 -08:00
lorenzejay
ec2b6a0287 feat: enhance GrepTool with regex length limit, path restrictions, and brace expansion support
- Added MAX_REGEX_LENGTH to limit regex pattern length and prevent ReDoS.
- Introduced allow_unrestricted_paths option to enable searching outside the current working directory.
- Implemented brace expansion for glob patterns to support multiple file types.
- Enhanced error handling for path traversal and regex compilation.
- Updated tests to cover new features and ensure robustness.
2026-02-11 20:44:46 -08:00
Lorenze Jay
25835ca795 Merge branch 'main' into lorenze/feat/grep-tool 2026-02-11 14:23:35 -08:00
Lorenze Jay
e65940816b Merge branch 'main' into lorenze/feat/grep-tool 2026-02-09 11:28:49 -08:00
Lorenze Jay
ad2435f5c1 Merge branch 'main' into lorenze/feat/grep-tool 2026-02-05 12:02:33 -08:00
github-actions[bot]
c9971a7418 chore: update tool specifications 2026-02-04 19:52:01 +00:00
lorenzejay
f04bedc9ab moved to tools 2026-02-04 11:50:43 -08:00
Lorenze Jay
5a14007511 native support for grep 2026-02-04 10:28:35 -08:00
243 changed files with 9408 additions and 12983 deletions

View File

@@ -55,7 +55,6 @@ jobs:
echo "${{ steps.changed-files.outputs.files }}" \
| tr ' ' '\n' \
| grep -v 'src/crewai/cli/templates/' \
| grep -v 'src/crewai_cli/templates/' \
| grep -v '/tests/' \
| xargs -I{} uv run ruff check "{}"

View File

@@ -59,8 +59,6 @@ jobs:
contents: read
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.release_tag || github.ref }}
- name: Install uv
uses: astral-sh/setup-uv@v6
@@ -95,72 +93,3 @@ jobs:
echo "Some packages failed to publish"
exit 1
fi
- name: Build Slack payload
if: success()
id: slack
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
RELEASE_TAG: ${{ inputs.release_tag }}
run: |
payload=$(uv run python -c "
import json, re, subprocess, sys
with open('lib/crewai/src/crewai/__init__.py') as f:
m = re.search(r\"__version__\s*=\s*[\\\"']([^\\\"']+)\", f.read())
version = m.group(1) if m else 'unknown'
import os
tag = os.environ.get('RELEASE_TAG') or version
try:
r = subprocess.run(['gh','release','view',tag,'--json','body','-q','.body'],
capture_output=True, text=True, check=True)
body = r.stdout.strip()
except Exception:
body = ''
blocks = [
{'type':'section','text':{'type':'mrkdwn',
'text':f':rocket: \`crewai v{version}\` published to PyPI'}},
{'type':'section','text':{'type':'mrkdwn',
'text':f'<https://pypi.org/project/crewai/{version}/|View on PyPI> · <https://github.com/crewAIInc/crewAI/releases/tag/{tag}|Release notes>'}},
{'type':'divider'},
]
if body:
heading, items = '', []
for line in body.split('\n'):
line = line.strip()
if not line: continue
hm = re.match(r'^#{2,3}\s+(.*)', line)
if hm:
if heading and items:
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
if not skip:
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
heading, items = hm.group(1), []
elif line.startswith('- ') or line.startswith('* '):
items.append(re.sub(r'\*\*([^*]*)\*\*', r'*\1*', line[2:]))
if heading and items:
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
if not skip:
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
blocks.append({'type':'divider'})
blocks.append({'type':'section','text':{'type':'mrkdwn',
'text':f'\`\`\`uv add \"crewai[tools]=={version}\"\`\`\`'}})
print(json.dumps({'blocks':blocks}))
")
echo "payload=$payload" >> $GITHUB_OUTPUT
- name: Notify Slack
if: success()
uses: slackapi/slack-github-action@v2.1.0
with:
webhook: ${{ secrets.SLACK_WEBHOOK_URL }}
webhook-type: incoming-webhook
payload: ${{ steps.slack.outputs.payload }}

View File

@@ -19,7 +19,7 @@ repos:
language: system
pass_filenames: true
types: [python]
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/cli/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.9.3
hooks:

View File

@@ -226,7 +226,7 @@ def vcr_cassette_dir(request: Any) -> str:
for parent in test_file.parents:
if (
parent.name in ("crewai", "crewai-tools", "crewai-files", "cli")
parent.name in ("crewai", "crewai-tools", "crewai-files")
and parent.parent.name == "lib"
):
package_root = parent

View File

@@ -4,82 +4,6 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
icon: "clock"
mode: "wide"
---
<Update label="Mar 14, 2026">
## v1.10.2rc2
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
## What's Changed
### Bug Fixes
- Remove exclusive locks from read-only storage operations
### Documentation
- Update changelog and version for v1.10.2rc1
## Contributors
@greysonlalonde
</Update>
<Update label="Mar 13, 2026">
## v1.10.2rc1
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
## What's Changed
### Features
- Add release command and trigger PyPI publish
### Bug Fixes
- Fix cross-process and thread-safe locking to unprotected I/O
- Propagate contextvars across all thread and executor boundaries
- Propagate ContextVars into async task threads
### Documentation
- Update changelog and version for v1.10.2a1
## Contributors
@danglies007, @greysonlalonde
</Update>
<Update label="Mar 11, 2026">
## v1.10.2a1
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2a1)
## What's Changed
### Features
- Add support for tool search, saving tokens, and dynamically injecting appropriate tools during execution for Anthropics.
- Introduce more Brave Search tools.
- Create action for nightly releases.
### Bug Fixes
- Fix LockException under concurrent multi-process execution.
- Resolve issues with grouping parallel tool results in a single user message.
- Address MCP tools resolutions and eliminate all shared mutable connections.
- Update LLM parameter handling in the human_feedback function.
- Add missing list/dict methods to LockedListProxy and LockedDictProxy.
- Propagate contextvars context to parallel tool call threads.
- Bump gitpython dependency to >=3.1.41 to resolve CVE path traversal vulnerability.
### Refactoring
- Refactor memory classes to be serializable.
### Documentation
- Update changelog and version for v1.10.1.
## Contributors
@akaKuruma, @github-actions[bot], @giulio-leone, @greysonlalonde, @joaomdmoura, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha
</Update>
<Update label="Mar 04, 2026">
## v1.10.1

View File

@@ -4,82 +4,6 @@ description: "CrewAI의 제품 업데이트, 개선 사항 및 버그 수정"
icon: "clock"
mode: "wide"
---
<Update label="2026년 3월 14일">
## v1.10.2rc2
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
## 변경 사항
### 버그 수정
- 읽기 전용 스토리지 작업에서 독점 잠금 제거
### 문서
- v1.10.2rc1에 대한 변경 로그 및 버전 업데이트
## 기여자
@greysonlalonde
</Update>
<Update label="2026년 3월 13일">
## v1.10.2rc1
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
## 변경 사항
### 기능
- 릴리스 명령 추가 및 PyPI 게시 트리거
### 버그 수정
- 보호되지 않은 I/O에 대한 프로세스 간 및 스레드 안전 잠금 수정
- 모든 스레드 및 실행기 경계를 넘는 contextvars 전파
- async 작업 스레드로 ContextVars 전파
### 문서
- v1.10.2a1에 대한 변경 로그 및 버전 업데이트
## 기여자
@danglies007, @greysonlalonde
</Update>
<Update label="2026년 3월 11일">
## v1.10.2a1
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2a1)
## 변경 사항
### 기능
- Anthropics에 대한 도구 검색 지원 추가, 토큰 저장, 실행 중 적절한 도구를 동적으로 주입하는 기능 추가.
- 더 많은 Brave Search 도구 도입.
- 야간 릴리스를 위한 액션 생성.
### 버그 수정
- 동시 다중 프로세스 실행 중 LockException 수정.
- 단일 사용자 메시지에서 병렬 도구 결과 그룹화 문제 해결.
- MCP 도구 해상도 문제 해결 및 모든 공유 가변 연결 제거.
- human_feedback 함수에서 LLM 매개변수 처리 업데이트.
- LockedListProxy 및 LockedDictProxy에 누락된 list/dict 메서드 추가.
- 병렬 도구 호출 스레드에 contextvars 컨텍스트 전파.
- CVE 경로 탐색 취약점을 해결하기 위해 gitpython 의존성을 >=3.1.41로 업데이트.
### 리팩토링
- 메모리 클래스를 직렬화 가능하도록 리팩토링.
### 문서
- v1.10.1에 대한 변경 로그 및 버전 업데이트.
## 기여자
@akaKuruma, @github-actions[bot], @giulio-leone, @greysonlalonde, @joaomdmoura, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha
</Update>
<Update label="2026년 3월 4일">
## v1.10.1

View File

@@ -4,82 +4,6 @@ description: "Atualizações de produto, melhorias e correções do CrewAI"
icon: "clock"
mode: "wide"
---
<Update label="14 mar 2026">
## v1.10.2rc2
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
## O que Mudou
### Correções de Bugs
- Remover bloqueios exclusivos de operações de armazenamento somente leitura
### Documentação
- Atualizar changelog e versão para v1.10.2rc1
## Contribuidores
@greysonlalonde
</Update>
<Update label="13 mar 2026">
## v1.10.2rc1
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
## O que Mudou
### Funcionalidades
- Adicionar comando de lançamento e acionar publicação no PyPI
### Correções de Bugs
- Corrigir bloqueio seguro entre processos e threads para I/O não protegido
- Propagar contextvars através de todos os limites de thread e executor
- Propagar ContextVars para threads de tarefas assíncronas
### Documentação
- Atualizar changelog e versão para v1.10.2a1
## Contribuidores
@danglies007, @greysonlalonde
</Update>
<Update label="11 mar 2026">
## v1.10.2a1
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2a1)
## O que mudou
### Recursos
- Adicionar suporte para busca de ferramentas, salvamento de tokens e injeção dinâmica de ferramentas apropriadas durante a execução para Anthropics.
- Introduzir mais ferramentas de Busca Brave.
- Criar ação para lançamentos noturnos.
### Correções de Bugs
- Corrigir LockException durante a execução concorrente de múltiplos processos.
- Resolver problemas com a agrupação de resultados de ferramentas paralelas em uma única mensagem de usuário.
- Abordar resoluções de ferramentas MCP e eliminar todas as conexões mutáveis compartilhadas.
- Atualizar o manuseio de parâmetros LLM na função human_feedback.
- Adicionar métodos de lista/dicionário ausentes a LockedListProxy e LockedDictProxy.
- Propagar o contexto de contextvars para as threads de chamada de ferramentas paralelas.
- Atualizar a dependência gitpython para >=3.1.41 para resolver a vulnerabilidade de travessia de diretórios CVE.
### Refatoração
- Refatorar classes de memória para serem serializáveis.
### Documentação
- Atualizar o changelog e a versão para v1.10.1.
## Contribuidores
@akaKuruma, @github-actions[bot], @giulio-leone, @greysonlalonde, @joaomdmoura, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha
</Update>
<Update label="04 mar 2026">
## v1.10.1

View File

@@ -1,15 +0,0 @@
# crewai-cli
CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework.
## Installation
```bash
pip install crewai-cli
```
Or install alongside the full framework:
```bash
pip install crewai[cli]
```

View File

@@ -1,39 +0,0 @@
[project]
name = "crewai-cli"
version = "1.10.0"
description = "CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework."
readme = "README.md"
authors = [
{ name = "Joao Moura", email = "joao@crewai.com" }
]
requires-python = ">=3.10, <3.14"
dependencies = [
"click~=8.1.7",
"pydantic~=2.11.9",
"pydantic-settings~=2.10.1",
"appdirs~=1.4.4",
"httpx~=0.28.1",
"pyjwt>=2.9.0,<3",
"rich>=13.7.1",
"tomli~=2.0.2",
"tomli-w~=1.1.0",
"packaging>=23.0",
"python-dotenv~=1.1.1",
"uv~=0.9.13",
"portalocker~=2.7.0",
]
[project.urls]
Homepage = "https://crewai.com"
Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[project.scripts]
crewai = "crewai_cli.cli:crewai"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/crewai_cli"]

View File

@@ -1 +0,0 @@
__version__ = "1.10.0"

View File

@@ -1,4 +0,0 @@
from crewai_cli.authentication.main import AuthenticationCommand
__all__ = ["AuthenticationCommand"]

View File

@@ -1,23 +0,0 @@
"""Wrapper for the crew chat command.
Delegates to ``crewai.utilities.crew_chat.run_chat`` when the full crewai
package is installed, otherwise prints a helpful error message.
"""
from __future__ import annotations
import click
def run_chat() -> None:
try:
from crewai.utilities.crew_chat import run_chat as _run_chat
except ImportError:
click.secho(
"The 'chat' command requires the full crewai package.\n"
"Install it with: pip install crewai",
fg="red",
)
raise SystemExit(1) from None
_run_chat()

View File

@@ -1,210 +0,0 @@
import os
from typing import Any
from urllib.parse import urljoin
import httpx
from crewai_cli.config import Settings
from crewai_cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai_cli.version import get_crewai_version
class PlusAPI:
"""
This class exposes methods for working with the CrewAI+ API.
"""
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
ORGANIZATIONS_RESOURCE = "/crewai_plus/api/v1/me/organizations"
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
TRACING_RESOURCE = "/crewai_plus/api/v1/tracing"
EPHEMERAL_TRACING_RESOURCE = "/crewai_plus/api/v1/tracing/ephemeral"
INTEGRATIONS_RESOURCE = "/crewai_plus/api/v1/integrations"
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
"X-Crewai-Version": get_crewai_version(),
}
settings = Settings()
if settings.org_uuid:
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
self.base_url = (
os.getenv("CREWAI_PLUS_URL")
or str(settings.enterprise_base_url)
or DEFAULT_CREWAI_ENTERPRISE_URL
)
def _make_request(
self, method: str, endpoint: str, **kwargs: Any
) -> httpx.Response:
url = urljoin(self.base_url, endpoint)
verify = kwargs.pop("verify", True)
with httpx.Client(trust_env=False, verify=verify) as client:
return client.request(method, url, headers=self.headers, **kwargs)
def login_to_tool_repository(self) -> httpx.Response:
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
def get_tool(self, handle: str) -> httpx.Response:
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
async def get_agent(self, handle: str) -> httpx.Response:
url = urljoin(self.base_url, f"{self.AGENTS_RESOURCE}/{handle}")
async with httpx.AsyncClient() as client:
return await client.get(url, headers=self.headers)
def publish_tool(
self,
handle: str,
is_public: bool,
version: str,
description: str | None,
encoded_file: str,
available_exports: list[dict[str, Any]] | None = None,
) -> httpx.Response:
params = {
"handle": handle,
"public": is_public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": available_exports,
}
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
def deploy_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
def crew_status_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
)
def crew_status_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment"
) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def crew_by_uuid(self, uuid: str, log_type: str = "deployment") -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
)
def delete_crew_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
)
def delete_crew_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
def list_crews(self) -> httpx.Response:
return self._make_request("GET", self.CREWS_RESOURCE)
def create_crew(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
def get_organizations(self) -> httpx.Response:
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
def initialize_trace_batch(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches",
json=payload,
timeout=30,
)
def initialize_ephemeral_trace_batch(
self, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
json=payload,
)
def send_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def send_ephemeral_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def finalize_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
json=payload,
timeout=30,
)
def finalize_ephemeral_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
json=payload,
timeout=30,
)
def mark_trace_batch_as_failed(
self, trace_batch_id: str, error_message: str
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
json={"status": "failed", "failure_reason": error_message},
timeout=30,
)
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
"""Get MCP server configurations for the given slugs."""
return self._make_request(
"GET",
f"{self.INTEGRATIONS_RESOURCE}/mcp_configs",
params={"slugs": ",".join(slugs)},
timeout=30,
)
def get_triggers(self) -> httpx.Response:
"""Get all available triggers from integrations."""
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
def get_trigger_payload(self, app_slug: str, trigger_slug: str) -> httpx.Response:
"""Get sample payload for a specific trigger."""
return self._make_request(
"GET", f"{self.INTEGRATIONS_RESOURCE}/{app_slug}/{trigger_slug}/payload"
)

View File

@@ -1,31 +0,0 @@
"""Wrapper for the reset-memories command.
Delegates to ``crewai.utilities.reset_memories`` when the full crewai
package is installed, otherwise prints a helpful error message.
"""
from __future__ import annotations
import click
def reset_memories_command(
memory: bool,
knowledge: bool,
agent_knowledge: bool,
kickoff_outputs: bool,
all: bool,
) -> None:
try:
from crewai.utilities.reset_memories import (
reset_memories_command as _reset,
)
except ImportError:
click.secho(
"The 'reset-memories' command requires the full crewai package.\n"
"Install it with: pip install crewai",
fg="red",
)
raise SystemExit(1) from None
_reset(memory, knowledge, agent_knowledge, kickoff_outputs, all)

View File

@@ -1,54 +0,0 @@
"""Lightweight SQLite reader for kickoff task outputs.
Only used by the ``crewai log-tasks-outputs`` CLI command. Depends solely on
the standard library + *appdirs* so crewai-cli can read stored outputs without
importing the full crewai framework.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
import sqlite3
from typing import Any
from crewai_cli.user_data import _db_storage_path
logger = logging.getLogger(__name__)
def load_task_outputs(db_path: str | None = None) -> list[dict[str, Any]]:
"""Return all rows from the kickoff task outputs database."""
if db_path is None:
db_path = str(Path(_db_storage_path()) / "latest_kickoff_task_outputs.db")
if not Path(db_path).exists():
return []
try:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT *
FROM latest_kickoff_task_outputs
ORDER BY task_index
""")
rows = cursor.fetchall()
results: list[dict[str, Any]] = [
{
"task_id": row[0],
"expected_output": row[1],
"output": json.loads(row[2]),
"task_index": row[3],
"inputs": json.loads(row[4]),
"was_replayed": row[5],
"timestamp": row[6],
}
for row in rows
]
return results
except sqlite3.Error as e:
logger.error("Failed to load task outputs: %s", e)
return []

View File

@@ -1,66 +0,0 @@
"""Standalone user-data helpers for the CLI package.
These mirror the functions in ``crewai.events.listeners.tracing.utils`` but
depend only on the standard library + *appdirs* so that crewai-cli can work
without importing the full crewai framework.
"""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import Any, cast
import appdirs
logger = logging.getLogger(__name__)
def _get_project_directory_name() -> str:
return os.environ.get("CREWAI_STORAGE_DIR", Path.cwd().name)
def _db_storage_path() -> str:
app_name = _get_project_directory_name()
app_author = "CrewAI"
data_dir = Path(appdirs.user_data_dir(app_name, app_author))
data_dir.mkdir(parents=True, exist_ok=True)
return str(data_dir)
def _user_data_file() -> Path:
base = Path(_db_storage_path())
base.mkdir(parents=True, exist_ok=True)
return base / ".crewai_user.json"
def _load_user_data() -> dict[str, Any]:
p = _user_data_file()
if p.exists():
try:
return cast(dict[str, Any], json.loads(p.read_text()))
except (json.JSONDecodeError, OSError, PermissionError) as e:
logger.warning("Failed to load user data: %s", e)
return {}
def _save_user_data(data: dict[str, Any]) -> None:
try:
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
except (OSError, PermissionError) as e:
logger.warning("Failed to save user data: %s", e)
def is_tracing_enabled() -> bool:
"""Check if tracing is enabled (mirrors crewai core logic)."""
data = _load_user_data()
if (
data.get("first_execution_done", False)
and data.get("trace_consent", False) is False
):
return False
return os.getenv("CREWAI_TRACING_ENABLED", "false").lower() == "true"

View File

@@ -1,369 +0,0 @@
from __future__ import annotations
from functools import reduce
from inspect import getmro, isclass
import os
from pathlib import Path
import shutil
import sys
from typing import Any, cast
import click
from rich.console import Console
import tomli
from crewai_cli.config import Settings
from crewai_cli.constants import ENV_VARS
if sys.version_info >= (3, 11):
import tomllib
console = Console()
def copy_template(
src: Path, dst: Path, name: str, class_name: str, folder_name: str
) -> None:
"""Copy a file from src to dst."""
with open(src, "r") as file:
content = file.read()
content = content.replace("{{name}}", name)
content = content.replace("{{crew_name}}", class_name)
content = content.replace("{{folder_name}}", folder_name)
with open(dst, "w") as file:
file.write(content)
click.secho(f" - Created {dst}", fg="green")
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
"""Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f:
return tomli.load(f)
def parse_toml(content: str) -> dict[str, Any]:
if sys.version_info >= (3, 11):
return tomllib.loads(content)
return tomli.loads(content)
def get_project_name(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project name from the pyproject.toml file."""
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
def get_project_version(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project version from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "version"], require=require
)
def get_project_description(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project description from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "description"], require=require
)
def _get_project_attribute(
pyproject_path: str, keys: list[str], require: bool
) -> Any | None:
"""Get an attribute from the pyproject.toml file."""
attribute = None
try:
with open(pyproject_path, "r") as f:
pyproject_content = parse_toml(f.read())
dependencies = (
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
)
if not any(True for dep in dependencies if "crewai" in dep):
raise Exception("crewai is not in the dependencies.")
attribute = _get_nested_value(pyproject_content, keys)
except FileNotFoundError:
console.print(f"Error: {pyproject_path} not found.", style="bold red")
except KeyError:
console.print(
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
style="bold red",
)
except Exception as e:
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
console.print(
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
)
else:
console.print(
f"Error reading the pyproject.toml file: {e}", style="bold red"
)
if require and not attribute:
console.print(
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
style="bold red",
)
raise SystemExit
return attribute
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
"""Fetch the environment variables from a .env file and return them as a dictionary."""
try:
with open(env_file_path, "r") as f:
env_content = f.read()
env_dict = {}
for line in env_content.splitlines():
if line.strip() and not line.strip().startswith("#"):
key, value = line.split("=", 1)
env_dict[key.strip()] = value.strip()
return env_dict
except FileNotFoundError:
console.print(f"Error: {env_file_path} not found.", style="bold red")
except Exception as e:
console.print(f"Error reading the .env file: {e}", style="bold red")
return {}
def tree_copy(source: Path, destination: Path) -> None:
"""Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source):
source_item = os.path.join(source, item)
destination_item = os.path.join(destination, item)
if os.path.isdir(source_item):
shutil.copytree(source_item, destination_item)
else:
shutil.copy2(source_item, destination_item)
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
"""Recursively searches through a directory, replacing a target string in
both file contents and filenames with a specified replacement string.
"""
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
for filename in files:
filepath = os.path.join(path, filename)
with open(filepath, "r", encoding="utf-8", errors="ignore") as file:
contents = file.read()
with open(filepath, "w") as file:
file.write(contents.replace(find, replace))
if find in filename:
new_filename = filename.replace(find, replace)
new_filepath = os.path.join(path, new_filename)
os.rename(filepath, new_filepath)
for dirname in dirs:
if find in dirname:
new_dirname = dirname.replace(find, replace)
new_dirpath = os.path.join(path, new_dirname)
old_dirpath = os.path.join(path, dirname)
os.rename(old_dirpath, new_dirpath)
def load_env_vars(folder_path: Path) -> dict[str, Any]:
"""Loads environment variables from a .env file in the specified folder path."""
env_file_path = folder_path / ".env"
env_vars = {}
if env_file_path.exists():
with open(env_file_path, "r") as file:
for line in file:
key, _, value = line.strip().partition("=")
if key and value:
env_vars[key] = value
return env_vars
def update_env_vars(
env_vars: dict[str, Any], provider: str, model: str
) -> dict[str, Any] | None:
"""Updates environment variables with the API key for the selected provider and model."""
provider_config = cast(
list[str],
ENV_VARS.get(
provider,
[
click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
)
],
),
)
api_key_var = provider_config[0]
if api_key_var not in env_vars:
try:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
)
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return None
else:
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
env_vars["MODEL"] = model
click.secho(f"Selected model: {model}", fg="green")
return env_vars
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
"""Writes environment variables to a .env file in the specified folder."""
env_file_path = folder_path / ".env"
with open(env_file_path, "w") as file:
for key, value in env_vars.items():
file.write(f"{key.upper()}={value}\n")
def is_valid_tool(obj: Any) -> bool:
"""Check if an object is a valid tool class.
Works without importing crewai by checking MRO class names.
Falls back to crewai's ``is_valid_tool`` when available.
"""
try:
from crewai.utilities.project_utils import is_valid_tool as _core_is_valid_tool
return _core_is_valid_tool(obj)
except ImportError:
pass
if isclass(obj):
try:
return any(base.__name__ == "BaseTool" for base in getmro(obj))
except (TypeError, AttributeError):
return False
return False
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
"""Extract available tool classes from the project's __init__.py files."""
try:
init_files = Path(dir_path).glob("**/__init__.py")
available_exports: list[dict[str, Any]] = []
for init_file in init_files:
tools = _load_tools_from_init(init_file)
available_exports.extend(tools)
if not available_exports:
_print_no_tools_warning()
raise SystemExit(1)
return available_exports
except SystemExit:
raise
except Exception as e:
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
console.print(
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
)
raise SystemExit(1) from e
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
"""Load and validate tools from a given __init__.py file."""
import importlib.util as _importlib_util
spec = _importlib_util.spec_from_file_location("temp_module", init_file)
if not spec or not spec.loader:
return []
module = _importlib_util.module_from_spec(spec)
sys.modules["temp_module"] = module
try:
spec.loader.exec_module(module)
if not hasattr(module, "__all__"):
console.print(
f"Warning: No __all__ defined in {init_file}",
style="bold yellow",
)
raise SystemExit(1)
return [
{"name": name}
for name in module.__all__
if hasattr(module, name) and is_valid_tool(getattr(module, name))
]
except SystemExit:
raise
except Exception as e:
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
raise SystemExit(1) from e
finally:
sys.modules.pop("temp_module", None)
def _print_no_tools_warning() -> None:
"""Display warning and usage instructions if no tools were found."""
console.print(
"\n[bold yellow]Warning: No valid tools were exposed in your __init__.py file![/bold yellow]"
)
console.print(
"Your __init__.py file must contain all classes that inherit from [bold]BaseTool[/bold] "
"or functions decorated with [bold]@tool[/bold]."
)
console.print(
"\nExample:\n[dim]# In your __init__.py file[/dim]\n"
"[green]__all__ = ['YourTool', 'your_tool_function'][/green]\n\n"
"[dim]# In your tool.py file[/dim]\n"
"[green]from crewai.tools import BaseTool, tool\n\n"
"# Tool class example\n"
"class YourTool(BaseTool):\n"
' name = "your_tool"\n'
' description = "Your tool description"\n'
" # ... rest of implementation\n\n"
"# Decorated function example\n"
"@tool\n"
"def your_tool_function(text: str) -> str:\n"
' """Your tool description"""\n'
" # ... implementation\n"
" return result\n"
)
def build_env_with_tool_repository_credentials(
repository_handle: str,
) -> dict[str, Any]:
repository_handle = repository_handle.upper().replace("-", "_")
settings = Settings()
env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or ""
)
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or ""
)
return env

View File

@@ -1,91 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.auth0 import Auth0Provider
class TestAuth0Provider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="auth0",
domain="test-domain.auth0.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = Auth0Provider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = Auth0Provider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "auth0"
assert provider.settings.domain == "test-domain.auth0.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.auth0.com/oauth/device/code"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="my-company.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://my-company.auth0.com/oauth/device/code"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.auth0.com/oauth/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="another-domain.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://another-domain.auth0.com/oauth/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.auth0.com/.well-known/jwks.json"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="dev.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://dev.auth0.com/.well-known/jwks.json"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.auth0.com/"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="prod.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_issuer = "https://prod.auth0.com/"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -1,141 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.entra_id import EntraIdProvider
class TestEntraIdProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "openid profile email api://crewai-cli-dev/read"
}
)
self.provider = EntraIdProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = EntraIdProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "entra_id"
assert provider.settings.domain == "tenant-id-abcdef123456"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="my-company.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="another-domain.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="dev.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="other-tenant-id-xpto",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="entra_id",
domain="test-tenant-id",
client_id="test-client-id",
audience=None,
)
provider = EntraIdProvider(settings)
with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["scope"])
def test_get_oauth_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
def test_base_url(self):
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"

View File

@@ -1,138 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.keycloak import KeycloakProvider
class TestKeycloakProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="keycloak",
domain="keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
self.provider = KeycloakProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = KeycloakProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "keycloak"
assert provider.settings.domain == "keycloak.example.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
assert provider.settings.extra.get("realm") == "test-realm"
def test_get_authorize_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/auth/device"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="auth.company.com",
client_id="test-client",
audience="test-audience",
extra={
"realm": "my-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://auth.company.com/realms/my-realm/protocol/openid-connect/auth/device"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="sso.enterprise.com",
client_id="test-client",
audience="test-audience",
extra={
"realm": "enterprise-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://sso.enterprise.com/realms/enterprise-realm/protocol/openid-connect/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/certs"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="identity.org",
client_id="test-client",
audience="test-audience",
extra={
"realm": "org-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://identity.org/realms/org-realm/protocol/openid-connect/certs"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://keycloak.example.com/realms/test-realm"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="login.myapp.io",
client_id="test-client",
audience="test-audience",
extra={
"realm": "app-realm"
}
)
provider = KeycloakProvider(settings)
expected_issuer = "https://login.myapp.io/realms/app-realm"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert self.provider.get_required_fields() == ["realm"]
def test_oauth2_base_url(self):
assert self.provider._oauth2_base_url() == "https://keycloak.example.com"
def test_oauth2_base_url_strips_https_prefix(self):
settings = Oauth2Settings(
provider="keycloak",
domain="https://keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
provider = KeycloakProvider(settings)
assert provider._oauth2_base_url() == "https://keycloak.example.com"
def test_oauth2_base_url_strips_http_prefix(self):
settings = Oauth2Settings(
provider="keycloak",
domain="http://keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
provider = KeycloakProvider(settings)
assert provider._oauth2_base_url() == "https://keycloak.example.com"

View File

@@ -1,257 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.okta import OktaProvider
class TestOktaProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience="test-audience",
)
self.provider = OktaProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = OktaProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "okta"
assert provider.settings.domain == "test-domain.okta.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="my-company.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="another-domain.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="dev.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.okta.com/oauth2/default"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="prod.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_issuer = "https://prod.okta.com/oauth2/default"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
)
provider = OktaProvider(settings)
with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
def test_oauth2_base_url(self):
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
def test_oauth2_base_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
def test_oauth2_base_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"

View File

@@ -1,100 +0,0 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.workos import WorkosProvider
class TestWorkosProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = WorkosProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = WorkosProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "workos"
assert provider.settings.domain == "login.company.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.company.com/oauth2/device_authorization"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="login.example.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://login.example.com/oauth2/device_authorization"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.company.com/oauth2/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="api.workos.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://api.workos.com/oauth2/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.company.com/oauth2/jwks"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="auth.enterprise.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://auth.enterprise.com/oauth2/jwks"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.company.com"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="sso.company.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_issuer = "https://sso.company.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_fallback_to_default(self):
settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience=None
)
provider = WorkosProvider(settings)
assert provider.get_audience() == ""
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -1,348 +0,0 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, call, patch
import pytest
import httpx
from crewai_cli.authentication.main import AuthenticationCommand
from crewai_cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
)
class TestAuthenticationCommand:
def setup_method(self):
# Mock Settings so we always use default constants regardless of local config.
with patch("crewai_cli.authentication.main.Settings") as mock_settings:
instance = mock_settings.return_value
instance.oauth2_provider = "workos"
instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN
instance.oauth2_client_id = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID
instance.oauth2_audience = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE
instance.oauth2_extra = {}
self.auth_command = AuthenticationCommand()
@pytest.mark.parametrize(
"user_provider,expected_urls",
[
(
"workos",
{
"device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization",
"token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token",
"client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
},
),
],
)
@patch("crewai_cli.authentication.main.AuthenticationCommand._get_device_code")
@patch(
"crewai_cli.authentication.main.AuthenticationCommand._display_auth_instructions"
)
@patch("crewai_cli.authentication.main.AuthenticationCommand._poll_for_token")
@patch("crewai_cli.authentication.main.console.print")
def test_login(
self,
mock_console_print,
mock_poll,
mock_display,
mock_get_device,
user_provider,
expected_urls,
):
mock_get_device.return_value = {
"device_code": "test_code",
"user_code": "123456",
}
self.auth_command.login()
mock_console_print.assert_called_once_with(
"Signing in to CrewAI AMP...\n", style="bold blue"
)
mock_get_device.assert_called_once()
mock_display.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"}
)
mock_poll.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"},
)
assert (
self.auth_command.oauth2_provider.get_client_id()
== expected_urls["client_id"]
)
assert (
self.auth_command.oauth2_provider.get_audience()
== expected_urls["audience"]
)
assert (
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
)
@patch("crewai_cli.authentication.main.webbrowser")
@patch("crewai_cli.authentication.main.console.print")
def test_display_auth_instructions(self, mock_console_print, mock_webbrowser):
device_code_data = {
"verification_uri_complete": "https://example.com/auth",
"user_code": "123456",
}
self.auth_command._display_auth_instructions(device_code_data)
expected_calls = [
call("1. Navigate to: ", "https://example.com/auth"),
call("2. Enter the following code: ", "123456"),
]
mock_console_print.assert_has_calls(expected_calls)
mock_webbrowser.open.assert_called_once_with("https://example.com/auth")
@pytest.mark.parametrize(
"user_provider,jwt_config",
[
(
"workos",
{
"jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks",
"issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}",
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
},
),
],
)
@pytest.mark.parametrize("has_expiration", [True, False])
@patch("crewai_cli.authentication.main.validate_jwt_token")
@patch("crewai_cli.authentication.main.TokenManager.save_tokens")
def test_validate_and_save_token(
self,
mock_save_tokens,
mock_validate_jwt,
user_provider,
jwt_config,
has_expiration,
):
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.workos import WorkosProvider
if user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(
settings=Oauth2Settings(
provider=user_provider,
client_id="test-client-id",
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
audience=jwt_config["audience"],
)
)
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
if has_expiration:
future_timestamp = int((datetime.now() + timedelta(days=100)).timestamp())
decoded_token = {"exp": future_timestamp}
else:
decoded_token = {}
mock_validate_jwt.return_value = decoded_token
self.auth_command._validate_and_save_token(token_data)
mock_validate_jwt.assert_called_once_with(
jwt_token="test_access_token",
jwks_url=jwt_config["jwks_url"],
issuer=jwt_config["issuer"],
audience=jwt_config["audience"],
)
if has_expiration:
mock_save_tokens.assert_called_once_with(
"test_access_token", future_timestamp
)
else:
mock_save_tokens.assert_called_once_with("test_access_token", 0)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai_cli.authentication.main.Settings")
@patch("crewai_cli.authentication.main.console.print")
def test_login_to_tool_repository_success(
self, mock_console_print, mock_settings, mock_tool_command
):
mock_tool_instance = MagicMock()
mock_tool_command.return_value = mock_tool_instance
mock_settings_instance = MagicMock()
mock_settings_instance.org_name = "Test Org"
mock_settings_instance.org_uuid = "test-uuid-123"
mock_settings.return_value = mock_settings_instance
self.auth_command._login_to_tool_repository()
mock_tool_command.assert_called_once()
mock_tool_instance.login.assert_called_once()
expected_calls = [
call(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
),
call("Success!\n", style="bold green"),
call(
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
style="green",
),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai_cli.authentication.main.console.print")
def test_login_to_tool_repository_error(
self, mock_console_print, mock_tool_command
):
mock_tool_instance = MagicMock()
mock_tool_instance.login.side_effect = Exception("Tool repository error")
mock_tool_command.return_value = mock_tool_instance
self.auth_command._login_to_tool_repository()
mock_tool_command.assert_called_once()
mock_tool_instance.login.assert_called_once()
expected_calls = [
call(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
),
call(
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
style="yellow",
),
call(
"Other features will work normally, but you may experience limitations with downloading and publishing tools.\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.authentication.main.httpx.post")
def test_get_device_code(self, mock_post):
mock_response = MagicMock()
mock_response.json.return_value = {
"device_code": "test_device_code",
"user_code": "123456",
"verification_uri_complete": "https://example.com/auth",
}
mock_post.return_value = mock_response
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
"https://example.com/device"
)
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
result = self.auth_command._get_device_code()
mock_post.assert_called_once_with(
url="https://example.com/device",
data={
"client_id": "test_client",
"scope": "openid profile email",
"audience": "test_audience",
},
timeout=20,
)
assert result == {
"device_code": "test_device_code",
"user_code": "123456",
"verification_uri_complete": "https://example.com/auth",
}
@patch("crewai_cli.authentication.main.httpx.post")
@patch("crewai_cli.authentication.main.console.print")
def test_poll_for_token_success(self, mock_console_print, mock_post):
mock_response_success = MagicMock()
mock_response_success.status_code = 200
mock_response_success.json.return_value = {
"access_token": "test_access_token",
"id_token": "test_id_token",
}
mock_post.return_value = mock_response_success
device_code_data = {"device_code": "test_device_code", "interval": 1}
with (
patch.object(
self.auth_command, "_validate_and_save_token"
) as mock_validate,
patch.object(
self.auth_command, "_login_to_tool_repository"
) as mock_tool_login,
):
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_token_url.return_value = (
"https://example.com/token"
)
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command._poll_for_token(device_code_data)
mock_post.assert_called_once_with(
"https://example.com/token",
data={
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": "test_device_code",
"client_id": "test_client",
},
timeout=30,
)
mock_validate.assert_called_once()
mock_tool_login.assert_called_once()
expected_calls = [
call("\nWaiting for authentication... ", style="bold blue", end=""),
call("Success!", style="bold green"),
call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.authentication.main.httpx.post")
@patch("crewai_cli.authentication.main.console.print")
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
mock_response_pending = MagicMock()
mock_response_pending.status_code = 400
mock_response_pending.json.return_value = {"error": "authorization_pending"}
mock_post.return_value = mock_response_pending
device_code_data = {
"device_code": "test_device_code",
"interval": 0.1, # Short interval for testing
}
self.auth_command._poll_for_token(device_code_data)
mock_console_print.assert_any_call(
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
@patch("crewai_cli.authentication.main.httpx.post")
def test_poll_for_token_error(self, mock_post):
"""Test the method to poll for token (error path)."""
# Setup mock to return error
mock_response_error = MagicMock()
mock_response_error.status_code = 400
mock_response_error.json.return_value = {
"error": "access_denied",
"error_description": "User denied access",
}
mock_post.return_value = mock_response_error
device_code_data = {"device_code": "test_device_code", "interval": 1}
with pytest.raises(httpx.HTTPError):
self.auth_command._poll_for_token(device_code_data)

View File

@@ -1,107 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import jwt
from crewai_cli.authentication.utils import validate_jwt_token
@patch("crewai_cli.authentication.utils.PyJWKClient", return_value=MagicMock())
@patch("crewai_cli.authentication.utils.jwt")
class TestUtils(unittest.TestCase):
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.return_value = {"exp": 1719859200}
# Create signing key object mock with a .key attribute
mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock(
key="mock_signing_key"
)
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
decoded_token = validate_jwt_token(
jwt_token=jwt_token,
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
mock_jwt.decode.assert_called_with(
jwt_token,
"mock_signing_key",
algorithms=["RS256"],
audience="app_id_xxxx",
issuer="https://mock_issuer",
leeway=10.0,
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
mock_pyjwkclient.assert_called_once_with("https://mock_jwks_url")
self.assertEqual(decoded_token, {"exp": 1719859200})
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_missing_required_claims(
self, mock_jwt, mock_pyjwkclient
):
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidTokenError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)

View File

@@ -1,255 +0,0 @@
from pathlib import Path
from unittest import mock
import pytest
from click.testing import CliRunner
from crewai_cli.cli import (
deploy_create,
deploy_list,
deploy_logs,
deploy_push,
deploy_remove,
deply_status,
flow_add_crew,
login,
reset_memories,
test,
train,
version,
)
@pytest.fixture
def runner():
return CliRunner()
@mock.patch("crewai_cli.cli.train_crew")
def test_train_default_iterations(train_crew, runner):
result = runner.invoke(train)
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the Crew for 5 iterations" in result.output
@mock.patch("crewai_cli.cli.train_crew")
def test_train_custom_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "10"])
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the Crew for 10 iterations" in result.output
@mock.patch("crewai_cli.cli.train_crew")
def test_train_invalid_string_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "invalid"])
train_crew.assert_not_called()
assert result.exit_code == 2
assert (
"Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
in result.output
)
def test_reset_no_memory_flags(runner):
result = runner.invoke(
reset_memories,
)
assert (
result.output
== "Please specify at least one memory type to reset using the appropriate flags.\n"
)
def test_version_flag(runner):
result = runner.invoke(version)
assert result.exit_code == 0
assert "crewai version:" in result.output
def test_version_command(runner):
result = runner.invoke(version)
assert result.exit_code == 0
assert "crewai version:" in result.output
def test_version_command_with_tools(runner):
result = runner.invoke(version, ["--tools"])
assert result.exit_code == 0
assert "crewai version:" in result.output
assert (
"crewai tools version:" in result.output
or "crewai tools not installed" in result.output
)
@mock.patch("crewai_cli.cli.evaluate_crew")
def test_test_default_iterations(evaluate_crew, runner):
result = runner.invoke(test)
evaluate_crew.assert_called_once_with(3, "gpt-4o-mini")
assert result.exit_code == 0
assert "Testing the crew for 3 iterations with model gpt-4o-mini" in result.output
@mock.patch("crewai_cli.cli.evaluate_crew")
def test_test_custom_iterations(evaluate_crew, runner):
result = runner.invoke(test, ["--n_iterations", "5", "--model", "gpt-4o"])
evaluate_crew.assert_called_once_with(5, "gpt-4o")
assert result.exit_code == 0
assert "Testing the crew for 5 iterations with model gpt-4o" in result.output
@mock.patch("crewai_cli.cli.evaluate_crew")
def test_test_invalid_string_iterations(evaluate_crew, runner):
result = runner.invoke(test, ["--n_iterations", "invalid"])
evaluate_crew.assert_not_called()
assert result.exit_code == 2
assert (
"Usage: test [OPTIONS]\nTry 'test --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
in result.output
)
@mock.patch("crewai_cli.cli.AuthenticationCommand")
def test_login(command, runner):
mock_auth = command.return_value
result = runner.invoke(login)
assert result.exit_code == 0
mock_auth.login.assert_called_once()
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_create(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_create)
assert result.exit_code == 0
mock_deploy.create_crew.assert_called_once()
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_list(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_list)
assert result.exit_code == 0
mock_deploy.list_crews.assert_called_once()
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_push(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_push, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.deploy.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_push_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_push)
assert result.exit_code == 0
mock_deploy.deploy.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_status(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deply_status, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.get_crew_status.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_status_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deply_status)
assert result.exit_code == 0
mock_deploy.get_crew_status.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_logs(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_logs, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.get_crew_logs.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_logs_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_logs)
assert result.exit_code == 0
mock_deploy.get_crew_logs.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_remove(command, runner):
mock_deploy = command.return_value
uuid = "test-uuid"
result = runner.invoke(deploy_remove, ["-u", uuid])
assert result.exit_code == 0
mock_deploy.remove_crew.assert_called_once_with(uuid=uuid)
@mock.patch("crewai_cli.cli.DeployCommand")
def test_deploy_remove_no_uuid(command, runner):
mock_deploy = command.return_value
result = runner.invoke(deploy_remove)
assert result.exit_code == 0
mock_deploy.remove_crew.assert_called_once_with(uuid=None)
@mock.patch("crewai_cli.add_crew_to_flow.create_embedded_crew")
@mock.patch("pathlib.Path.exists", return_value=True)
def test_flow_add_crew(mock_path_exists, mock_create_embedded_crew, runner):
crew_name = "new_crew"
result = runner.invoke(flow_add_crew, [crew_name])
assert result.exit_code == 0, f"Command failed with output: {result.output}"
assert f"Adding crew {crew_name} to the flow" in result.output
mock_create_embedded_crew.assert_called_once()
call_args, call_kwargs = mock_create_embedded_crew.call_args
assert call_args[0] == crew_name
assert "parent_folder" in call_kwargs
assert isinstance(call_kwargs["parent_folder"], Path)
def test_add_crew_to_flow_not_in_root(runner):
with mock.patch("pathlib.Path.exists", autospec=True) as mock_exists:
def exists_side_effect(self):
if self.name == "pyproject.toml":
return False
return True
mock_exists.side_effect = exists_side_effect
result = runner.invoke(flow_add_crew, ["new_crew"])
assert result.exit_code != 0
assert "This command must be run from the root of a flow project." in str(
result.output
)

View File

@@ -1,148 +0,0 @@
import json
import shutil
import tempfile
import unittest
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
from crewai_cli.config import (
CLI_SETTINGS_KEYS,
DEFAULT_CLI_SETTINGS,
USER_SETTINGS_KEYS,
Settings,
)
from crewai_cli.shared.token_manager import TokenManager
class TestSettings(unittest.TestCase):
def setUp(self):
self.test_dir = Path(tempfile.mkdtemp())
self.config_path = self.test_dir / "settings.json"
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_empty_initialization(self):
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)
self.assertIsNone(settings.tool_repository_password)
def test_initialization_with_data(self):
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
)
self.assertEqual(settings.tool_repository_username, "user1")
self.assertIsNone(settings.tool_repository_password)
def test_initialization_with_existing_file(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump({"tool_repository_username": "file_user"}, f)
settings = Settings(config_path=self.config_path)
self.assertEqual(settings.tool_repository_username, "file_user")
def test_merge_file_and_input_data(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump(
{
"tool_repository_username": "file_user",
"tool_repository_password": "file_pass",
},
f,
)
settings = Settings(
config_path=self.config_path, tool_repository_username="new_user"
)
self.assertEqual(settings.tool_repository_username, "new_user")
self.assertEqual(settings.tool_repository_password, "file_pass")
def test_clear_user_settings(self):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
settings = Settings(config_path=self.config_path, **user_settings)
settings.clear_user_settings()
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
@patch("crewai_cli.config.TokenManager")
def test_reset_settings(self, mock_token_manager):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
settings = Settings(
config_path=self.config_path, **user_settings, **cli_settings
)
mock_token_manager.return_value = MagicMock()
TokenManager().save_tokens(
"aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp()
)
settings.reset()
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
for key in cli_settings.keys():
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
mock_token_manager.return_value.clear_tokens.assert_called_once()
def test_dump_new_settings(self):
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertEqual(saved_data["tool_repository_username"], "user1")
def test_update_existing_settings(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
json.dump({"existing_setting": "value"}, f)
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"
)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertEqual(saved_data["existing_setting"], "value")
self.assertEqual(saved_data["tool_repository_username"], "user1")
def test_none_values(self):
settings = Settings(config_path=self.config_path, tool_repository_username=None)
settings.dump()
with self.config_path.open("r") as f:
saved_data = json.load(f)
self.assertIsNone(saved_data.get("tool_repository_username"))
def test_invalid_json_in_config(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with self.config_path.open("w") as f:
f.write("invalid json")
try:
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)
except json.JSONDecodeError:
self.fail("Settings initialization should handle invalid JSON")
def test_empty_config_file(self):
self.config_path.parent.mkdir(parents=True, exist_ok=True)
self.config_path.touch()
settings = Settings(config_path=self.config_path)
self.assertIsNone(settings.tool_repository_username)

View File

@@ -1,20 +0,0 @@
from crewai_cli.constants import ENV_VARS, MODELS, PROVIDERS
def test_huggingface_in_providers():
"""Test that Huggingface is in the PROVIDERS list."""
assert "huggingface" in PROVIDERS
def test_huggingface_env_vars():
"""Test that Huggingface environment variables are properly configured."""
assert "huggingface" in ENV_VARS
assert any(
detail.get("key_name") == "HF_TOKEN" for detail in ENV_VARS["huggingface"]
)
def test_huggingface_models():
"""Test that Huggingface models are properly configured."""
assert "huggingface" in MODELS
assert len(MODELS["huggingface"]) > 0

View File

@@ -1,356 +0,0 @@
import os
import unittest
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from crewai_cli.plus_api import PlusAPI
class TestPlusAPI(unittest.TestCase):
def setUp(self):
self.api_key = "test_api_key"
self.api = PlusAPI(self.api_key)
self.org_uuid = "test-org-uuid"
def test_init(self):
self.assertEqual(self.api.api_key, self.api_key)
self.assertEqual(self.api.headers["Authorization"], f"Bearer {self.api_key}")
self.assertEqual(self.api.headers["Content-Type"], "application/json")
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
self.assertTrue(self.api.headers["X-Crewai-Version"])
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_login_to_tool_repository(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.login_to_tool_repository()
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools/login"
)
self.assertEqual(response, mock_response)
def assert_request_with_org_id(
self, mock_client_instance, method: str, endpoint: str, **kwargs
):
mock_client_instance.request.assert_called_once_with(
method,
f"{os.getenv('CREWAI_PLUS_URL')}{endpoint}",
headers={
"Authorization": ANY,
"Content-Type": ANY,
"User-Agent": ANY,
"X-Crewai-Version": ANY,
"X-Crewai-Organization-Id": self.org_uuid,
},
**kwargs,
)
@patch("crewai_cli.plus_api.Settings")
@patch("crewai_cli.plus_api.httpx.Client")
def test_login_to_tool_repository_with_org_uuid(
self, mock_client_class, mock_settings_class
):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api.login_to_tool_repository()
self.assert_request_with_org_id(
mock_client_instance, "POST", "/crewai_plus/api/v1/tools/login"
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get_tool("test_tool_handle")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.Settings")
@patch("crewai_cli.plus_api.httpx.Client")
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api.get_tool("test_tool_handle")
self.assert_request_with_org_id(
mock_client_instance, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_publish_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": None,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.Settings")
@patch("crewai_cli.plus_api.httpx.Client")
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
expected_params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": None,
}
self.assert_request_with_org_id(
mock_client_instance, "POST", "/crewai_plus/api/v1/tools", json=expected_params
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_publish_tool_without_description(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = False
version = "2.0.0"
description = None
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": None,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.httpx.Client")
def test_make_request(self, mock_client_class):
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api._make_request("GET", "test_endpoint")
mock_client_class.assert_called_once_with(trust_env=False, verify=True)
mock_client_instance.request.assert_called_once_with(
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
)
self.assertEqual(response, mock_response)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_deploy_by_name(self, mock_make_request):
self.api.deploy_by_name("test_project")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_deploy_by_uuid(self, mock_make_request):
self.api.deploy_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_status_by_name(self, mock_make_request):
self.api.crew_status_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_status_by_uuid(self, mock_make_request):
self.api.crew_status_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_by_name(self, mock_make_request):
self.api.crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
)
self.api.crew_by_name("test_project", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_crew_by_uuid(self, mock_make_request):
self.api.crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
)
self.api.crew_by_uuid("test_uuid", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_delete_crew_by_name(self, mock_make_request):
self.api.delete_crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_delete_crew_by_uuid(self, mock_make_request):
self.api.delete_crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
)
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_list_crews(self, mock_make_request):
self.api.list_crews()
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
@patch("crewai_cli.plus_api.PlusAPI._make_request")
def test_create_crew(self, mock_make_request):
payload = {"name": "test_crew"}
self.api.create_crew(payload)
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews", json=payload
)
@patch("crewai_cli.plus_api.Settings")
@patch.dict(os.environ, {"CREWAI_PLUS_URL": ""})
def test_custom_base_url(self, mock_settings_class):
mock_settings = MagicMock()
mock_settings.enterprise_base_url = "https://custom-url.com/api"
mock_settings_class.return_value = mock_settings
custom_api = PlusAPI("test_key")
self.assertEqual(
custom_api.base_url,
"https://custom-url.com/api",
)
@patch.dict(os.environ, {"CREWAI_PLUS_URL": "https://custom-url-from-env.com"})
def test_custom_base_url_from_env(self):
custom_api = PlusAPI("test_key")
self.assertEqual(
custom_api.base_url,
"https://custom-url-from-env.com",
)
@pytest.mark.asyncio
@patch("httpx.AsyncClient")
async def test_get_agent(mock_async_client_class):
api = PlusAPI("test_api_key")
mock_response = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
response = await api.get_agent("test_agent_handle")
mock_client_instance.get.assert_called_once_with(
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
headers=api.headers,
)
assert response == mock_response
@pytest.mark.asyncio
@patch("httpx.AsyncClient")
@patch("crewai_cli.plus_api.Settings")
async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class):
org_uuid = "test-org-uuid"
mock_settings = MagicMock()
mock_settings.org_uuid = org_uuid
mock_settings.enterprise_base_url = os.getenv("CREWAI_PLUS_URL")
mock_settings_class.return_value = mock_settings
api = PlusAPI("test_api_key")
mock_response = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
response = await api.get_agent("test_agent_handle")
mock_client_instance.get.assert_called_once_with(
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
headers=api.headers,
)
assert "X-Crewai-Organization-Id" in api.headers
assert api.headers["X-Crewai-Organization-Id"] == org_uuid
assert response == mock_response

View File

@@ -1,294 +0,0 @@
"""Tests for TokenManager with atomic file operations."""
import json
import os
import tempfile
import unittest
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import patch
from cryptography.fernet import Fernet
from crewai_cli.shared.token_manager import TokenManager
class TestTokenManager(unittest.TestCase):
"""Test cases for TokenManager."""
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None:
"""Set up test fixtures."""
mock_get_key.return_value = Fernet.generate_key()
self.token_manager = TokenManager()
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_get_or_create_key_existing(
self,
mock_get_or_create: unittest.mock.MagicMock,
mock_read: unittest.mock.MagicMock,
) -> None:
"""Test that existing key is returned when present."""
mock_key = Fernet.generate_key()
mock_get_or_create.return_value = mock_key
token_manager = TokenManager()
result = token_manager.key
self.assertEqual(result, mock_key)
def test_get_or_create_key_new(self) -> None:
"""Test that new key is created when none exists."""
mock_key = Fernet.generate_key()
with (
patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create,
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
):
result = self.token_manager._get_or_create_key()
self.assertEqual(result, mock_key)
mock_read.assert_called_with("secret.key")
mock_generate.assert_called_once()
mock_atomic_create.assert_called_once_with("secret.key", mock_key)
def test_get_or_create_key_race_condition(self) -> None:
"""Test that another process's key is used when atomic create fails."""
our_key = Fernet.generate_key()
their_key = Fernet.generate_key()
with (
patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create,
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=our_key),
):
result = self.token_manager._get_or_create_key()
self.assertEqual(result, their_key)
self.assertEqual(mock_read.call_count, 2)
@patch("crewai_cli.shared.token_manager.TokenManager._atomic_write_secure_file")
def test_save_tokens(
self, mock_write: unittest.mock.MagicMock
) -> None:
"""Test saving tokens encrypts and writes atomically."""
access_token = "test_token"
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
self.token_manager.save_tokens(access_token, expires_at)
mock_write.assert_called_once()
args = mock_write.call_args[0]
self.assertEqual(args[0], "tokens.enc")
decrypted_data = self.token_manager.fernet.decrypt(args[1])
data = json.loads(decrypted_data)
self.assertEqual(data["access_token"], access_token)
expiration = datetime.fromisoformat(data["expiration"])
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_valid(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test getting a valid non-expired token."""
access_token = "test_token"
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertEqual(result, access_token)
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_expired(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test that expired token returns None."""
access_token = "test_token"
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertIsNone(result)
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_not_found(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test that missing token file returns None."""
mock_read.return_value = None
result = self.token_manager.get_token()
self.assertIsNone(result)
@patch("crewai_cli.shared.token_manager.TokenManager._delete_secure_file")
def test_clear_tokens(
self, mock_delete: unittest.mock.MagicMock
) -> None:
"""Test clearing tokens deletes the token file."""
self.token_manager.clear_tokens()
mock_delete.assert_called_once_with("tokens.enc")
class TestAtomicFileOperations(unittest.TestCase):
"""Test atomic file operations directly."""
def setUp(self) -> None:
"""Set up test fixtures with temp directory."""
self.temp_dir = tempfile.mkdtemp()
self.original_get_path = TokenManager._get_secure_storage_path
# Patch to use temp directory
def mock_get_path() -> Path:
return Path(self.temp_dir)
TokenManager._get_secure_storage_path = staticmethod(mock_get_path)
def tearDown(self) -> None:
"""Clean up temp directory."""
TokenManager._get_secure_storage_path = staticmethod(self.original_get_path)
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic create succeeds for new file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
result = tm._atomic_create_secure_file("test.txt", b"content")
self.assertTrue(result)
file_path = Path(self.temp_dir) / "test.txt"
self.assertTrue(file_path.exists())
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_existing_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic create fails for existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
# Create file first
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"original")
result = tm._atomic_create_secure_file("test.txt", b"new content")
self.assertFalse(result)
self.assertEqual(file_path.read_bytes(), b"original")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic write creates new file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
tm._atomic_write_secure_file("test.txt", b"content")
file_path = Path(self.temp_dir) / "test.txt"
self.assertTrue(file_path.exists())
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_overwrites(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic write overwrites existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"original")
tm._atomic_write_secure_file("test.txt", b"new content")
self.assertEqual(file_path.read_bytes(), b"new content")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_no_temp_file_on_success(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test that temp file is cleaned up after successful write."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
tm._atomic_write_secure_file("test.txt", b"content")
# Check no temp files remain
temp_files = list(Path(self.temp_dir).glob(".test.txt.*"))
self.assertEqual(len(temp_files), 0)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test reading existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"content")
result = tm._read_secure_file("test.txt")
self.assertEqual(result, b"content")
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test reading non-existent file returns None."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
result = tm._read_secure_file("nonexistent.txt")
self.assertIsNone(result)
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test deleting existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"content")
tm._delete_secure_file("test.txt")
self.assertFalse(file_path.exists())
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test deleting non-existent file doesn't raise."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
# Should not raise
tm._delete_secure_file("nonexistent.txt")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,146 +0,0 @@
import os
import shutil
import tempfile
from pathlib import Path
import pytest
from crewai_cli import utils
@pytest.fixture
def temp_tree():
root_dir = tempfile.mkdtemp()
create_file(os.path.join(root_dir, "file1.txt"), "Hello, world!")
create_file(os.path.join(root_dir, "file2.txt"), "Another file")
os.mkdir(os.path.join(root_dir, "empty_dir"))
nested_dir = os.path.join(root_dir, "nested_dir")
os.mkdir(nested_dir)
create_file(os.path.join(nested_dir, "nested_file.txt"), "Nested content")
yield root_dir
shutil.rmtree(root_dir)
def create_file(path, content):
with open(path, "w") as f:
f.write(content)
def test_tree_find_and_replace_file_content(temp_tree):
utils.tree_find_and_replace(temp_tree, "world", "universe")
with open(os.path.join(temp_tree, "file1.txt"), "r") as f:
assert f.read() == "Hello, universe!"
def test_tree_find_and_replace_file_name(temp_tree):
old_path = os.path.join(temp_tree, "file2.txt")
new_path = os.path.join(temp_tree, "file2_renamed.txt")
os.rename(old_path, new_path)
utils.tree_find_and_replace(temp_tree, "renamed", "modified")
assert os.path.exists(os.path.join(temp_tree, "file2_modified.txt"))
assert not os.path.exists(new_path)
def test_tree_find_and_replace_directory_name(temp_tree):
utils.tree_find_and_replace(temp_tree, "empty", "renamed")
assert os.path.exists(os.path.join(temp_tree, "renamed_dir"))
assert not os.path.exists(os.path.join(temp_tree, "empty_dir"))
def test_tree_find_and_replace_nested_content(temp_tree):
utils.tree_find_and_replace(temp_tree, "Nested", "Updated")
with open(os.path.join(temp_tree, "nested_dir", "nested_file.txt"), "r") as f:
assert f.read() == "Updated content"
def test_tree_find_and_replace_no_matches(temp_tree):
utils.tree_find_and_replace(temp_tree, "nonexistent", "replacement")
assert set(os.listdir(temp_tree)) == {
"file1.txt",
"file2.txt",
"empty_dir",
"nested_dir",
}
def test_tree_copy_full_structure(temp_tree):
dest_dir = tempfile.mkdtemp()
try:
utils.tree_copy(temp_tree, dest_dir)
assert set(os.listdir(dest_dir)) == set(os.listdir(temp_tree))
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
assert os.path.isfile(os.path.join(dest_dir, "file2.txt"))
assert os.path.isdir(os.path.join(dest_dir, "empty_dir"))
assert os.path.isdir(os.path.join(dest_dir, "nested_dir"))
assert os.path.isfile(os.path.join(dest_dir, "nested_dir", "nested_file.txt"))
finally:
shutil.rmtree(dest_dir)
def test_tree_copy_preserve_content(temp_tree):
dest_dir = tempfile.mkdtemp()
try:
utils.tree_copy(temp_tree, dest_dir)
with open(os.path.join(dest_dir, "file1.txt"), "r") as f:
assert f.read() == "Hello, world!"
with open(os.path.join(dest_dir, "nested_dir", "nested_file.txt"), "r") as f:
assert f.read() == "Nested content"
finally:
shutil.rmtree(dest_dir)
def test_tree_copy_to_existing_directory(temp_tree):
dest_dir = tempfile.mkdtemp()
try:
create_file(os.path.join(dest_dir, "existing_file.txt"), "I was here first")
utils.tree_copy(temp_tree, dest_dir)
assert os.path.isfile(os.path.join(dest_dir, "existing_file.txt"))
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
finally:
shutil.rmtree(dest_dir)
@pytest.fixture
def temp_project_dir():
"""Create a temporary directory for testing tool extraction."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
def create_init_file(directory, content):
return create_file(directory / "__init__.py", content)
def test_extract_available_exports_empty_project(temp_project_dir, capsys):
with pytest.raises(SystemExit):
utils.extract_available_exports(dir_path=temp_project_dir)
captured = capsys.readouterr()
assert "No valid tools were exposed in your __init__.py file" in captured.out
def test_extract_available_exports_no_init_file(temp_project_dir, capsys):
(temp_project_dir / "some_file.py").write_text("print('hello')")
with pytest.raises(SystemExit):
utils.extract_available_exports(dir_path=temp_project_dir)
captured = capsys.readouterr()
assert "No valid tools were exposed in your __init__.py file" in captured.out
def test_extract_available_exports_empty_init_file(temp_project_dir, capsys):
create_init_file(temp_project_dir, "")
with pytest.raises(SystemExit):
utils.extract_available_exports(dir_path=temp_project_dir)
captured = capsys.readouterr()
assert "Warning: No __all__ defined in" in captured.out
# Tests for extract_available_exports with crewai.tools (BaseTool, @tool)
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.
# Tests for get_crews, get_flows, fetch_crews, is_valid_tool
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.

View File

@@ -1,372 +0,0 @@
"""Test for version management."""
import json
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
from crewai_cli.version import get_crewai_version as _get_ver
from crewai_cli.version import (
_find_latest_non_yanked_version,
_get_cache_file,
_is_cache_valid,
_is_version_yanked,
get_crewai_version,
get_latest_version_from_pypi,
is_current_version_yanked,
is_newer_version_available,
)
def test_dynamic_versioning_consistency() -> None:
"""Test that dynamic versioning provides consistent version across all access methods."""
cli_version = get_crewai_version()
package_version = _get_ver()
assert cli_version == package_version
assert package_version is not None
assert len(package_version.strip()) > 0
class TestVersionChecking:
"""Test version checking utilities."""
def test_get_crewai_version(self) -> None:
"""Test getting current crewai version."""
version = get_crewai_version()
assert isinstance(version, str)
assert len(version) > 0
def test_get_cache_file(self) -> None:
"""Test cache file path generation."""
cache_file = _get_cache_file()
assert isinstance(cache_file, Path)
assert cache_file.name == "version_cache.json"
def test_is_cache_valid_with_fresh_cache(self) -> None:
"""Test cache validation with fresh cache."""
cache_data = {"timestamp": datetime.now().isoformat(), "version": "1.0.0"}
assert _is_cache_valid(cache_data) is True
def test_is_cache_valid_with_stale_cache(self) -> None:
"""Test cache validation with stale cache."""
old_time = datetime.now() - timedelta(hours=25)
cache_data = {"timestamp": old_time.isoformat(), "version": "1.0.0"}
assert _is_cache_valid(cache_data) is False
def test_is_cache_valid_with_missing_timestamp(self) -> None:
"""Test cache validation with missing timestamp."""
cache_data = {"version": "1.0.0"}
assert _is_cache_valid(cache_data) is False
@patch("crewai_cli.version.Path.exists")
@patch("crewai_cli.version.request.urlopen")
def test_get_latest_version_from_pypi_success(
self, mock_urlopen: MagicMock, mock_exists: MagicMock
) -> None:
"""Test successful PyPI version fetch uses releases data."""
mock_exists.return_value = False
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0": [{"yanked": False}],
"2.1.0": [{"yanked": True, "yanked_reason": "bad release"}],
}
mock_response = MagicMock()
mock_response.read.return_value = json.dumps(
{"info": {"version": "2.1.0"}, "releases": releases}
).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response
version = get_latest_version_from_pypi()
assert version == "2.0.0"
@patch("crewai_cli.version.Path.exists")
@patch("crewai_cli.version.request.urlopen")
def test_get_latest_version_from_pypi_failure(
self, mock_urlopen: MagicMock, mock_exists: MagicMock
) -> None:
"""Test PyPI version fetch failure."""
from urllib.error import URLError
mock_exists.return_value = False
mock_urlopen.side_effect = URLError("Network error")
version = get_latest_version_from_pypi()
assert version is None
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version.get_latest_version_from_pypi")
def test_is_newer_version_available_true(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
"""Test when newer version is available."""
mock_current.return_value = "1.0.0"
mock_latest.return_value = "2.0.0"
is_newer, current, latest = is_newer_version_available()
assert is_newer is True
assert current == "1.0.0"
assert latest == "2.0.0"
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version.get_latest_version_from_pypi")
def test_is_newer_version_available_false(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
"""Test when no newer version is available."""
mock_current.return_value = "2.0.0"
mock_latest.return_value = "2.0.0"
is_newer, current, latest = is_newer_version_available()
assert is_newer is False
assert current == "2.0.0"
assert latest == "2.0.0"
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version.get_latest_version_from_pypi")
def test_is_newer_version_available_with_none_latest(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
"""Test when PyPI fetch fails."""
mock_current.return_value = "1.0.0"
mock_latest.return_value = None
is_newer, current, latest = is_newer_version_available()
assert is_newer is False
assert current == "1.0.0"
assert latest is None
class TestFindLatestNonYankedVersion:
"""Test _find_latest_non_yanked_version helper."""
def test_skips_yanked_versions(self) -> None:
"""Test that yanked versions are skipped."""
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0": [{"yanked": True}],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_returns_highest_non_yanked(self) -> None:
"""Test that the highest non-yanked version is returned."""
releases = {
"1.0.0": [{"yanked": False}],
"1.5.0": [{"yanked": False}],
"2.0.0": [{"yanked": True}],
}
assert _find_latest_non_yanked_version(releases) == "1.5.0"
def test_returns_none_when_all_yanked(self) -> None:
"""Test that None is returned when all versions are yanked."""
releases = {
"1.0.0": [{"yanked": True}],
"2.0.0": [{"yanked": True}],
}
assert _find_latest_non_yanked_version(releases) is None
def test_skips_prerelease_versions(self) -> None:
"""Test that pre-release versions are skipped."""
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0a1": [{"yanked": False}],
"2.0.0rc1": [{"yanked": False}],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_skips_versions_with_empty_files(self) -> None:
"""Test that versions with no files are skipped."""
releases: dict[str, list[dict[str, bool]]] = {
"1.0.0": [{"yanked": False}],
"2.0.0": [],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_handles_invalid_version_strings(self) -> None:
"""Test that invalid version strings are skipped."""
releases = {
"1.0.0": [{"yanked": False}],
"not-a-version": [{"yanked": False}],
}
assert _find_latest_non_yanked_version(releases) == "1.0.0"
def test_partially_yanked_files_not_considered_yanked(self) -> None:
"""Test that a version with some non-yanked files is not yanked."""
releases = {
"1.0.0": [{"yanked": False}],
"2.0.0": [{"yanked": True}, {"yanked": False}],
}
assert _find_latest_non_yanked_version(releases) == "2.0.0"
class TestIsVersionYanked:
"""Test _is_version_yanked helper."""
def test_non_yanked_version(self) -> None:
"""Test a non-yanked version returns False."""
releases = {"1.0.0": [{"yanked": False}]}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is False
assert reason == ""
def test_yanked_version_with_reason(self) -> None:
"""Test a yanked version returns True with reason."""
releases = {
"1.0.0": [{"yanked": True, "yanked_reason": "critical bug"}],
}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is True
assert reason == "critical bug"
def test_yanked_version_without_reason(self) -> None:
"""Test a yanked version returns True with empty reason."""
releases = {"1.0.0": [{"yanked": True}]}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is True
assert reason == ""
def test_unknown_version(self) -> None:
"""Test an unknown version returns False."""
releases = {"1.0.0": [{"yanked": False}]}
is_yanked, reason = _is_version_yanked("9.9.9", releases)
assert is_yanked is False
assert reason == ""
def test_partially_yanked_files(self) -> None:
"""Test a version with mixed yanked/non-yanked files is not yanked."""
releases = {
"1.0.0": [{"yanked": True}, {"yanked": False}],
}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is False
assert reason == ""
def test_multiple_yanked_files_picks_first_reason(self) -> None:
"""Test that the first available reason is returned."""
releases = {
"1.0.0": [
{"yanked": True, "yanked_reason": ""},
{"yanked": True, "yanked_reason": "second reason"},
],
}
is_yanked, reason = _is_version_yanked("1.0.0", releases)
assert is_yanked is True
assert reason == "second reason"
class TestIsCurrentVersionYanked:
"""Test is_current_version_yanked public function."""
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_reads_from_valid_cache(
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
) -> None:
"""Test reading yanked status from a valid cache."""
mock_version.return_value = "1.0.0"
cache_file = tmp_path / "version_cache.json"
cache_data = {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"current_version": "1.0.0",
"current_version_yanked": True,
"current_version_yanked_reason": "bad release",
}
cache_file.write_text(json.dumps(cache_data))
mock_cache_file.return_value = cache_file
is_yanked, reason = is_current_version_yanked()
assert is_yanked is True
assert reason == "bad release"
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_not_yanked_from_cache(
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
) -> None:
"""Test non-yanked status from a valid cache."""
mock_version.return_value = "2.0.0"
cache_file = tmp_path / "version_cache.json"
cache_data = {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"current_version": "2.0.0",
"current_version_yanked": False,
"current_version_yanked_reason": "",
}
cache_file.write_text(json.dumps(cache_data))
mock_cache_file.return_value = cache_file
is_yanked, reason = is_current_version_yanked()
assert is_yanked is False
assert reason == ""
@patch("crewai_cli.version.get_latest_version_from_pypi")
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_triggers_fetch_on_stale_cache(
self,
mock_cache_file: MagicMock,
mock_version: MagicMock,
mock_fetch: MagicMock,
tmp_path: Path,
) -> None:
"""Test that a stale cache triggers a re-fetch."""
mock_version.return_value = "1.0.0"
cache_file = tmp_path / "version_cache.json"
old_time = datetime.now() - timedelta(hours=25)
cache_data = {
"version": "2.0.0",
"timestamp": old_time.isoformat(),
"current_version": "1.0.0",
"current_version_yanked": True,
"current_version_yanked_reason": "old reason",
}
cache_file.write_text(json.dumps(cache_data))
mock_cache_file.return_value = cache_file
fresh_cache = {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"current_version": "1.0.0",
"current_version_yanked": False,
"current_version_yanked_reason": "",
}
def write_fresh_cache() -> str:
cache_file.write_text(json.dumps(fresh_cache))
return "2.0.0"
mock_fetch.side_effect = lambda: write_fresh_cache()
is_yanked, reason = is_current_version_yanked()
assert is_yanked is False
mock_fetch.assert_called_once()
@patch("crewai_cli.version.get_latest_version_from_pypi")
@patch("crewai_cli.version.get_crewai_version")
@patch("crewai_cli.version._get_cache_file")
def test_returns_false_on_fetch_failure(
self,
mock_cache_file: MagicMock,
mock_version: MagicMock,
mock_fetch: MagicMock,
tmp_path: Path,
) -> None:
"""Test that fetch failure returns not yanked."""
mock_version.return_value = "1.0.0"
cache_file = tmp_path / "version_cache.json"
mock_cache_file.return_value = cache_file
mock_fetch.return_value = None
is_yanked, reason = is_current_version_yanked()
assert is_yanked is False
assert reason == ""
# TestConsoleFormatterVersionCheck tests remain in lib/crewai/tests/cli/test_version.py
# as they depend on crewai.events.utils.console_formatter (core package).

View File

@@ -152,4 +152,4 @@ __all__ = [
"wrap_file_source",
]
__version__ = "1.10.2rc2"
__version__ = "1.10.1"

View File

@@ -11,7 +11,7 @@ dependencies = [
"pytube~=15.0.0",
"requests~=2.32.5",
"docker~=7.1.0",
"crewai==1.10.2rc2",
"crewai==1.10.1",
"tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4",
"python-docx~=1.2.0",

View File

@@ -88,6 +88,7 @@ from crewai_tools.tools.generate_crewai_automation_tool.generate_crewai_automati
GenerateCrewaiAutomationTool,
)
from crewai_tools.tools.github_search_tool.github_search_tool import GithubSearchTool
from crewai_tools.tools.grep_tool.grep_tool import GrepTool
from crewai_tools.tools.hyperbrowser_load_tool.hyperbrowser_load_tool import (
HyperbrowserLoadTool,
)
@@ -248,6 +249,7 @@ __all__ = [
"FirecrawlSearchTool",
"GenerateCrewaiAutomationTool",
"GithubSearchTool",
"GrepTool",
"HyperbrowserLoadTool",
"InvokeCrewAIAutomationTool",
"JSONSearchTool",
@@ -309,4 +311,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.10.2rc2"
__version__ = "1.10.1"

View File

@@ -1,9 +1,7 @@
from collections.abc import Callable
import os
from pathlib import Path
from typing import Any
from crewai.utilities.lock_store import lock as store_lock
from lancedb import ( # type: ignore[import-untyped]
DBConnection as LanceDBConnection,
connect as lancedb_connect,
@@ -35,12 +33,10 @@ class LanceDBAdapter(Adapter):
_db: LanceDBConnection = PrivateAttr()
_table: LanceDBTable = PrivateAttr()
_lock_name: str = PrivateAttr(default="")
def model_post_init(self, __context: Any) -> None:
self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name)
self._lock_name = f"lancedb:{os.path.realpath(str(self.uri))}"
super().model_post_init(__context)
@@ -60,5 +56,4 @@ class LanceDBAdapter(Adapter):
*args: Any,
**kwargs: Any,
) -> None:
with store_lock(self._lock_name):
self._table.add(*args, **kwargs)
self._table.add(*args, **kwargs)

View File

@@ -1,9 +1,6 @@
from __future__ import annotations
import asyncio
import contextvars
import logging
import threading
from typing import TYPE_CHECKING
@@ -21,9 +18,6 @@ class BrowserSessionManager:
This class maintains separate browser sessions for different threads,
enabling concurrent usage of browsers in multi-threaded environments.
Browsers are created lazily only when needed by tools.
Uses per-key events to serialize creation for the same thread_id without
blocking unrelated callers or wasting resources on duplicate sessions.
"""
def __init__(self, region: str = "us-west-2"):
@@ -33,10 +27,8 @@ class BrowserSessionManager:
region: AWS region for browser client
"""
self.region = region
self._lock = threading.Lock()
self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
self._creating: dict[str, threading.Event] = {}
async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
"""Get or create an async browser for the specified thread.
@@ -47,29 +39,10 @@ class BrowserSessionManager:
Returns:
An async browser instance specific to the thread
"""
loop = asyncio.get_event_loop()
while True:
with self._lock:
if thread_id in self._async_sessions:
return self._async_sessions[thread_id][1]
if thread_id not in self._creating:
self._creating[thread_id] = threading.Event()
break
event = self._creating[thread_id]
ctx = contextvars.copy_context()
await loop.run_in_executor(None, ctx.run, event.wait)
if thread_id in self._async_sessions:
return self._async_sessions[thread_id][1]
try:
browser_client, browser = await self._create_async_browser_session(
thread_id
)
with self._lock:
self._async_sessions[thread_id] = (browser_client, browser)
return browser
finally:
with self._lock:
evt = self._creating.pop(thread_id)
evt.set()
return await self._create_async_browser_session(thread_id)
def get_sync_browser(self, thread_id: str) -> SyncBrowser:
"""Get or create a sync browser for the specified thread.
@@ -80,33 +53,19 @@ class BrowserSessionManager:
Returns:
A sync browser instance specific to the thread
"""
while True:
with self._lock:
if thread_id in self._sync_sessions:
return self._sync_sessions[thread_id][1]
if thread_id not in self._creating:
self._creating[thread_id] = threading.Event()
break
event = self._creating[thread_id]
event.wait()
if thread_id in self._sync_sessions:
return self._sync_sessions[thread_id][1]
try:
return self._create_sync_browser_session(thread_id)
finally:
with self._lock:
evt = self._creating.pop(thread_id)
evt.set()
return self._create_sync_browser_session(thread_id)
async def _create_async_browser_session(
self, thread_id: str
) -> tuple[BrowserClient, AsyncBrowser]:
async def _create_async_browser_session(self, thread_id: str) -> AsyncBrowser:
"""Create a new async browser session for the specified thread.
Args:
thread_id: Unique identifier for the thread
Returns:
Tuple of (BrowserClient, AsyncBrowser).
The newly created async browser instance
Raises:
Exception: If browser session creation fails
@@ -116,8 +75,10 @@ class BrowserSessionManager:
browser_client = BrowserClient(region=self.region)
try:
# Start browser session
browser_client.start()
# Get WebSocket connection info
ws_url, headers = browser_client.generate_ws_headers()
logger.info(
@@ -126,6 +87,7 @@ class BrowserSessionManager:
from playwright.async_api import async_playwright
# Connect to browser using Playwright
playwright = await async_playwright().start()
browser = await playwright.chromium.connect_over_cdp(
endpoint_url=ws_url, headers=headers, timeout=30000
@@ -134,13 +96,17 @@ class BrowserSessionManager:
f"Successfully connected to async browser for thread {thread_id}"
)
return browser_client, browser
# Store session resources
self._async_sessions[thread_id] = (browser_client, browser)
return browser
except Exception as e:
logger.error(
f"Failed to create async browser session for thread {thread_id}: {e}"
)
# Clean up resources if session creation fails
if browser_client:
try:
browser_client.stop()
@@ -166,8 +132,10 @@ class BrowserSessionManager:
browser_client = BrowserClient(region=self.region)
try:
# Start browser session
browser_client.start()
# Get WebSocket connection info
ws_url, headers = browser_client.generate_ws_headers()
logger.info(
@@ -176,6 +144,7 @@ class BrowserSessionManager:
from playwright.sync_api import sync_playwright
# Connect to browser using Playwright
playwright = sync_playwright().start()
browser = playwright.chromium.connect_over_cdp(
endpoint_url=ws_url, headers=headers, timeout=30000
@@ -184,8 +153,8 @@ class BrowserSessionManager:
f"Successfully connected to sync browser for thread {thread_id}"
)
with self._lock:
self._sync_sessions[thread_id] = (browser_client, browser)
# Store session resources
self._sync_sessions[thread_id] = (browser_client, browser)
return browser
@@ -194,6 +163,7 @@ class BrowserSessionManager:
f"Failed to create sync browser session for thread {thread_id}: {e}"
)
# Clean up resources if session creation fails
if browser_client:
try:
browser_client.stop()
@@ -208,13 +178,13 @@ class BrowserSessionManager:
Args:
thread_id: Unique identifier for the thread
"""
with self._lock:
if thread_id not in self._async_sessions:
logger.warning(f"No async browser session found for thread {thread_id}")
return
if thread_id not in self._async_sessions:
logger.warning(f"No async browser session found for thread {thread_id}")
return
browser_client, browser = self._async_sessions.pop(thread_id)
browser_client, browser = self._async_sessions[thread_id]
# Close browser
if browser:
try:
await browser.close()
@@ -223,6 +193,7 @@ class BrowserSessionManager:
f"Error closing async browser for thread {thread_id}: {e}"
)
# Stop browser client
if browser_client:
try:
browser_client.stop()
@@ -231,6 +202,8 @@ class BrowserSessionManager:
f"Error stopping browser client for thread {thread_id}: {e}"
)
# Remove session from dictionary
del self._async_sessions[thread_id]
logger.info(f"Async browser session cleaned up for thread {thread_id}")
def close_sync_browser(self, thread_id: str) -> None:
@@ -239,13 +212,13 @@ class BrowserSessionManager:
Args:
thread_id: Unique identifier for the thread
"""
with self._lock:
if thread_id not in self._sync_sessions:
logger.warning(f"No sync browser session found for thread {thread_id}")
return
if thread_id not in self._sync_sessions:
logger.warning(f"No sync browser session found for thread {thread_id}")
return
browser_client, browser = self._sync_sessions.pop(thread_id)
browser_client, browser = self._sync_sessions[thread_id]
# Close browser
if browser:
try:
browser.close()
@@ -254,6 +227,7 @@ class BrowserSessionManager:
f"Error closing sync browser for thread {thread_id}: {e}"
)
# Stop browser client
if browser_client:
try:
browser_client.stop()
@@ -262,17 +236,19 @@ class BrowserSessionManager:
f"Error stopping browser client for thread {thread_id}: {e}"
)
# Remove session from dictionary
del self._sync_sessions[thread_id]
logger.info(f"Sync browser session cleaned up for thread {thread_id}")
async def close_all_browsers(self) -> None:
"""Close all browser sessions."""
with self._lock:
async_thread_ids = list(self._async_sessions.keys())
sync_thread_ids = list(self._sync_sessions.keys())
# Close all async browsers
async_thread_ids = list(self._async_sessions.keys())
for thread_id in async_thread_ids:
await self.close_async_browser(thread_id)
# Close all sync browsers
sync_thread_ids = list(self._sync_sessions.keys())
for thread_id in sync_thread_ids:
self.close_sync_browser(thread_id)

View File

@@ -1,11 +1,9 @@
import logging
import os
from pathlib import Path
from typing import Any
from uuid import uuid4
import chromadb
from crewai.utilities.lock_store import lock as store_lock
from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.rag.base_loader import BaseLoader
@@ -40,32 +38,22 @@ class RAG(Adapter):
_client: Any = PrivateAttr()
_collection: Any = PrivateAttr()
_embedding_service: EmbeddingService = PrivateAttr()
_lock_name: str = PrivateAttr(default="")
def model_post_init(self, __context: Any) -> None:
try:
self._lock_name = (
f"chromadb:{os.path.realpath(self.persist_directory)}"
if self.persist_directory
else "chromadb:ephemeral"
if self.persist_directory:
self._client = chromadb.PersistentClient(path=self.persist_directory)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={
"hnsw:space": "cosine",
"description": "CrewAI Knowledge Base",
},
)
with store_lock(self._lock_name):
if self.persist_directory:
self._client = chromadb.PersistentClient(
path=self.persist_directory
)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={
"hnsw:space": "cosine",
"description": "CrewAI Knowledge Base",
},
)
self._embedding_service = EmbeddingService(
provider=self.embedding_provider,
model=self.embedding_model,
@@ -99,8 +87,29 @@ class RAG(Adapter):
loader_result = loader.load(source_content)
doc_id = loader_result.doc_id
chunks = chunker.chunk(loader_result.content)
existing_doc = self._collection.get(
where={"source": source_content.source_ref}, limit=1
)
existing_doc_id = (
existing_doc and existing_doc["metadatas"][0]["doc_id"]
if existing_doc["metadatas"]
else None
)
if existing_doc_id == doc_id:
logger.warning(
f"Document with source {loader_result.source} already exists"
)
return
# Document with same source ref does exists but the content has changed, deleting the oldest reference
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
documents = []
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
doc_metadata = (metadata or {}).copy()
doc_metadata["chunk_index"] = i
@@ -127,6 +136,7 @@ class RAG(Adapter):
ids = [doc.id for doc in documents]
metadatas = []
for doc in documents:
doc_metadata = doc.metadata.copy()
doc_metadata.update(
@@ -138,36 +148,16 @@ class RAG(Adapter):
)
metadatas.append(doc_metadata)
with store_lock(self._lock_name):
existing_doc = self._collection.get(
where={"source": source_content.source_ref}, limit=1
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
existing_doc_id = (
existing_doc and existing_doc["metadatas"][0]["doc_id"]
if existing_doc["metadatas"]
else None
)
if existing_doc_id == doc_id:
logger.warning(
f"Document with source {loader_result.source} already exists"
)
return
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
try:
@@ -211,8 +201,7 @@ class RAG(Adapter):
def delete_collection(self) -> None:
try:
with store_lock(self._lock_name):
self._client.delete_collection(self.collection_name)
self._client.delete_collection(self.collection_name)
logger.info(f"Deleted collection: {self.collection_name}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View File

@@ -77,6 +77,7 @@ from crewai_tools.tools.generate_crewai_automation_tool.generate_crewai_automati
GenerateCrewaiAutomationTool,
)
from crewai_tools.tools.github_search_tool.github_search_tool import GithubSearchTool
from crewai_tools.tools.grep_tool.grep_tool import GrepTool
from crewai_tools.tools.hyperbrowser_load_tool.hyperbrowser_load_tool import (
HyperbrowserLoadTool,
)
@@ -232,6 +233,7 @@ __all__ = [
"FirecrawlSearchTool",
"GenerateCrewaiAutomationTool",
"GithubSearchTool",
"GrepTool",
"HyperbrowserLoadTool",
"InvokeCrewAIAutomationTool",
"JSONSearchTool",

View File

@@ -1,3 +1,4 @@
from datetime import datetime
import json
import os
import time
@@ -9,8 +10,8 @@ from pydantic import BaseModel, Field
from pydantic.types import StringConstraints
import requests
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
load_dotenv()

View File

@@ -30,8 +30,9 @@ class FileWriterTool(BaseTool):
def _run(self, **kwargs: Any) -> str:
try:
if kwargs.get("directory"):
os.makedirs(kwargs["directory"], exist_ok=True)
# Create the directory if it doesn't exist
if kwargs.get("directory") and not os.path.exists(kwargs["directory"]):
os.makedirs(kwargs["directory"])
# Construct the full path
filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"])

View File

@@ -99,8 +99,8 @@ class FileCompressorTool(BaseTool):
def _prepare_output(output_path: str, overwrite: bool) -> bool:
"""Ensures output path is ready for writing."""
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
if os.path.exists(output_path) and not overwrite:
return False
return True

View File

@@ -0,0 +1,3 @@
from crewai_tools.tools.grep_tool.grep_tool import GrepTool
__all__ = ["GrepTool"]

View File

@@ -0,0 +1,542 @@
"""Tool for searching file contents on disk using regex patterns."""
from __future__ import annotations
from dataclasses import dataclass, field
from itertools import chain
import os
from pathlib import Path
import re
import signal
import sys
from typing import Literal
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
MAX_OUTPUT_CHARS = 50_000
MAX_FILES = 10_000
MAX_MATCHES_PER_FILE = 200
MAX_LINE_LENGTH = 500
BINARY_CHECK_SIZE = 8192
MAX_REGEX_LENGTH = 1_000
REGEX_MATCH_TIMEOUT_SECONDS = 5
MAX_CONTEXT_LINES = 10
MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 # 10 MB
SKIP_DIRS = frozenset(
{
".git",
"__pycache__",
"node_modules",
".venv",
"venv",
".tox",
".mypy_cache",
".pytest_cache",
}
)
# File names that may contain secrets or credentials — always excluded from
# search results to prevent accidental sensitive-content leakage.
SENSITIVE_FILE_NAMES = frozenset(
{
".env",
".env.local",
".env.development",
".env.production",
".env.staging",
".env.test",
".netrc",
".npmrc",
".pypirc",
".docker/config.json",
".aws/credentials",
".ssh/id_rsa",
".ssh/id_ed25519",
".ssh/id_ecdsa",
".ssh/id_dsa",
"credentials.json",
"service-account.json",
"secrets.yaml",
"secrets.yml",
"secrets.json",
}
)
# Glob-style suffixes that indicate sensitive content (matched against the
# full file name, e.g. "app.env.bak" won't match, but ".env.bak" will).
SENSITIVE_FILE_PATTERNS = (
".pem",
".key",
".p12",
".pfx",
".jks",
".keystore",
)
@dataclass
class MatchLine:
"""A single line from a search result."""
line_number: int
text: str
is_match: bool # True for match, False for context line
@dataclass
class FileSearchResult:
"""Search results for a single file."""
file_path: Path
matches: list[list[MatchLine]] = field(default_factory=list)
match_count: int = 0
class GrepToolSchema(BaseModel):
"""Schema for grep tool arguments."""
pattern: str = Field(
..., description="Regex pattern to search for in file contents"
)
path: str | None = Field(
default=None,
description="File or directory to search in. Defaults to current working directory.",
)
glob_pattern: str | None = Field(
default=None,
description="Glob pattern to filter files (e.g. '*.py'). Supports brace expansion (e.g. '*.{ts,tsx}').",
)
output_mode: Literal["content", "files_with_matches", "count"] = Field(
default="content",
description="Output mode: 'content' shows matching lines, 'files_with_matches' shows only file paths, 'count' shows match counts per file",
)
case_insensitive: bool = Field(
default=False,
description="Whether to perform case-insensitive matching",
)
context_lines: int = Field(
default=0,
ge=0,
le=MAX_CONTEXT_LINES,
description=f"Number of lines to show before and after each match (0-{MAX_CONTEXT_LINES})",
)
include_line_numbers: bool = Field(
default=True,
description="Whether to prefix matching lines with line numbers",
)
class GrepTool(BaseTool):
"""Tool for searching file contents on disk using regex patterns.
Recursively searches files in a directory for lines matching a regex pattern.
Supports glob filtering, context lines, and multiple output modes.
Example:
>>> tool = GrepTool()
>>> result = tool.run(pattern="def.*main", path="src")
>>> result = tool.run(
... pattern="TODO",
... glob_pattern="*.py",
... context_lines=2,
... )
To search any path on the filesystem (opt-in):
>>> tool = GrepTool(allow_unrestricted_paths=True)
>>> result = tool.run(pattern="error", path="/var/log/app")
"""
name: str = "Search file contents"
description: str = (
"A tool that searches file contents on disk using regex patterns. "
"Recursively searches files in a directory for matching lines. "
"Returns matching content with line numbers, file paths only, or match counts."
)
args_schema: type[BaseModel] = GrepToolSchema
allow_unrestricted_paths: bool = Field(
default=False,
description=(
"When False (default), searches are restricted to the current working "
"directory. Set to True to allow searching any path on the filesystem."
),
)
max_file_size_bytes: int = Field(
default=MAX_FILE_SIZE_BYTES,
description=(
"Maximum file size in bytes to search. Files larger than this are "
"skipped. Defaults to 10 MB."
),
)
def _run(
self,
pattern: str,
path: str | None = None,
glob_pattern: str | None = None,
output_mode: Literal["content", "files_with_matches", "count"] = "content",
case_insensitive: bool = False,
context_lines: int = 0,
include_line_numbers: bool = True,
**kwargs: object,
) -> str:
"""Search files for a regex pattern.
Args:
pattern: Regex pattern to search for.
path: File or directory to search. Defaults to cwd.
glob_pattern: Glob pattern to filter files.
output_mode: What to return.
case_insensitive: Case-insensitive matching.
context_lines: Lines of context around matches.
include_line_numbers: Prefix lines with line numbers.
Returns:
Formatted search results as a string.
"""
# Resolve search path — constrained to cwd unless unrestricted
cwd = Path(os.getcwd()).resolve()
if path:
candidate = Path(path)
if candidate.is_absolute():
search_path = candidate.resolve()
else:
search_path = (cwd / candidate).resolve()
# Prevent traversal outside the working directory (unless opted in)
if not self.allow_unrestricted_paths:
try:
search_path.relative_to(cwd)
except ValueError:
return (
f"Error: Path '{path}' is outside the working directory. "
"Initialize with GrepTool(allow_unrestricted_paths=True) to allow this."
)
else:
search_path = cwd
if not search_path.exists():
return f"Error: Path '{search_path}' does not exist."
# Compile regex with length guard to mitigate ReDoS
if len(pattern) > MAX_REGEX_LENGTH:
return f"Error: Pattern too long ({len(pattern)} chars). Maximum is {MAX_REGEX_LENGTH}."
flags = re.IGNORECASE if case_insensitive else 0
try:
compiled = re.compile(pattern, flags)
except re.error as e:
return f"Error: Invalid regex pattern '{pattern}': {e}"
# Collect files
files = self._collect_files(search_path, glob_pattern)
# Search each file
results: list[FileSearchResult] = []
for file_path in files:
result = self._search_file(file_path, compiled, context_lines)
if result is not None:
results.append(result)
if not results:
return "No matches found."
# Format output
if output_mode == "files_with_matches":
output = self._format_files_with_matches(results)
elif output_mode == "count":
output = self._format_count(results)
else:
output = self._format_content(results, include_line_numbers)
# Truncate if needed
if len(output) > MAX_OUTPUT_CHARS:
output = (
output[:MAX_OUTPUT_CHARS]
+ "\n\n... Output truncated. Try a narrower search pattern or glob filter."
)
return output
@staticmethod
def _expand_brace_pattern(pattern: str) -> list[str]:
"""Expand a simple brace pattern into individual globs.
Handles a single level of brace expansion, e.g.
``*.{py,txt}`` -> ``['*.py', '*.txt']``.
Nested braces are *not* supported and the pattern is returned as-is.
Args:
pattern: Glob pattern that may contain ``{a,b,...}`` syntax.
Returns:
List of expanded patterns (or the original if no braces found).
"""
match = re.search(r"\{([^{}]+)\}", pattern)
if not match:
return [pattern]
prefix = pattern[: match.start()]
suffix = pattern[match.end() :]
alternatives = match.group(1).split(",")
return [f"{prefix}{alt.strip()}{suffix}" for alt in alternatives]
def _collect_files(self, search_path: Path, glob_pattern: str | None) -> list[Path]:
"""Collect files to search.
Sensitive files (e.g. ``.env``, ``.netrc``, key material) are
automatically excluded even when searched by explicit path so that
credentials cannot leak into tool output.
Args:
search_path: File or directory to search.
glob_pattern: Optional glob pattern to filter files.
Returns:
List of file paths to search.
"""
if search_path.is_file():
if self._is_sensitive_file(search_path):
return []
return [search_path]
patterns = self._expand_brace_pattern(glob_pattern) if glob_pattern else ["*"]
seen: set[Path] = set()
files: list[Path] = []
for p in chain.from_iterable(search_path.rglob(pat) for pat in patterns):
if not p.is_file():
continue
if p in seen:
continue
seen.add(p)
# Skip hidden/build directories
if any(part in SKIP_DIRS for part in p.relative_to(search_path).parts):
continue
if self._is_sensitive_file(p):
continue
files.append(p)
if len(files) >= MAX_FILES:
break
return sorted(files)
@staticmethod
def _safe_search(
compiled_pattern: re.Pattern[str], line: str
) -> re.Match[str] | None:
"""Run a regex search with a per-line timeout to mitigate ReDoS.
On platforms that support SIGALRM (Unix), a timeout is enforced.
On Windows, the search runs without a timeout but is still bounded
by MAX_LINE_LENGTH truncation applied earlier in the pipeline.
Args:
compiled_pattern: Compiled regex pattern.
line: The text line to search.
Returns:
Match object if found, None otherwise (including on timeout).
"""
if sys.platform == "win32":
return compiled_pattern.search(line)
def _timeout_handler(signum: int, frame: object) -> None:
raise TimeoutError
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(REGEX_MATCH_TIMEOUT_SECONDS)
try:
return compiled_pattern.search(line)
except TimeoutError:
return None
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
@staticmethod
def _is_sensitive_file(file_path: Path) -> bool:
"""Check whether a file is likely to contain secrets or credentials.
The check is deliberately conservative — it matches exact file names
(e.g. ``.env``, ``.netrc``) as well as common key/certificate
extensions. Files whose *name* starts with ``.env`` (including
variants like ``.env.local``, ``.env.production``, etc.) are also
excluded.
Args:
file_path: Path to the file.
Returns:
True if the file should be skipped.
"""
name = file_path.name
# Exact-name match (e.g. ".env", ".netrc", "secrets.json")
if name in SENSITIVE_FILE_NAMES:
return True
# Any .env variant (.env.backup, .env.staging.old, …)
if name.startswith(".env"):
return True
# Extension-based match for key/cert material
if any(name.endswith(ext) for ext in SENSITIVE_FILE_PATTERNS):
return True
# Check path components for well-known sensitive dirs/files
# e.g. ".aws/credentials" or ".ssh/id_rsa"
parts = file_path.parts
for i, _part in enumerate(parts):
remaining = "/".join(parts[i:])
if remaining in SENSITIVE_FILE_NAMES:
return True
return False
def _is_binary_file(self, file_path: Path) -> bool:
"""Check if a file is binary by looking for null bytes.
Args:
file_path: Path to the file.
Returns:
True if the file appears to be binary.
"""
try:
with open(file_path, "rb") as f:
chunk = f.read(BINARY_CHECK_SIZE)
return b"\x00" in chunk
except (OSError, PermissionError):
return True
def _search_file(
self,
file_path: Path,
compiled_pattern: re.Pattern[str],
context_lines: int,
) -> FileSearchResult | None:
"""Search a single file for matches.
Args:
file_path: Path to the file.
compiled_pattern: Compiled regex pattern.
context_lines: Number of context lines around matches.
Returns:
FileSearchResult if matches found, None otherwise.
"""
if self._is_sensitive_file(file_path):
return None
if self._is_binary_file(file_path):
return None
# Skip files that are too large to safely read into memory
try:
file_size = file_path.stat().st_size
except OSError:
return None
if file_size > self.max_file_size_bytes:
return None
try:
with open(file_path, encoding="utf-8", errors="replace") as f:
lines = f.readlines()
except (OSError, PermissionError):
return None
# Find matching line numbers
match_line_nums: list[int] = []
for i, line in enumerate(lines):
if self._safe_search(compiled_pattern, line):
match_line_nums.append(i)
if len(match_line_nums) >= MAX_MATCHES_PER_FILE:
break
if not match_line_nums:
return None
# Build groups of contiguous match blocks with context
groups: list[list[MatchLine]] = []
current_group: list[MatchLine] = []
prev_end = -1
for match_idx in match_line_nums:
start = max(0, match_idx - context_lines)
end = min(len(lines), match_idx + context_lines + 1)
# If this block doesn't overlap with the previous, start a new group
if start > prev_end and current_group:
groups.append(current_group)
current_group = []
for i in range(max(start, prev_end), end):
text = lines[i].rstrip("\n\r")
if len(text) > MAX_LINE_LENGTH:
text = text[:MAX_LINE_LENGTH] + "..."
current_group.append(
MatchLine(
line_number=i + 1, # 1-indexed
text=text,
is_match=(i in match_line_nums),
)
)
prev_end = end
if current_group:
groups.append(current_group)
return FileSearchResult(
file_path=file_path,
matches=groups,
match_count=len(match_line_nums),
)
def _format_content(
self,
results: list[FileSearchResult],
include_line_numbers: bool,
) -> str:
"""Format results showing matching content.
Args:
results: List of file search results.
include_line_numbers: Whether to include line numbers.
Returns:
Formatted string with file paths and matching lines.
"""
parts: list[str] = []
for result in results:
parts.append(str(result.file_path))
for group_idx, group in enumerate(result.matches):
if group_idx > 0:
parts.append("--")
for match_line in group:
if include_line_numbers:
parts.append(f"{match_line.line_number}: {match_line.text}")
else:
parts.append(match_line.text)
parts.append("") # blank line between files
return "\n".join(parts).rstrip()
def _format_files_with_matches(self, results: list[FileSearchResult]) -> str:
"""Format results showing only file paths.
Args:
results: List of file search results.
Returns:
One file path per line.
"""
return "\n".join(str(r.file_path) for r in results)
def _format_count(self, results: list[FileSearchResult]) -> str:
"""Format results showing match counts per file.
Args:
results: List of file search results.
Returns:
Filepath and count per line.
"""
return "\n".join(f"{r.file_path}: {r.match_count}" for r in results)

View File

@@ -18,6 +18,7 @@ class MergeAgentHandlerToolError(Exception):
"""Base exception for Merge Agent Handler tool errors."""
class MergeAgentHandlerTool(BaseTool):
"""
Wrapper for Merge Agent Handler tools.
@@ -173,7 +174,7 @@ class MergeAgentHandlerTool(BaseTool):
>>> tool = MergeAgentHandlerTool.from_tool_name(
... tool_name="linear__create_issue",
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
... )
"""
# Create an empty args schema model (proper BaseModel subclass)
@@ -209,10 +210,7 @@ class MergeAgentHandlerTool(BaseTool):
if "parameters" in tool_schema:
try:
params = tool_schema["parameters"]
if (
params.get("type") == "object"
and "properties" in params
):
if params.get("type") == "object" and "properties" in params:
# Build field definitions for Pydantic
fields = {}
properties = params["properties"]
@@ -300,7 +298,7 @@ class MergeAgentHandlerTool(BaseTool):
>>> tools = MergeAgentHandlerTool.from_tool_pack(
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
... tool_names=["linear__create_issue", "linear__get_issues"],
... tool_names=["linear__create_issue", "linear__get_issues"]
... )
"""
# Create a temporary instance to fetch the tool list

View File

@@ -110,13 +110,11 @@ class QdrantVectorSearchTool(BaseTool):
self.custom_embedding_fn(query)
if self.custom_embedding_fn
else (
lambda: (
__import__("openai")
.Client(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[query], model="text-embedding-3-large")
.data[0]
.embedding
)
lambda: __import__("openai")
.Client(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[query], model="text-embedding-3-large")
.data[0]
.embedding
)()
)
results = self.client.query_points(

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import threading
from typing import TYPE_CHECKING, Any
from crewai.tools.base_tool import BaseTool
@@ -34,7 +33,6 @@ logger = logging.getLogger(__name__)
# Cache for query results
_query_cache: dict[str, list[dict[str, Any]]] = {}
_cache_lock = threading.Lock()
class SnowflakeConfig(BaseModel):
@@ -104,7 +102,7 @@ class SnowflakeSearchTool(BaseTool):
)
_connection_pool: list[SnowflakeConnection] | None = None
_pool_lock: threading.Lock | None = None
_pool_lock: asyncio.Lock | None = None
_thread_pool: ThreadPoolExecutor | None = None
_model_rebuilt: bool = False
package_dependencies: list[str] = Field(
@@ -124,7 +122,7 @@ class SnowflakeSearchTool(BaseTool):
try:
if SNOWFLAKE_AVAILABLE:
self._connection_pool = []
self._pool_lock = threading.Lock()
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
else:
raise ImportError
@@ -149,7 +147,7 @@ class SnowflakeSearchTool(BaseTool):
)
self._connection_pool = []
self._pool_lock = threading.Lock()
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
except subprocess.CalledProcessError as e:
raise ImportError("Failed to install Snowflake dependencies") from e
@@ -165,12 +163,13 @@ class SnowflakeSearchTool(BaseTool):
raise RuntimeError("Pool lock not initialized")
if self._connection_pool is None:
raise RuntimeError("Connection pool not initialized")
with self._pool_lock:
if self._connection_pool:
return self._connection_pool.pop()
return await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
async with self._pool_lock:
if not self._connection_pool:
conn = await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
self._connection_pool.append(conn)
return self._connection_pool.pop()
def _create_connection(self) -> SnowflakeConnection:
"""Create a new Snowflake connection."""
@@ -205,10 +204,9 @@ class SnowflakeSearchTool(BaseTool):
"""Execute a query with retries and return results."""
if self.enable_caching:
cache_key = self._get_cache_key(query, timeout)
with _cache_lock:
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
for attempt in range(self.max_retries):
try:
@@ -227,8 +225,7 @@ class SnowflakeSearchTool(BaseTool):
]
if self.enable_caching:
with _cache_lock:
_query_cache[self._get_cache_key(query, timeout)] = results
_query_cache[self._get_cache_key(query, timeout)] = results
return results
finally:
@@ -237,7 +234,7 @@ class SnowflakeSearchTool(BaseTool):
self._pool_lock is not None
and self._connection_pool is not None
):
with self._pool_lock:
async with self._pool_lock:
self._connection_pool.append(conn)
except (DatabaseError, OperationalError) as e: # noqa: PERF203
if attempt == self.max_retries - 1:

View File

@@ -1,5 +1,4 @@
import asyncio
import contextvars
import json
import os
import re
@@ -138,9 +137,7 @@ class StagehandTool(BaseTool):
- 'observe': For finding elements in a specific area
"""
args_schema: type[BaseModel] = StagehandToolSchema
package_dependencies: list[str] = Field(
default_factory=lambda: ["stagehand<=0.5.9"]
)
package_dependencies: list[str] = Field(default_factory=lambda: ["stagehand<=0.5.9"])
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
@@ -623,12 +620,9 @@ class StagehandTool(BaseTool):
# We're in an existing event loop, use it
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
ctx.run,
asyncio.run,
self._async_run(instruction, url, command_type),
asyncio.run, self._async_run(instruction, url, command_type)
)
result = future.result()
else:
@@ -712,12 +706,11 @@ class StagehandTool(BaseTool):
if loop.is_running():
import concurrent.futures
ctx = contextvars.copy_context()
with (
concurrent.futures.ThreadPoolExecutor() as executor
):
future = executor.submit(
ctx.run, asyncio.run, self._async_close()
asyncio.run, self._async_close()
)
future.result()
else:

View File

@@ -0,0 +1,450 @@
"""Unit tests for GrepTool."""
from __future__ import annotations
from pathlib import Path
import pytest
from pydantic import ValidationError
from crewai_tools import GrepTool
from crewai_tools.tools.grep_tool.grep_tool import (
MAX_CONTEXT_LINES,
MAX_REGEX_LENGTH,
GrepToolSchema,
)
@pytest.fixture
def sample_dir(tmp_path: Path) -> Path:
"""Create a temp directory with sample files for testing."""
# src/main.py
src = tmp_path / "src"
src.mkdir()
(src / "main.py").write_text(
"def hello():\n"
" print('Hello, world!')\n"
"\n"
"def goodbye():\n"
" print('Goodbye, world!')\n"
"\n"
"class MyClass:\n"
" pass\n"
)
# src/utils.py
(src / "utils.py").write_text(
"import os\n"
"\n"
"def helper():\n"
" return os.getcwd()\n"
"\n"
"CONSTANT = 42\n"
)
# docs/readme.md
docs = tmp_path / "docs"
docs.mkdir()
(docs / "readme.md").write_text(
"# Project\n"
"\n"
"This is a sample project.\n"
"It has multiple files.\n"
)
# data/binary.bin
data = tmp_path / "data"
data.mkdir()
(data / "binary.bin").write_bytes(b"\x00\x01\x02\x03\x04binary content")
# empty.txt
(tmp_path / "empty.txt").write_text("")
# .git/config (should be skipped)
git_dir = tmp_path / ".git"
git_dir.mkdir()
(git_dir / "config").write_text("[core]\n repositoryformatversion = 0\n")
return tmp_path
class TestGrepTool:
"""Tests for GrepTool."""
def setup_method(self) -> None:
"""Set up test fixtures.
We use allow_unrestricted_paths=True so that tests using pytest's
tmp_path (which lives outside the working directory) are not rejected
by the path-restriction guard.
"""
self.tool = GrepTool(allow_unrestricted_paths=True)
def test_tool_metadata(self) -> None:
"""Test tool has correct name and description."""
assert self.tool.name == "Search file contents"
assert "search" in self.tool.description.lower() or "Search" in self.tool.description
def test_args_schema(self) -> None:
"""Test that args_schema has correct fields and defaults."""
schema = self.tool.args_schema
fields = schema.model_fields
assert "pattern" in fields
assert fields["pattern"].is_required()
assert "path" in fields
assert not fields["path"].is_required()
assert "glob_pattern" in fields
assert not fields["glob_pattern"].is_required()
assert "output_mode" in fields
assert not fields["output_mode"].is_required()
assert "case_insensitive" in fields
assert not fields["case_insensitive"].is_required()
assert "context_lines" in fields
assert not fields["context_lines"].is_required()
assert "include_line_numbers" in fields
assert not fields["include_line_numbers"].is_required()
def test_basic_pattern_match(self, sample_dir: Path) -> None:
"""Test simple string pattern found in output."""
result = self.tool._run(pattern="Hello", path=str(sample_dir))
assert "Hello" in result
def test_regex_pattern(self, sample_dir: Path) -> None:
"""Test regex pattern matches function definitions."""
result = self.tool._run(pattern=r"def\s+\w+", path=str(sample_dir))
assert "def hello" in result
assert "def goodbye" in result
assert "def helper" in result
def test_case_sensitive_default(self, sample_dir: Path) -> None:
"""Test that search is case-sensitive by default."""
result = self.tool._run(pattern="hello", path=str(sample_dir))
# "hello" (lowercase) appears in "def hello():" but not in "Hello, world!"
assert "hello" in result
# Verify it found the function definition line
assert "def hello" in result
def test_case_insensitive(self, sample_dir: Path) -> None:
"""Test case-insensitive matching."""
result = self.tool._run(
pattern="hello", path=str(sample_dir), case_insensitive=True
)
# Should match both "def hello():" and "Hello, world!"
assert "hello" in result.lower()
assert "Hello" in result
def test_output_mode_content(self, sample_dir: Path) -> None:
"""Test content output mode shows file paths, line numbers, and text."""
result = self.tool._run(
pattern="CONSTANT", path=str(sample_dir), output_mode="content"
)
assert "utils.py" in result
assert "CONSTANT" in result
# Should have line numbers by default
assert ": " in result
def test_output_mode_files_with_matches(self, sample_dir: Path) -> None:
"""Test files_with_matches output mode shows only file paths."""
result = self.tool._run(
pattern="def", path=str(sample_dir), output_mode="files_with_matches"
)
assert "main.py" in result
assert "utils.py" in result
# Should not contain line content
assert "print" not in result
def test_output_mode_count(self, sample_dir: Path) -> None:
"""Test count output mode shows filepath: N format."""
result = self.tool._run(
pattern="def", path=str(sample_dir), output_mode="count"
)
# main.py has 2 def lines, utils.py has 1
assert "main.py: 2" in result
assert "utils.py: 1" in result
def test_context_lines(self, sample_dir: Path) -> None:
"""Test surrounding context lines are included."""
result = self.tool._run(
pattern="CONSTANT", path=str(sample_dir), context_lines=2
)
# Two lines before CONSTANT = 42 is " return os.getcwd()"
assert "return os.getcwd()" in result
assert "CONSTANT" in result
def test_line_numbers_disabled(self, sample_dir: Path) -> None:
"""Test output without line number prefixes."""
result = self.tool._run(
pattern="CONSTANT",
path=str(sample_dir),
include_line_numbers=False,
)
assert "CONSTANT = 42" in result
# Verify no line number prefix (e.g., "6: ")
for line in result.strip().split("\n"):
if "CONSTANT" in line:
assert not line[0].isdigit() or ": " not in line
def test_glob_pattern_filtering(self, sample_dir: Path) -> None:
"""Test glob pattern filters to specific file types."""
result = self.tool._run(
pattern="project",
path=str(sample_dir),
glob_pattern="*.py",
case_insensitive=True,
)
# "project" appears in readme.md but not in .py files
assert "No matches found" in result
def test_search_single_file(self, sample_dir: Path) -> None:
"""Test searching a single file by path."""
file_path = str(sample_dir / "src" / "main.py")
result = self.tool._run(pattern="def", path=file_path)
assert "def hello" in result
assert "def goodbye" in result
# Should not include results from other files
assert "helper" not in result
def test_path_not_found(self) -> None:
"""Test error message when a relative path doesn't exist."""
result = self.tool._run(pattern="test", path="totally_nonexistent_subdir")
assert "Error" in result
assert "does not exist" in result
def test_invalid_regex(self, sample_dir: Path) -> None:
"""Test error message for invalid regex patterns."""
result = self.tool._run(pattern="[invalid", path=str(sample_dir))
assert "Error" in result
assert "Invalid regex" in result
def test_binary_files_skipped(self, sample_dir: Path) -> None:
"""Test binary files are not included in results."""
result = self.tool._run(pattern="binary", path=str(sample_dir))
# binary.bin has null bytes so it should be skipped
assert "binary.bin" not in result
def test_no_matches_found(self, sample_dir: Path) -> None:
"""Test message when no matches are found."""
result = self.tool._run(
pattern="zzz_nonexistent_pattern_zzz", path=str(sample_dir)
)
assert "No matches found" in result
def test_hidden_dirs_skipped(self, sample_dir: Path) -> None:
"""Test that .git/ directory contents are not searched."""
result = self.tool._run(pattern="repositoryformatversion", path=str(sample_dir))
assert "No matches found" in result
def test_empty_file(self, sample_dir: Path) -> None:
"""Test searching an empty file doesn't crash."""
result = self.tool._run(
pattern="anything", path=str(sample_dir / "empty.txt")
)
assert "No matches found" in result
def test_run_with_kwargs(self, sample_dir: Path) -> None:
"""Test _run ignores extra kwargs."""
result = self.tool._run(
pattern="Hello", path=str(sample_dir), extra_arg="ignored"
)
assert "Hello" in result
class TestPathRestriction:
"""Tests for path traversal prevention and allow_unrestricted_paths."""
def test_absolute_path_outside_cwd_blocked(self, tmp_path: Path) -> None:
"""An absolute path outside cwd is rejected by default."""
tool = GrepTool()
# tmp_path is almost certainly not under os.getcwd()
result = tool._run(pattern="anything", path=str(tmp_path))
assert "Error" in result
assert "outside the working directory" in result
def test_relative_traversal_blocked(self, sample_dir: Path) -> None:
"""A relative path with ../ that escapes cwd is rejected."""
tool = GrepTool()
result = tool._run(pattern="anything", path="../../etc")
assert "Error" in result
assert "outside the working directory" in result
def test_relative_path_within_cwd_allowed(self) -> None:
"""A relative path that stays inside cwd works fine."""
tool = GrepTool()
# "." is always within cwd
result = tool._run(pattern="zzz_will_not_match_anything_zzz", path=".")
# Should not get a traversal error — either matches or "No matches found"
assert "outside the working directory" not in result
def test_allow_unrestricted_paths_bypasses_check(self, tmp_path: Path) -> None:
"""With allow_unrestricted_paths=True, absolute paths outside cwd are allowed."""
# Write a searchable file in tmp_path
(tmp_path / "hello.txt").write_text("unrestricted search target\n")
tool = GrepTool(allow_unrestricted_paths=True)
result = tool._run(pattern="unrestricted", path=str(tmp_path))
assert "unrestricted search target" in result
def test_allow_unrestricted_defaults_false(self) -> None:
"""The flag defaults to False."""
tool = GrepTool()
assert tool.allow_unrestricted_paths is False
def test_error_message_includes_hint(self, tmp_path: Path) -> None:
"""The traversal error tells the user how to opt in."""
tool = GrepTool()
result = tool._run(pattern="x", path=str(tmp_path))
assert "GrepTool(allow_unrestricted_paths=True)" in result
class TestReDoSGuards:
"""Tests for regex denial-of-service mitigations."""
def test_pattern_length_rejected(self, sample_dir: Path) -> None:
"""Patterns exceeding MAX_REGEX_LENGTH are rejected before compilation."""
tool = GrepTool(allow_unrestricted_paths=True)
long_pattern = "a" * (MAX_REGEX_LENGTH + 1)
result = tool._run(pattern=long_pattern, path=str(sample_dir))
assert "Error" in result
assert "Pattern too long" in result
def test_pattern_at_max_length_accepted(self, sample_dir: Path) -> None:
"""A pattern exactly at MAX_REGEX_LENGTH is allowed (boundary check)."""
tool = GrepTool(allow_unrestricted_paths=True)
exact_pattern = "a" * MAX_REGEX_LENGTH
result = tool._run(pattern=exact_pattern, path=str(sample_dir))
# Should not get a length error — either matches or "No matches found"
assert "Pattern too long" not in result
def test_safe_search_returns_match(self) -> None:
"""_safe_search returns a match object for a normal pattern."""
compiled = __import__("re").compile(r"hello")
match = GrepTool._safe_search(compiled, "say hello world")
assert match is not None
assert match.group() == "hello"
def test_safe_search_returns_none_on_no_match(self) -> None:
"""_safe_search returns None when the pattern doesn't match."""
compiled = __import__("re").compile(r"zzz")
match = GrepTool._safe_search(compiled, "hello world")
assert match is None
class TestBraceExpansion:
"""Tests for glob brace expansion ({a,b} syntax)."""
def test_expand_simple_brace(self) -> None:
"""*.{py,txt} expands to ['*.py', '*.txt']."""
result = GrepTool._expand_brace_pattern("*.{py,txt}")
assert result == ["*.py", "*.txt"]
def test_expand_three_alternatives(self) -> None:
"""*.{py,txt,md} expands to three patterns."""
result = GrepTool._expand_brace_pattern("*.{py,txt,md}")
assert result == ["*.py", "*.txt", "*.md"]
def test_expand_no_braces_passthrough(self) -> None:
"""A pattern without braces is returned as a single-element list."""
result = GrepTool._expand_brace_pattern("*.py")
assert result == ["*.py"]
def test_expand_strips_whitespace(self) -> None:
"""Whitespace around alternatives inside braces is stripped."""
result = GrepTool._expand_brace_pattern("*.{ py , txt }")
assert result == ["*.py", "*.txt"]
def test_expand_prefix_and_suffix(self) -> None:
"""Prefix and suffix around the braces are preserved."""
result = GrepTool._expand_brace_pattern("src/*.{py,pyi}.bak")
assert result == ["src/*.py.bak", "src/*.pyi.bak"]
def test_brace_glob_end_to_end(self, tmp_path: Path) -> None:
"""Brace expansion works end-to-end with _collect_files."""
(tmp_path / "a.py").write_text("match_me\n")
(tmp_path / "b.txt").write_text("match_me\n")
(tmp_path / "c.md").write_text("match_me\n")
tool = GrepTool(allow_unrestricted_paths=True)
result = tool._run(
pattern="match_me",
path=str(tmp_path),
glob_pattern="*.{py,txt}",
)
assert "a.py" in result
assert "b.txt" in result
# .md should NOT be included
assert "c.md" not in result
def test_brace_glob_no_duplicates(self, tmp_path: Path) -> None:
"""Files are not reported twice when they match multiple expanded patterns."""
(tmp_path / "x.py").write_text("unique_content\n")
tool = GrepTool(allow_unrestricted_paths=True)
result = tool._run(
pattern="unique_content",
path=str(tmp_path),
glob_pattern="*.{py,py}",
output_mode="count",
)
# Should appear exactly once
assert result.count("x.py") == 1
class TestSensitiveFileProtection:
"""Tests for sensitive file exclusion (secrets leakage prevention)."""
@pytest.mark.parametrize(
"name",
[".env", ".env.local", ".netrc", ".npmrc", "secrets.json", "server.pem"],
)
def test_sensitive_files_excluded(self, tmp_path: Path, name: str) -> None:
"""Sensitive files are skipped even if they contain matches."""
(tmp_path / name).write_text("MATCH_ME\n")
tool = GrepTool(allow_unrestricted_paths=True)
result = tool._run(pattern="MATCH_ME", path=str(tmp_path))
assert "No matches found" in result
def test_sensitive_file_blocked_by_direct_path(self, tmp_path: Path) -> None:
"""A .env passed as the explicit path argument is still blocked."""
env = tmp_path / ".env"
env.write_text("SECRET=abc\n")
tool = GrepTool(allow_unrestricted_paths=True)
result = tool._run(pattern="SECRET", path=str(env))
assert "No matches found" in result
class TestFileSizeLimit:
"""Tests for max_file_size_bytes guard."""
def test_large_file_skipped(self, tmp_path: Path) -> None:
"""Files over max_file_size_bytes are skipped."""
(tmp_path / "big.txt").write_text("needle\n" * 100)
tool = GrepTool(allow_unrestricted_paths=True, max_file_size_bytes=50)
result = tool._run(pattern="needle", path=str(tmp_path))
assert "No matches found" in result
def test_large_file_searched_with_raised_limit(self, tmp_path: Path) -> None:
"""Raising the limit lets the same file be searched."""
(tmp_path / "big.txt").write_text("needle\n" * 100)
tool = GrepTool(allow_unrestricted_paths=True, max_file_size_bytes=50_000)
result = tool._run(pattern="needle", path=str(tmp_path))
assert "needle" in result
class TestContextLinesUpperBound:
"""Tests for context_lines validation bounds."""
def test_negative_rejected(self) -> None:
"""context_lines < 0 is rejected by Pydantic."""
with pytest.raises(ValidationError):
GrepToolSchema(pattern="x", context_lines=-1)
def test_over_max_rejected(self) -> None:
"""context_lines > MAX_CONTEXT_LINES is rejected by Pydantic."""
with pytest.raises(ValidationError):
GrepToolSchema(pattern="x", context_lines=MAX_CONTEXT_LINES + 1)

View File

@@ -10150,6 +10150,141 @@
"type": "object"
}
},
{
"description": "A tool that searches file contents on disk using regex patterns. Recursively searches files in a directory for matching lines. Returns matching content with line numbers, file paths only, or match counts.",
"env_vars": [],
"humanized_name": "Search file contents",
"init_params_schema": {
"$defs": {
"EnvVar": {
"properties": {
"default": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Default"
},
"description": {
"title": "Description",
"type": "string"
},
"name": {
"title": "Name",
"type": "string"
},
"required": {
"default": true,
"title": "Required",
"type": "boolean"
}
},
"required": [
"name",
"description"
],
"title": "EnvVar",
"type": "object"
}
},
"description": "Tool for searching file contents on disk using regex patterns.\n\nRecursively searches files in a directory for lines matching a regex pattern.\nSupports glob filtering, context lines, and multiple output modes.\n\nExample:\n >>> tool = GrepTool()\n >>> result = tool.run(pattern=\"def.*main\", path=\"src\")\n >>> result = tool.run(\n ... pattern=\"TODO\",\n ... glob_pattern=\"*.py\",\n ... context_lines=2,\n ... )\n\n To search any path on the filesystem (opt-in):\n >>> tool = GrepTool(allow_unrestricted_paths=True)\n >>> result = tool.run(pattern=\"error\", path=\"/var/log/app\")",
"properties": {
"allow_unrestricted_paths": {
"default": false,
"description": "When False (default), searches are restricted to the current working directory. Set to True to allow searching any path on the filesystem.",
"title": "Allow Unrestricted Paths",
"type": "boolean"
},
"max_file_size_bytes": {
"default": 10485760,
"description": "Maximum file size in bytes to search. Files larger than this are skipped. Defaults to 10 MB.",
"title": "Max File Size Bytes",
"type": "integer"
}
},
"title": "GrepTool",
"type": "object"
},
"name": "GrepTool",
"package_dependencies": [],
"run_params_schema": {
"description": "Schema for grep tool arguments.",
"properties": {
"case_insensitive": {
"default": false,
"description": "Whether to perform case-insensitive matching",
"title": "Case Insensitive",
"type": "boolean"
},
"context_lines": {
"default": 0,
"description": "Number of lines to show before and after each match (0-10)",
"maximum": 10,
"minimum": 0,
"title": "Context Lines",
"type": "integer"
},
"glob_pattern": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Glob pattern to filter files (e.g. '*.py'). Supports brace expansion (e.g. '*.{ts,tsx}').",
"title": "Glob Pattern"
},
"include_line_numbers": {
"default": true,
"description": "Whether to prefix matching lines with line numbers",
"title": "Include Line Numbers",
"type": "boolean"
},
"output_mode": {
"default": "content",
"description": "Output mode: 'content' shows matching lines, 'files_with_matches' shows only file paths, 'count' shows match counts per file",
"enum": [
"content",
"files_with_matches",
"count"
],
"title": "Output Mode",
"type": "string"
},
"path": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "File or directory to search in. Defaults to current working directory.",
"title": "Path"
},
"pattern": {
"description": "Regex pattern to search for in file contents",
"title": "Pattern",
"type": "string"
}
},
"required": [
"pattern"
],
"title": "GrepToolSchema",
"type": "object"
}
},
{
"description": "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html",
"env_vars": [

View File

@@ -53,7 +53,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.10.2rc2",
"crewai-tools==1.10.1",
]
embeddings = [
"tiktoken~=0.8.0"
@@ -105,15 +105,10 @@ a2a = [
file-processing = [
"crewai-files",
]
cli = [
"crewai-cli",
]
# CLI entry point has moved to the crewai-cli package.
# Install it via: pip install crewai[cli]
# [project.scripts]
# crewai = "crewai.cli.cli:crewai"
[project.scripts]
crewai = "crewai.cli.cli:crewai"
# PyTorch index configuration, since torch 2.5.0 is not compatible with python 3.13

View File

@@ -1,4 +1,3 @@
import contextvars
import threading
from typing import Any
import urllib.request
@@ -41,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.10.2rc2"
__version__ = "1.10.1"
_telemetry_submitted = False
@@ -67,8 +66,7 @@ def _track_install() -> None:
def _track_install_async() -> None:
"""Track installation in background thread to avoid blocking imports."""
if not Telemetry._is_telemetry_disabled():
ctx = contextvars.copy_context()
thread = threading.Thread(target=ctx.run, args=(_track_install,), daemon=True)
thread = threading.Thread(target=_track_install, daemon=True)
thread.start()

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio
from collections.abc import MutableMapping
import concurrent.futures
import contextvars
from functools import lru_cache
import ssl
import time
@@ -148,9 +147,8 @@ def fetch_agent_card(
has_running_loop = False
if has_running_loop:
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(ctx.run, asyncio.run, coro).result()
return pool.submit(asyncio.run, coro).result()
return asyncio.run(coro)
@@ -217,9 +215,8 @@ def _fetch_agent_card_cached(
has_running_loop = False
if has_running_loop:
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(ctx.run, asyncio.run, coro).result()
return pool.submit(asyncio.run, coro).result()
return asyncio.run(coro)

View File

@@ -7,7 +7,6 @@ import base64
from collections.abc import AsyncIterator, Callable, MutableMapping
import concurrent.futures
from contextlib import asynccontextmanager
import contextvars
import logging
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
@@ -230,9 +229,8 @@ def execute_a2a_delegation(
has_running_loop = False
if has_running_loop:
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(ctx.run, asyncio.run, coro).result()
return pool.submit(asyncio.run, coro).result()
return asyncio.run(coro)

View File

@@ -8,7 +8,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Mapping
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
from functools import wraps
import json
from types import MethodType
@@ -279,9 +278,7 @@ def _fetch_agent_cards_concurrently(
max_workers = min(len(a2a_agents), 10)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(
contextvars.copy_context().run, _fetch_card_from_config, config
): config
executor.submit(_fetch_card_from_config, config): config
for config in a2a_agents
}
for future in as_completed(futures):

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Sequence
import contextvars
import shutil
import subprocess
import time
@@ -514,13 +513,9 @@ class Agent(BaseAgent):
"""
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
ctx.run,
self._execute_without_timeout,
task_prompt=task_prompt,
task=task,
self._execute_without_timeout, task_prompt=task_prompt, task=task
)
try:

View File

@@ -38,7 +38,7 @@ from crewai.utilities.string_utils import interpolate_only
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$"
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$"
)

View File

@@ -895,9 +895,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
ToolUsageStartedEvent,
)
args_dict, parse_error = parse_tool_call_args(
func_args, func_name, call_id, original_tool
)
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
if parse_error is not None:
return parse_error

View File

@@ -1,7 +0,0 @@
"""Authentication utilities for the CrewAI platform."""
from crewai.auth.oauth2 import AuthenticationCommand
from crewai.auth.token import AuthError, get_auth_token
__all__ = ["AuthError", "AuthenticationCommand", "get_auth_token"]

View File

@@ -1,3 +0,0 @@
"""Authentication constants."""
ALGORITHMS = ["RS256"]

View File

@@ -1,184 +0,0 @@
"""OAuth2 authentication for the CrewAI platform."""
import time
from typing import TYPE_CHECKING, Any, TypeVar, cast
import webbrowser
import httpx
from pydantic import BaseModel, Field
from rich.console import Console
from crewai.auth.token_manager import TokenManager
from crewai.auth.utils import validate_jwt_token
from crewai.settings import Settings
console = Console()
TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
class Oauth2Settings(BaseModel):
"""OAuth2 provider configuration."""
provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
)
client_id: str = Field(
description="OAuth2 client ID issued by the provider, used during authentication requests."
)
domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
)
audience: str | None = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=None,
)
extra: dict[str, Any] = Field(
description="Extra configuration for the OAuth2 provider.",
default={},
)
@classmethod
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
"""Create an Oauth2Settings instance from the CLI settings."""
settings = Settings()
return cls(
provider=settings.oauth2_provider,
domain=settings.oauth2_domain,
client_id=settings.oauth2_client_id,
audience=settings.oauth2_audience,
extra=settings.oauth2_extra,
)
if TYPE_CHECKING:
from crewai.auth.providers.base_provider import BaseProvider
class ProviderFactory:
"""Factory for creating OAuth2 providers from settings."""
@classmethod
def from_settings(
cls: type["ProviderFactory"], # noqa: UP037
settings: Oauth2Settings | None = None,
) -> "BaseProvider": # noqa: UP037
"""Create a provider instance from settings."""
settings = settings or Oauth2Settings.from_settings()
import importlib
module = importlib.import_module(
f"crewai.auth.providers.{settings.provider.lower()}"
)
provider = getattr(
module,
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
)
return cast("BaseProvider", provider(settings))
class AuthenticationCommand:
"""Handles authentication with the CrewAI platform."""
def __init__(self) -> None:
self.token_manager = TokenManager()
self.oauth2_provider = ProviderFactory.from_settings()
def login(self) -> None:
"""Sign up to CrewAI+"""
console.print("Signing in to CrewAI AMP...\n", style="bold blue")
device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data)
return self._poll_for_token(device_code_data)
def _get_device_code(self) -> dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
"client_id": self.oauth2_provider.get_client_id(),
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
"audience": self.oauth2_provider.get_audience(),
}
response = httpx.post(
url=self.oauth2_provider.get_authorize_url(),
data=device_code_payload,
timeout=20,
)
response.raise_for_status()
return cast(dict[str, Any], response.json())
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
"""Display the authentication instructions to the user."""
verification_uri = device_code_data.get(
"verification_uri_complete", device_code_data.get("verification_uri", "")
)
console.print("1. Navigate to: ", verification_uri)
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(verification_uri)
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
"""Polls the server for the token until it is received, or max attempts are reached."""
token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code_data["device_code"],
"client_id": self.oauth2_provider.get_client_id(),
}
console.print("\nWaiting for authentication... ", style="bold blue", end="")
attempts = 0
while True and attempts < 10:
response = httpx.post(
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
)
token_data = response.json()
if response.status_code == 200:
self._validate_and_save_token(token_data)
console.print(
"Success!",
style="bold green",
)
self._post_login()
console.print("\n[bold green]Welcome to CrewAI AMP![/bold green]\n")
return
if token_data["error"] not in ("authorization_pending", "slow_down"):
raise httpx.HTTPError(
token_data.get("error_description") or token_data.get("error")
)
time.sleep(device_code_data["interval"])
attempts += 1
console.print(
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
def _validate_and_save_token(self, token_data: dict[str, Any]) -> None:
"""Validates the JWT token and saves the token to the token manager."""
jwt_token = token_data["access_token"]
issuer = self.oauth2_provider.get_issuer()
jwt_token_data = {
"jwt_token": jwt_token,
"jwks_url": self.oauth2_provider.get_jwks_url(),
"issuer": issuer,
"audience": self.oauth2_provider.get_audience(),
}
decoded_token = validate_jwt_token(**jwt_token_data)
expires_at = decoded_token.get("exp", 0)
self.token_manager.save_tokens(jwt_token, expires_at)
def _post_login(self) -> None:
"""Hook called after successful login. Override in subclasses for additional behavior."""

View File

@@ -1 +0,0 @@
"""OAuth2 authentication providers."""

View File

@@ -1,38 +0,0 @@
"""Auth0 OAuth2 provider."""
from crewai.auth.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):
"""Auth0 OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth/device/code"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/.well-known/jwks.json"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}/"
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def _get_domain(self) -> str:
if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain

View File

@@ -1,38 +0,0 @@
"""Base OAuth2 provider interface."""
from abc import ABC, abstractmethod
from crewai.auth.oauth2 import Oauth2Settings
class BaseProvider(ABC):
"""Abstract base class for OAuth2 providers."""
def __init__(self, settings: Oauth2Settings):
self.settings = settings
@abstractmethod
def get_authorize_url(self) -> str: ...
@abstractmethod
def get_token_url(self) -> str: ...
@abstractmethod
def get_jwks_url(self) -> str: ...
@abstractmethod
def get_issuer(self) -> str: ...
@abstractmethod
def get_audience(self) -> str: ...
@abstractmethod
def get_client_id(self) -> str: ...
def get_required_fields(self) -> list[str]:
"""Returns which provider-specific fields inside the "extra" dict will be required."""
return []
def get_oauth_scopes(self) -> list[str]:
"""Returns the OAuth scopes to request."""
return ["openid", "profile", "email"]

View File

@@ -1,47 +0,0 @@
"""Entra ID (Azure AD) OAuth2 provider."""
from typing import cast
from crewai.auth.providers.base_provider import BaseProvider
class EntraIdProvider(BaseProvider):
"""Entra ID (Azure AD) OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"{self._base_url()}/oauth2/v2.0/devicecode"
def get_token_url(self) -> str:
return f"{self._base_url()}/oauth2/v2.0/token"
def get_jwks_url(self) -> str:
return f"{self._base_url()}/discovery/v2.0/keys"
def get_issuer(self) -> str:
return f"{self._base_url()}/v2.0"
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_oauth_scopes(self) -> list[str]:
return [
*super().get_oauth_scopes(),
*cast(str, self.settings.extra.get("scope", "")).split(),
]
def get_required_fields(self) -> list[str]:
return ["scope"]
def _base_url(self) -> str:
return f"https://login.microsoftonline.com/{self.settings.domain}"

View File

@@ -1,36 +0,0 @@
"""Keycloak OAuth2 provider."""
from crewai.auth.providers.base_provider import BaseProvider
class KeycloakProvider(BaseProvider):
"""Keycloak OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/auth/device"
def get_token_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/token"
def get_jwks_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/certs"
def get_issuer(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}"
def get_audience(self) -> str:
return self.settings.audience or "no-audience-provided"
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_required_fields(self) -> list[str]:
return ["realm"]
def _oauth2_base_url(self) -> str:
domain = self.settings.domain.removeprefix("https://").removeprefix("http://")
return f"https://{domain}"

View File

@@ -1,46 +0,0 @@
"""Okta OAuth2 provider."""
from crewai.auth.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):
"""Okta OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/device/authorize"
def get_token_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/token"
def get_jwks_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/keys"
def get_issuer(self) -> str:
return self._oauth2_base_url().removesuffix("/oauth2")
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_required_fields(self) -> list[str]:
return ["authorization_server_name", "using_org_auth_server"]
def _oauth2_base_url(self) -> str:
using_org_auth_server = self.settings.extra.get("using_org_auth_server", False)
if using_org_auth_server:
base_url = f"https://{self.settings.domain}/oauth2"
else:
base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}"
return f"{base_url}"

View File

@@ -1,34 +0,0 @@
"""WorkOS OAuth2 provider."""
from crewai.auth.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):
"""WorkOS OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/device_authorization"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/jwks"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}"
def get_audience(self) -> str:
return self.settings.audience or ""
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def _get_domain(self) -> str:
if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain

View File

@@ -1,15 +0,0 @@
"""Authentication token retrieval."""
from crewai.auth.token_manager import TokenManager
class AuthError(Exception):
"""Raised when authentication fails."""
def get_auth_token() -> str:
"""Get the authentication token."""
access_token = TokenManager().get_token()
if not access_token:
raise AuthError("No token found, make sure you are logged in")
return access_token

View File

@@ -1,188 +0,0 @@
"""Manages encrypted token storage."""
from datetime import datetime
import json
import os
from pathlib import Path
import sys
import tempfile
from typing import Final, Literal, cast
from cryptography.fernet import Fernet
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44
class TokenManager:
"""Manages encrypted token storage."""
def __init__(self, file_path: str = "tokens.enc") -> None:
"""Initialize the TokenManager.
Args:
file_path: The file path to store encrypted tokens.
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""Get or create the encryption key.
Returns:
The encryption key as bytes.
"""
key_filename: str = "secret.key"
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
new_key = Fernet.generate_key()
if self._atomic_create_secure_file(key_filename, new_key):
return new_key
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
raise RuntimeError("Failed to create or read encryption key")
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""Save the access token and its expiration time.
Args:
access_token: The access token to save.
expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self._atomic_write_secure_file(self.file_path, encrypted_data)
def get_token(self) -> str | None:
"""Get the access token if it is valid and not expired.
Returns:
The access token if valid and not expired, otherwise None.
"""
encrypted_data = self._read_secure_file(self.file_path)
if encrypted_data is None:
return None
decrypted_data = self.fernet.decrypt(encrypted_data)
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return cast(str | None, data.get("access_token"))
def clear_tokens(self) -> None:
"""Clear the stored tokens."""
self._delete_secure_file(self.file_path)
@staticmethod
def _get_secure_storage_path() -> Path:
"""Get the secure storage path based on the operating system.
Returns:
The secure storage path.
"""
if sys.platform == "win32":
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
base_path = os.path.expanduser("~/Library/Application Support")
else:
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
"""Create a file only if it doesn't exist.
Args:
filename: The name of the file.
content: The content to write.
Returns:
True if file was created, False if it already exists.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
try:
os.write(fd, content)
finally:
os.close(fd)
return True
except FileExistsError:
return False
def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
"""Write content to a secure file.
Args:
filename: The name of the file.
content: The content to write.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
fd_closed = False
try:
os.write(fd, content)
os.close(fd)
fd_closed = True
os.chmod(temp_path, 0o600)
os.replace(temp_path, file_path)
except Exception:
if not fd_closed:
os.close(fd)
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
def _read_secure_file(self, filename: str) -> bytes | None:
"""Read the content of a secure file.
Args:
filename: The name of the file.
Returns:
The content of the file if it exists, otherwise None.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
with open(file_path, "rb") as f:
return f.read()
except FileNotFoundError:
return None
def _delete_secure_file(self, filename: str) -> None:
"""Delete a secure file.
Args:
filename: The name of the file.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
file_path.unlink()
except FileNotFoundError:
pass

View File

@@ -1,67 +0,0 @@
"""JWT token validation utilities."""
from typing import Any
import jwt
from jwt import PyJWKClient
def validate_jwt_token(
jwt_token: str, jwks_url: str, issuer: str, audience: str
) -> Any:
"""Verify the token's signature and claims using PyJWT.
Args:
jwt_token: The JWT (JWS) string to validate.
jwks_url: The URL of the JWKS endpoint.
issuer: The expected issuer of the token.
audience: The expected audience of the token.
Returns:
The decoded token.
Raises:
Exception: If the token is invalid for any reason.
"""
try:
jwk_client = PyJWKClient(jwks_url)
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
_unverified_decoded_token = jwt.decode(
jwt_token, options={"verify_signature": False}
)
return jwt.decode(
jwt_token,
signing_key.key,
algorithms=["RS256"],
audience=audience,
issuer=issuer,
leeway=10.0,
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
except jwt.ExpiredSignatureError as e:
raise Exception("Token has expired.") from e
except jwt.InvalidAudienceError as e:
actual_audience = _unverified_decoded_token.get("aud", "[no audience found]")
raise Exception(
f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'"
) from e
except jwt.InvalidIssuerError as e:
actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]")
raise Exception(
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
) from e
except jwt.MissingRequiredClaimError as e:
raise Exception(f"Token is missing required claims: {e!s}") from e
except jwt.exceptions.PyJWKClientError as e:
raise Exception(f"JWKS or key processing error: {e!s}") from e
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {e!s}") from e

View File

@@ -2,15 +2,19 @@ from pathlib import Path
import click
from crewai_cli.utils import copy_template
from crewai.cli.utils import copy_template
from crewai.utilities.printer import Printer
_printer = Printer()
def add_crew_to_flow(crew_name: str) -> None:
"""Add a new crew to the current flow."""
# Check if pyproject.toml exists in the current directory
if not Path("pyproject.toml").exists():
click.secho(
"This command must be run from the root of a flow project.", fg="red"
_printer.print(
"This command must be run from the root of a flow project.", color="red"
)
raise click.ClickException(
"This command must be run from the root of a flow project."
@@ -21,7 +25,7 @@ def add_crew_to_flow(crew_name: str) -> None:
crews_folder = flow_folder / "src" / flow_folder.name / "crews"
if not crews_folder.exists():
click.secho("Crews folder does not exist in the current flow.", fg="red")
_printer.print("Crews folder does not exist in the current flow.", color="red")
raise click.ClickException("Crews folder does not exist in the current flow.")
# Create the crew within the flow's crews directory

View File

@@ -0,0 +1,4 @@
from crewai.cli.authentication.main import AuthenticationCommand
__all__ = ["AuthenticationCommand"]

View File

@@ -6,9 +6,9 @@ import httpx
from pydantic import BaseModel, Field
from rich.console import Console
from crewai_cli.authentication.utils import validate_jwt_token
from crewai_cli.config import Settings
from crewai_cli.shared.token_manager import TokenManager
from crewai.cli.authentication.utils import validate_jwt_token
from crewai.cli.config import Settings
from crewai.cli.shared.token_manager import TokenManager
console = Console()
@@ -51,7 +51,7 @@ class Oauth2Settings(BaseModel):
if TYPE_CHECKING:
from crewai_cli.authentication.providers.base_provider import BaseProvider
from crewai.cli.authentication.providers.base_provider import BaseProvider
class ProviderFactory:
@@ -65,7 +65,7 @@ class ProviderFactory:
import importlib
module = importlib.import_module(
f"crewai_cli.authentication.providers.{settings.provider.lower()}"
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
)
# Converts from snake_case to CamelCase to obtain the provider class name.
provider = getattr(
@@ -180,7 +180,7 @@ class AuthenticationCommand:
def _login_to_tool_repository(self) -> None:
"""Login to the tool repository."""
from crewai_cli.tools.main import ToolCommand
from crewai.cli.tools.main import ToolCommand
try:
console.print(

View File

@@ -1,4 +1,4 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
from crewai.cli.authentication.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from crewai_cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.main import Oauth2Settings
class BaseProvider(ABC):

View File

@@ -1,6 +1,6 @@
from typing import cast
from crewai_cli.authentication.providers.base_provider import BaseProvider
from crewai.cli.authentication.providers.base_provider import BaseProvider
class EntraIdProvider(BaseProvider):

View File

@@ -1,4 +1,4 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
from crewai.cli.authentication.providers.base_provider import BaseProvider
class KeycloakProvider(BaseProvider):

View File

@@ -1,4 +1,4 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
from crewai.cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):

View File

@@ -1,4 +1,4 @@
from crewai_cli.authentication.providers.base_provider import BaseProvider
from crewai.cli.authentication.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):

View File

@@ -1,4 +1,4 @@
from crewai_cli.shared.token_manager import TokenManager
from crewai.cli.shared.token_manager import TokenManager
class AuthError(Exception):

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from importlib.metadata import version as get_version
import os
import subprocess
@@ -7,58 +5,44 @@ from typing import Any
import click
from crewai_cli.add_crew_to_flow import add_crew_to_flow
from crewai_cli.authentication.main import AuthenticationCommand
from crewai_cli.config import Settings
from crewai_cli.create_crew import create_crew
from crewai_cli.create_flow import create_flow
from crewai_cli.crew_chat import run_chat
from crewai_cli.deploy.main import DeployCommand
from crewai_cli.enterprise.main import EnterpriseConfigureCommand
from crewai_cli.evaluate_crew import evaluate_crew
from crewai_cli.install_crew import install_crew
from crewai_cli.kickoff_flow import kickoff_flow
from crewai_cli.organization.main import OrganizationCommand
from crewai_cli.plot_flow import plot_flow
from crewai_cli.replay_from_task import replay_task_command
from crewai_cli.reset_memories_command import reset_memories_command
from crewai_cli.run_crew import run_crew
from crewai_cli.settings.main import SettingsCommand
from crewai_cli.task_outputs import load_task_outputs
from crewai_cli.tools.main import ToolCommand
from crewai_cli.train_crew import train_crew
from crewai_cli.triggers.main import TriggersCommand
from crewai_cli.update_crew import update_crew
from crewai_cli.user_data import (
_load_user_data,
_save_user_data,
is_tracing_enabled,
from crewai.cli.add_crew_to_flow import add_crew_to_flow
from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.config import Settings
from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow
from crewai.cli.crew_chat import run_chat
from crewai.cli.deploy.main import DeployCommand
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
from crewai.cli.evaluate_crew import evaluate_crew
from crewai.cli.install_crew import install_crew
from crewai.cli.kickoff_flow import kickoff_flow
from crewai.cli.organization.main import OrganizationCommand
from crewai.cli.plot_flow import plot_flow
from crewai.cli.replay_from_task import replay_task_command
from crewai.cli.reset_memories_command import reset_memories_command
from crewai.cli.run_crew import run_crew
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.tools.main import ToolCommand
from crewai.cli.train_crew import train_crew
from crewai.cli.triggers.main import TriggersCommand
from crewai.cli.update_crew import update_crew
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
from crewai_cli.utils import build_env_with_tool_repository_credentials, read_toml
def _get_cli_version() -> str:
"""Return the best available version string for the CLI."""
# Prefer crewai version if installed (keeps existing UX)
try:
return get_version("crewai")
except Exception: # noqa: S110
pass
try:
return get_version("crewai-cli")
except Exception:
return "unknown"
@click.group()
@click.version_option(_get_cli_version())
@click.version_option(get_version("crewai"))
def crewai():
"""Top-level command group for crewai."""
@crewai.command(
name="uv",
context_settings={"ignore_unknown_options": True},
context_settings=dict(
ignore_unknown_options=True,
),
)
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
def uv(uv_args):
@@ -123,7 +107,7 @@ def version(tools):
if tools:
try:
tools_version = get_version("crewai-tools")
tools_version = get_version("crewai")
click.echo(f"crewai tools version: {tools_version}")
except Exception:
click.echo("crewai tools not installed")
@@ -158,7 +142,12 @@ def train(n_iterations: int, filename: str):
help="Replay the crew from this task ID, including all subsequent tasks.",
)
def replay(task_id: str) -> None:
"""Replay the crew execution from a specific task."""
"""
Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
try:
click.echo(f"Replaying the crew from task {task_id}")
replay_task_command(task_id)
@@ -168,9 +157,12 @@ def replay(task_id: str) -> None:
@crewai.command()
def log_tasks_outputs() -> None:
"""Retrieve your latest crew.kickoff() task outputs."""
"""
Retrieve your latest crew.kickoff() task outputs.
"""
try:
tasks = load_task_outputs()
storage = KickoffTaskOutputsSQLiteStorage()
tasks = storage.load()
if not tasks:
click.echo(
@@ -190,24 +182,15 @@ def log_tasks_outputs() -> None:
@crewai.command()
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
@click.option(
"-l",
"--long",
is_flag=True,
hidden=True,
"-l", "--long", is_flag=True, hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option(
"-s",
"--short",
is_flag=True,
hidden=True,
"-s", "--short", is_flag=True, hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option(
"-e",
"--entities",
is_flag=True,
hidden=True,
"-e", "--entities", is_flag=True, hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
@@ -228,17 +211,14 @@ def reset_memories(
agent_knowledge: bool,
all: bool,
) -> None:
"""Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved."""
"""
Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved.
"""
try:
# Treat legacy flags as --memory with a deprecation warning
if long or short or entities:
legacy_used = [
f
for f, v in [
("--long", long),
("--short", short),
("--entities", entities),
]
if v
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
]
click.echo(
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
@@ -258,7 +238,9 @@ def reset_memories(
"Please specify at least one memory type to reset using the appropriate flags."
)
return
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
reset_memories_command(
memory, knowledge, agent_knowledge, kickoff_outputs, all
)
except Exception as e:
click.echo(f"An error occurred while resetting memories: {e}", err=True)
@@ -296,7 +278,7 @@ def memory(
) -> None:
"""Open the Memory TUI to browse scopes and recall memories."""
try:
from crewai_cli.memory_tui import MemoryTUI
from crewai.cli.memory_tui import MemoryTUI
except ImportError as exc:
click.echo(
"Textual is required for the memory TUI but could not be imported. "
@@ -346,10 +328,10 @@ def test(n_iterations: int, model: str):
@crewai.command(
context_settings={
"ignore_unknown_options": True,
"allow_extra_args": True,
}
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.pass_context
def install(context):
@@ -514,12 +496,14 @@ def triggers_run(trigger_path: str):
@crewai.command()
def chat():
"""Start a conversation with the Crew, collecting user-supplied inputs,
"""
Start a conversation with the Crew, collecting user-supplied inputs,
and using the Chat LLM to generate responses.
"""
click.secho(
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
)
run_chat()
@@ -630,7 +614,7 @@ def env_view():
table.add_row(
"CREWAI_TRACING_ENABLED",
"[dim]Not set[/dim]",
"[dim]---[/dim]",
"[dim][/dim]",
)
# Check other related env vars
@@ -649,7 +633,7 @@ def env_view():
# Check if .env file exists
table.add_row(
".env file",
"Found" if env_file_exists else "Not found",
"Found" if env_file_exists else "Not found",
str(env_file.resolve()) if env_file_exists else "N/A",
)
@@ -665,11 +649,11 @@ def env_view():
# Show helpful message
if env_file_exists:
console.print(
"\n[dim]Tip: To enable tracing via .env, add: CREWAI_TRACING_ENABLED=true[/dim]"
"\n[dim]💡 Tip: To enable tracing via .env, add: CREWAI_TRACING_ENABLED=true[/dim]"
)
else:
console.print(
"\n[dim]Tip: Create a .env file in your project root and add: CREWAI_TRACING_ENABLED=true[/dim]"
"\n[dim]💡 Tip: Create a .env file in your project root and add: CREWAI_TRACING_ENABLED=true[/dim]"
)
console.print()
@@ -685,6 +669,11 @@ def traces_enable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import (
_load_user_data,
_save_user_data,
)
console = Console()
# Update user data to enable traces
@@ -694,7 +683,7 @@ def traces_enable():
_save_user_data(user_data)
panel = Panel(
"Trace collection has been enabled!\n\n"
"Trace collection has been enabled!\n\n"
"Your crew/flow executions will now send traces to CrewAI+.\n"
"Use 'crewai traces disable' to turn off trace collection.",
title="Traces Enabled",
@@ -710,6 +699,11 @@ def traces_disable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import (
_load_user_data,
_save_user_data,
)
console = Console()
# Update user data to disable traces
@@ -719,7 +713,7 @@ def traces_disable():
_save_user_data(user_data)
panel = Panel(
"Trace collection has been disabled!\n\n"
"Trace collection has been disabled!\n\n"
"Your crew/flow executions will no longer send traces.\n"
"Use 'crewai traces enable' to turn trace collection back on.",
title="Traces Disabled",
@@ -738,6 +732,11 @@ def traces_status():
from rich.panel import Panel
from rich.table import Table
from crewai.events.listeners.tracing.utils import (
_load_user_data,
is_tracing_enabled,
)
console = Console()
user_data = _load_user_data()
@@ -752,19 +751,19 @@ def traces_status():
# Check user consent
trace_consent = user_data.get("trace_consent")
if trace_consent is True:
consent_status = "Enabled (user consented)"
consent_status = "Enabled (user consented)"
elif trace_consent is False:
consent_status = "Disabled (user declined)"
consent_status = "Disabled (user declined)"
else:
consent_status = "Not set (first-time user)"
consent_status = "Not set (first-time user)"
table.add_row("User Consent", consent_status)
# Check overall status
if is_tracing_enabled():
overall_status = "ENABLED"
overall_status = "ENABLED"
border_style = "green"
else:
overall_status = "DISABLED"
overall_status = "DISABLED"
border_style = "red"
table.add_row("Overall Status", overall_status)

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
import json
import httpx
from rich.console import Console
from crewai_cli.authentication.token import get_auth_token
from crewai_cli.plus_api import PlusAPI
from crewai.cli.authentication.token import get_auth_token
from crewai.cli.plus_api import PlusAPI
from crewai.telemetry.telemetry import Telemetry
console = Console()
@@ -14,14 +13,17 @@ console = Console()
class BaseCommand:
def __init__(self) -> None:
pass
self._telemetry = Telemetry()
self._telemetry.set_tracer()
class PlusAPIMixin:
def __init__(self) -> None:
def __init__(self, telemetry: Telemetry) -> None:
try:
telemetry.set_tracer()
self.plus_api_client = PlusAPI(api_key=get_auth_token())
except Exception:
telemetry.deploy_signup_error_span()
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
@@ -30,6 +32,12 @@ class PlusAPIMixin:
raise SystemExit from None
def _validate_response(self, response: httpx.Response) -> None:
"""
Handle and display error messages from API responses.
Args:
response (httpx.Response): The response from the Plus API
"""
try:
json_response = response.json()
except (json.JSONDecodeError, ValueError):

View File

@@ -6,14 +6,14 @@ from typing import Any
from pydantic import BaseModel, Field
from crewai_cli.constants import (
from crewai.cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
DEFAULT_CREWAI_ENTERPRISE_URL,
)
from crewai_cli.shared.token_manager import TokenManager
from crewai.cli.shared.token_manager import TokenManager
logger = getLogger(__name__)

View File

@@ -5,13 +5,13 @@ import sys
import click
import tomli
from crewai_cli.constants import ENV_VARS, MODELS
from crewai_cli.provider import (
from crewai.cli.constants import ENV_VARS, MODELS
from crewai.cli.provider import (
get_provider_data,
select_model,
select_provider,
)
from crewai_cli.utils import copy_template, load_env_vars, write_env_file
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
def get_reserved_script_names() -> set[str]:

View File

@@ -3,6 +3,8 @@ import shutil
import click
from crewai.telemetry import Telemetry
def create_flow(name):
"""Create a new flow."""
@@ -16,6 +18,10 @@ def create_flow(name):
click.secho(f"Error: Folder {folder_name} already exists.", fg="red")
return
# Initialize telemetry
telemetry = Telemetry()
telemetry.flow_creation_span(class_name)
# Create directory structure
(project_root / "src" / folder_name).mkdir(parents=True)
(project_root / "src" / folder_name / "crews").mkdir(parents=True)

View File

@@ -1,6 +1,3 @@
"""Interactive chat interface for CrewAI crews."""
import contextvars
import json
from pathlib import Path
import platform
@@ -14,15 +11,15 @@ import click
from packaging import version
import tomli
from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
from crewai.crew import Crew
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer
from crewai.utilities.project_utils import read_toml
from crewai.utilities.types import LLMMessage
from crewai.version import get_crewai_version
_printer = Printer()
@@ -33,14 +30,15 @@ MIN_REQUIRED_VERSION: Final[Literal["0.98.0"]] = "0.98.0"
def check_conversational_crews_version(
crewai_version: str, pyproject_data: dict[str, Any]
) -> bool:
"""Check if the installed crewAI version supports conversational crews.
"""
Check if the installed crewAI version supports conversational crews.
Args:
crewai_version: The current version of crewAI.
pyproject_data: Dictionary containing pyproject.toml data.
Returns:
True if version check passes, False otherwise.
bool: True if version check passes, False otherwise.
"""
try:
if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION):
@@ -57,8 +55,8 @@ def check_conversational_crews_version(
def run_chat() -> None:
"""Run an interactive chat loop using the Crew's chat LLM with function calling.
"""
Runs an interactive chat loop using the Crew's chat LLM with function calling.
Incorporates crew_name, crew_description, and input fields to build a tool schema.
Exits if crew_name or crew_description are missing.
"""
@@ -73,17 +71,16 @@ def run_chat() -> None:
if not chat_llm:
return
# Indicate that the crew is being analyzed
click.secho(
"\nAnalyzing crew and required inputs - this may take 3 to 30 seconds "
"depending on the complexity of your crew.",
fg="white",
)
# Start loading indicator
loading_complete = threading.Event()
ctx = contextvars.copy_context()
loading_thread = threading.Thread(
target=ctx.run, args=(show_loading, loading_complete)
)
loading_thread = threading.Thread(target=show_loading, args=(loading_complete,))
loading_thread.start()
try:
@@ -91,13 +88,16 @@ def run_chat() -> None:
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
system_message = build_system_message(crew_chat_inputs)
# Call the LLM to generate the introductory message
introductory_message = chat_llm.call(
messages=[{"role": "system", "content": system_message}]
)
finally:
# Stop loading indicator
loading_complete.set()
loading_thread.join()
# Indicate that the analysis is complete
click.secho("\nFinished analyzing crew.\n", fg="white")
click.secho(f"Assistant: {introductory_message}\n", fg="green")
@@ -123,7 +123,7 @@ def show_loading(event: threading.Event) -> None:
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
"""Initialize the chat LLM and handle exceptions."""
"""Initializes the chat LLM and handles exceptions."""
try:
return create_llm(crew.chat_llm)
except Exception as e:
@@ -135,7 +135,7 @@ def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
def build_system_message(crew_chat_inputs: ChatInputs) -> str:
"""Build the initial system message for the chat."""
"""Builds the initial system message for the chat."""
required_fields_str = (
", ".join(
f"{field.name} (desc: {field.description or 'n/a'})"
@@ -164,7 +164,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
"""Create a wrapper function for running the crew tool with messages."""
"""Creates a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs: Any) -> str:
return run_crew_tool(crew, messages, **kwargs)
@@ -175,11 +175,13 @@ def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
def flush_input() -> None:
"""Flush any pending input from the user."""
if platform.system() == "Windows":
# Windows platform
import msvcrt
while msvcrt.kbhit(): # type: ignore[attr-defined]
msvcrt.getch() # type: ignore[attr-defined]
else:
# Unix-like platforms (Linux, macOS)
import termios
termios.tcflush(sys.stdin, termios.TCIFLUSH)
@@ -194,6 +196,7 @@ def chat_loop(
"""Main chat loop for interacting with the user."""
while True:
try:
# Flush any pending input before accepting new input
flush_input()
user_input = get_user_input()
@@ -243,9 +246,11 @@ def handle_user_input(
messages.append({"role": "user", "content": user_input})
# Indicate that assistant is processing
click.echo()
click.secho("Assistant is processing your input. Please wait...", fg="green")
# Process assistant's response
final_response = chat_llm.call(
messages=messages,
tools=[crew_tool_schema],
@@ -257,11 +262,12 @@ def handle_user_input(
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
"""Dynamically build a Littellm 'function' schema for the given crew.
"""
Dynamically build a Littellm 'function' schema for the given crew.
Args:
crew_inputs: A ChatInputs object containing crew_description
and a list of input fields (each with a name & description).
crew_name: The name of the crew (used for the function 'name').
crew_inputs: A ChatInputs object containing crew_description
and a list of input fields (each with a name & description).
"""
properties = {}
for field in crew_inputs.inputs:
@@ -287,51 +293,70 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
def run_crew_tool(crew: Crew, messages: list[LLMMessage], **kwargs: Any) -> str:
"""Run the crew using crew.kickoff(inputs=kwargs) and return the output.
"""
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
Args:
crew: The crew instance to run.
messages: The chat messages up to this point.
crew (Crew): The crew instance to run.
messages (List[Dict[str, str]]): The chat messages up to this point.
**kwargs: The inputs collected from the user.
Returns:
The output from the crew's execution.
str: The output from the crew's execution.
Raises:
SystemExit: Exits the chat if an error occurs during crew execution.
"""
try:
# Serialize 'messages' to JSON string before adding to kwargs
kwargs["crew_chat_messages"] = json.dumps(messages)
# Run the crew with the provided inputs
crew_output = crew.kickoff(inputs=kwargs)
# Convert CrewOutput to a string to send back to the user
return str(crew_output)
except Exception as e:
# Exit the chat and show the error message
click.secho("An error occurred while running the crew:", fg="red")
click.secho(str(e), fg="red")
sys.exit(1)
def load_crew_and_name() -> tuple[Crew, str]:
"""Load the crew by importing the crew class from the user's project.
"""
Loads the crew by importing the crew class from the user's project.
Returns:
A tuple containing the Crew instance and the name of the crew.
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
"""
# Get the current working directory
cwd = Path.cwd()
# Path to the pyproject.toml file
pyproject_path = cwd / "pyproject.toml"
if not pyproject_path.exists():
raise FileNotFoundError("pyproject.toml not found in the current directory.")
# Load the pyproject.toml file using 'tomli'
with pyproject_path.open("rb") as f:
pyproject_data = tomli.load(f)
# Get the project name from the 'project' section
project_name = pyproject_data["project"]["name"]
folder_name = project_name
# Derive the crew class name from the project name
# E.g., if project_name is 'my_project', crew_class_name is 'MyProject'
crew_class_name = project_name.replace("_", " ").title().replace(" ", "")
# Add the 'src' directory to sys.path
src_path = cwd / "src"
if str(src_path) not in sys.path:
sys.path.insert(0, str(src_path))
# Import the crew module
crew_module_name = f"{folder_name}.crew"
try:
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
@@ -340,6 +365,7 @@ def load_crew_and_name() -> tuple[Crew, str]:
f"Failed to import crew module {crew_module_name}: {e}"
) from e
# Get the crew class from the module
try:
crew_class = getattr(crew_module, crew_class_name)
except AttributeError as e:
@@ -347,6 +373,7 @@ def load_crew_and_name() -> tuple[Crew, str]:
f"Crew class {crew_class_name} not found in module {crew_module_name}"
) from e
# Instantiate the crew
crew_instance = crew_class().crew()
return crew_instance, crew_class_name
@@ -354,23 +381,27 @@ def load_crew_and_name() -> tuple[Crew, str]:
def generate_crew_chat_inputs(
crew: Crew, crew_name: str, chat_llm: LLM | BaseLLM
) -> ChatInputs:
"""Generate the ChatInputs required for the crew by analyzing the tasks and agents.
"""
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
Args:
crew: The crew object containing tasks and agents.
crew_name: The name of the crew.
crew (Crew): The crew object containing tasks and agents.
crew_name (str): The name of the crew.
chat_llm: The chat language model to use for AI calls.
Returns:
An object containing the crew's name, description, and input fields.
ChatInputs: An object containing the crew's name, description, and input fields.
"""
# Extract placeholders from tasks and agents
required_inputs = fetch_required_inputs(crew)
# Generate descriptions for each input using AI
input_fields = []
for input_name in required_inputs:
description = generate_input_description_with_ai(input_name, crew, chat_llm)
input_fields.append(ChatInputField(name=input_name, description=description))
# Generate crew description using AI
crew_description = generate_crew_description_with_ai(crew, chat_llm)
return ChatInputs(
@@ -379,13 +410,13 @@ def generate_crew_chat_inputs(
def fetch_required_inputs(crew: Crew) -> set[str]:
"""Extract placeholders from the crew's tasks and agents.
"""Extracts placeholders from the crew's tasks and agents.
Args:
crew: The crew object.
crew (Crew): The crew object.
Returns:
A set of placeholder names.
Set[str]: A set of placeholder names.
"""
return crew.fetch_inputs()
@@ -393,16 +424,18 @@ def fetch_required_inputs(crew: Crew) -> set[str]:
def generate_input_description_with_ai(
input_name: str, crew: Crew, chat_llm: LLM | BaseLLM
) -> str:
"""Generate an input description using AI based on the context of the crew.
"""
Generates an input description using AI based on the context of the crew.
Args:
input_name: The name of the input placeholder.
crew: The crew object.
input_name (str): The name of the input placeholder.
crew (Crew): The crew object.
chat_llm: The chat language model to use for AI calls.
Returns:
A concise description of the input.
str: A concise description of the input.
"""
# Gather context from tasks and agents where the input is used
context_texts = []
placeholder_pattern = re.compile(r"\{(.+?)}")
@@ -411,6 +444,7 @@ def generate_input_description_with_ai(
f"{{{input_name}}}" in task.description
or f"{{{input_name}}}" in task.expected_output
):
# Replace placeholders with input names
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or ""
)
@@ -425,6 +459,7 @@ def generate_input_description_with_ai(
or f"{{{input_name}}}" in agent.goal
or f"{{{input_name}}}" in agent.backstory
):
# Replace placeholders with input names
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub(
@@ -436,6 +471,7 @@ def generate_input_description_with_ai(
context = "\n".join(context_texts)
if not context:
# If no context is found for the input, raise an exception as per instruction
raise ValueError(f"No context found for input '{input_name}'.")
prompt = (
@@ -449,19 +485,22 @@ def generate_input_description_with_ai(
def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> str:
"""Generate a brief description of the crew using AI.
"""
Generates a brief description of the crew using AI.
Args:
crew: The crew object.
crew (Crew): The crew object.
chat_llm: The chat language model to use for AI calls.
Returns:
A concise description of the crew's purpose (15 words or less).
str: A concise description of the crew's purpose (15 words or less).
"""
# Gather context from tasks and agents
context_texts = []
placeholder_pattern = re.compile(r"\{(.+?)}")
for task in crew.tasks:
# Replace placeholders with input names
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or ""
)
@@ -471,6 +510,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> st
context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}")
for agent in crew.agents:
# Replace placeholders with input names
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub(

View File

@@ -1,11 +1,10 @@
from pathlib import Path
from typing import Any
from rich.console import Console
from crewai_cli import git
from crewai_cli.command import BaseCommand, PlusAPIMixin
from crewai_cli.utils import fetch_and_json_env_file, get_project_name
from crewai.cli import git
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.utils import fetch_and_json_env_file, get_project_name
console = Console()
@@ -22,43 +21,8 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
"""
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
self.project_name = get_project_name(require=True)
self._validate_project_structure()
def _validate_project_structure(self) -> None:
"""Validate that the local project has the files required for deployment."""
errors: list[str] = []
if not Path("pyproject.toml").exists():
errors.append("Cannot find pyproject.toml in the current directory.")
has_lockfile = Path("uv.lock").exists() or Path("poetry.lock").exists()
if not has_lockfile:
errors.append(
"No uv.lock or poetry.lock found. "
"Run 'uv lock' or 'poetry lock' to generate one."
)
src_dir = Path("src") / (self.project_name or "")
crew_py = src_dir / "crew.py"
config_dir = src_dir / "config"
if not crew_py.exists() and not config_dir.exists():
errors.append(
f"Cannot find src/{self.project_name}/crew.py or "
f"src/{self.project_name}/config. "
"Ensure you are running this command from the project root."
)
if errors:
console.print(
"\n[bold red]Pre-flight check failed:[/bold red] "
"Your project is missing required files for deployment.\n"
)
for error in errors:
console.print(f"{error}", style="red")
console.print()
raise SystemExit(1)
def _standard_no_param_error_message(self) -> None:
"""
@@ -103,6 +67,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
Args:
uuid (Optional[str]): The UUID of the crew to deploy.
"""
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
console.print("Starting deployment...", style="bold blue")
if uuid:
response = self.plus_api_client.deploy_by_uuid(uuid)
@@ -119,6 +84,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
"""
Create a new crew deployment.
"""
self._create_crew_deployment_span = (
self._telemetry.create_crew_deployment_span()
)
console.print("Creating deployment...", style="bold blue")
env_vars = fetch_and_json_env_file()
@@ -268,6 +236,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
uuid (Optional[str]): The UUID of the crew to get logs for.
log_type (str): The type of logs to retrieve (default: "deployment").
"""
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
console.print(f"Fetching {log_type} logs...", style="bold blue")
if uuid:
@@ -288,6 +257,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
Args:
uuid (Optional[str]): The UUID of the crew to remove.
"""
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
console.print("Removing deployment...", style="bold blue")
if uuid:

View File

@@ -4,10 +4,10 @@ from typing import Any, cast
import httpx
from rich.console import Console
from crewai_cli.authentication.main import Oauth2Settings, ProviderFactory
from crewai_cli.command import BaseCommand
from crewai_cli.settings.main import SettingsCommand
from crewai_cli.version import get_crewai_version
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
from crewai.cli.command import BaseCommand
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.version import get_crewai_version
console = Console()

Some files were not shown because too many files have changed in this diff Show More