mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-22 19:58:14 +00:00
Compare commits
2 Commits
feat/cli-l
...
devin/1774
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
287ffe2f6d | ||
|
|
9a9cb48d09 |
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -52,7 +53,18 @@ class NL2SQLTool(BaseTool):
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_identifier(value: str, name: str) -> str:
|
||||
"""Validate a SQL identifier to prevent SQL injection."""
|
||||
if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", value):
|
||||
raise ValueError(
|
||||
f"Invalid {name}: {value!r}. "
|
||||
f"Only alphanumeric characters, underscores, and dollar signs are allowed."
|
||||
)
|
||||
return value
|
||||
|
||||
def _fetch_all_available_columns(self, table_name: str):
|
||||
table_name = self._validate_identifier(table_name, "table_name")
|
||||
return self.execute_sql(
|
||||
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -71,8 +72,16 @@ class SnowflakeSearchToolInput(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
query: str = Field(..., description="SQL query or semantic search query to execute")
|
||||
database: str | None = Field(None, description="Override default database")
|
||||
snowflake_schema: str | None = Field(None, description="Override default schema")
|
||||
database: str | None = Field(
|
||||
None,
|
||||
description="Override default database",
|
||||
pattern=r"^[A-Za-z_][A-Za-z0-9_$]*$",
|
||||
)
|
||||
snowflake_schema: str | None = Field(
|
||||
None,
|
||||
description="Override default schema",
|
||||
pattern=r"^[A-Za-z_][A-Za-z0-9_$]*$",
|
||||
)
|
||||
timeout: int | None = Field(300, description="Query timeout in seconds")
|
||||
|
||||
|
||||
@@ -247,6 +256,16 @@ class SnowflakeSearchTool(BaseTool):
|
||||
continue
|
||||
raise RuntimeError("Query failed after all retries")
|
||||
|
||||
@staticmethod
|
||||
def _validate_identifier(value: str, name: str) -> str:
|
||||
"""Validate and sanitize a Snowflake identifier to prevent SQL injection."""
|
||||
if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", value):
|
||||
raise ValueError(
|
||||
f"Invalid {name}: {value!r}. "
|
||||
f"Only alphanumeric characters, underscores, and dollar signs are allowed."
|
||||
)
|
||||
return value
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
query: str,
|
||||
@@ -259,9 +278,11 @@ class SnowflakeSearchTool(BaseTool):
|
||||
try:
|
||||
# Override database/schema if provided
|
||||
if database:
|
||||
await self._execute_query(f"USE DATABASE {database}")
|
||||
database = self._validate_identifier(database, "database")
|
||||
await self._execute_query(f'USE DATABASE "{database}"')
|
||||
if snowflake_schema:
|
||||
await self._execute_query(f"USE SCHEMA {snowflake_schema}")
|
||||
snowflake_schema = self._validate_identifier(snowflake_schema, "schema")
|
||||
await self._execute_query(f'USE SCHEMA "{snowflake_schema}"')
|
||||
|
||||
return await self._execute_query(query, timeout)
|
||||
except Exception as e:
|
||||
|
||||
72
lib/crewai-tools/tests/tools/nl2sql_tool_test.py
Normal file
72
lib/crewai-tools/tests/tools/nl2sql_tool_test.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
|
||||
|
||||
|
||||
class TestNL2SQLToolValidateIdentifier:
|
||||
"""Tests for SQL injection prevention via identifier validation."""
|
||||
|
||||
def test_valid_identifiers(self):
|
||||
assert NL2SQLTool._validate_identifier("users", "table_name") == "users"
|
||||
assert NL2SQLTool._validate_identifier("MY_TABLE", "table_name") == "MY_TABLE"
|
||||
assert NL2SQLTool._validate_identifier("table$1", "table_name") == "table$1"
|
||||
assert NL2SQLTool._validate_identifier("_private", "table_name") == "_private"
|
||||
|
||||
def test_rejects_sql_injection_with_semicolon(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users; DROP TABLE users;--", "table_name")
|
||||
|
||||
def test_rejects_sql_injection_with_quotes(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users'--", "table_name")
|
||||
|
||||
def test_rejects_sql_injection_with_spaces(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users DROP TABLE", "table_name")
|
||||
|
||||
def test_rejects_leading_number(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("1table", "table_name")
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("", "table_name")
|
||||
|
||||
def test_rejects_parentheses(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users()", "table_name")
|
||||
|
||||
def test_rejects_dash_comment(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users--comment", "table_name")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.nl2sql.nl2sql_tool.SQLALCHEMY_AVAILABLE", True)
|
||||
class TestNL2SQLToolFetchColumns:
|
||||
"""Tests that _fetch_all_available_columns validates table names."""
|
||||
|
||||
def _make_tool(self):
|
||||
"""Create an NL2SQLTool instance bypassing model_post_init DB calls."""
|
||||
with patch.object(NL2SQLTool, "model_post_init"):
|
||||
tool = NL2SQLTool(
|
||||
db_uri="sqlite:///:memory:",
|
||||
name="NL2SQLTool",
|
||||
description="test",
|
||||
)
|
||||
return tool
|
||||
|
||||
def test_rejects_malicious_table_name(self):
|
||||
tool = self._make_tool()
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
tool._fetch_all_available_columns("users'; DROP TABLE users;--")
|
||||
|
||||
def test_accepts_valid_table_name(self):
|
||||
tool = self._make_tool()
|
||||
with patch.object(NL2SQLTool, "execute_sql", return_value=[]) as mock_exec:
|
||||
result = tool._fetch_all_available_columns("valid_table")
|
||||
mock_exec.assert_called_once()
|
||||
call_sql = mock_exec.call_args[0][0]
|
||||
assert "valid_table" in call_sql
|
||||
assert result == []
|
||||
@@ -2,6 +2,9 @@ import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
|
||||
from crewai_tools.tools.snowflake_search_tool.snowflake_search_tool import (
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -100,3 +103,136 @@ def test_config_validation():
|
||||
# Test missing authentication
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(account="test_account", user="test_user")
|
||||
|
||||
|
||||
# SQL Injection Prevention Tests
|
||||
class TestSnowflakeSearchToolInputValidation:
|
||||
"""Tests for SQL injection prevention via input schema validation."""
|
||||
|
||||
def test_valid_database_identifier(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", database="my_database")
|
||||
assert inp.database == "my_database"
|
||||
|
||||
def test_valid_schema_identifier(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", snowflake_schema="public")
|
||||
assert inp.snowflake_schema == "public"
|
||||
|
||||
def test_valid_identifier_with_dollar_sign(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", database="my$db")
|
||||
assert inp.database == "my$db"
|
||||
|
||||
def test_database_with_sql_injection_semicolon(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", database="test_db; DROP TABLE users; --"
|
||||
)
|
||||
|
||||
def test_schema_with_sql_injection_semicolon(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", snowflake_schema="public; DROP TABLE users; --"
|
||||
)
|
||||
|
||||
def test_database_with_sql_injection_spaces(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", database="test_db DROP TABLE"
|
||||
)
|
||||
|
||||
def test_schema_with_sql_injection_quotes(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", snowflake_schema="public\"--"
|
||||
)
|
||||
|
||||
def test_database_with_sql_injection_dash_comment(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", database="test--comment"
|
||||
)
|
||||
|
||||
def test_database_starting_with_number(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(query="SELECT 1", database="1invalid")
|
||||
|
||||
def test_none_database_is_allowed(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", database=None)
|
||||
assert inp.database is None
|
||||
|
||||
def test_none_schema_is_allowed(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", snowflake_schema=None)
|
||||
assert inp.snowflake_schema is None
|
||||
|
||||
|
||||
class TestSnowflakeSearchToolValidateIdentifier:
|
||||
"""Tests for the _validate_identifier runtime check."""
|
||||
|
||||
def test_valid_identifiers(self):
|
||||
assert SnowflakeSearchTool._validate_identifier("my_db", "database") == "my_db"
|
||||
assert SnowflakeSearchTool._validate_identifier("PROD_DB", "database") == "PROD_DB"
|
||||
assert SnowflakeSearchTool._validate_identifier("schema$1", "schema") == "schema$1"
|
||||
assert SnowflakeSearchTool._validate_identifier("_private", "schema") == "_private"
|
||||
|
||||
def test_rejects_semicolons(self):
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
SnowflakeSearchTool._validate_identifier("db; DROP TABLE users;--", "database")
|
||||
|
||||
def test_rejects_spaces(self):
|
||||
with pytest.raises(ValueError, match="Invalid schema"):
|
||||
SnowflakeSearchTool._validate_identifier("public schema", "schema")
|
||||
|
||||
def test_rejects_quotes(self):
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
SnowflakeSearchTool._validate_identifier('db"--', "database")
|
||||
|
||||
def test_rejects_leading_number(self):
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
SnowflakeSearchTool._validate_identifier("1db", "database")
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
SnowflakeSearchTool._validate_identifier("", "database")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_uses_quoted_identifiers(snowflake_tool, mock_snowflake_connection):
|
||||
"""Verify that _run wraps database/schema in double quotes in the SQL."""
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
await snowflake_tool._run(
|
||||
query="SELECT 1",
|
||||
database="my_db",
|
||||
snowflake_schema="my_schema",
|
||||
)
|
||||
|
||||
calls = mock_snowflake_connection.cursor().execute.call_args_list
|
||||
sql_statements = [call[0][0] for call in calls]
|
||||
assert 'USE DATABASE "my_db"' in sql_statements
|
||||
assert 'USE SCHEMA "my_schema"' in sql_statements
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_rejects_malicious_database(snowflake_tool, mock_snowflake_connection):
|
||||
"""Verify that _run raises ValueError for SQL injection attempts in database."""
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
await snowflake_tool._run(
|
||||
query="SELECT 1",
|
||||
database="test_db; DROP TABLE users; --",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_rejects_malicious_schema(snowflake_tool, mock_snowflake_connection):
|
||||
"""Verify that _run raises ValueError for SQL injection attempts in schema."""
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid schema"):
|
||||
await snowflake_tool._run(
|
||||
query="SELECT 1",
|
||||
snowflake_schema="public; DROP TABLE users; --",
|
||||
)
|
||||
|
||||
@@ -21672,6 +21672,7 @@
|
||||
"database": {
|
||||
"anyOf": [
|
||||
{
|
||||
"pattern": "^[A-Za-z_][A-Za-z0-9_$]*$",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
@@ -21690,6 +21691,7 @@
|
||||
"snowflake_schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"pattern": "^[A-Za-z_][A-Za-z0-9_$]*$",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
|
||||
@@ -22,7 +22,6 @@ from crewai.cli.replay_from_task import replay_task_command
|
||||
from crewai.cli.reset_memories_command import reset_memories_command
|
||||
from crewai.cli.run_crew import run_crew
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
from crewai.cli.train_crew import train_crew
|
||||
from crewai.cli.triggers.main import TriggersCommand
|
||||
@@ -35,7 +34,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
|
||||
@click.group()
|
||||
@click.version_option(get_version("crewai"))
|
||||
def crewai() -> None:
|
||||
def crewai():
|
||||
"""Top-level command group for crewai."""
|
||||
|
||||
|
||||
@@ -46,7 +45,7 @@ def crewai() -> None:
|
||||
),
|
||||
)
|
||||
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
|
||||
def uv(uv_args: tuple[str, ...]) -> None:
|
||||
def uv(uv_args):
|
||||
"""A wrapper around uv commands that adds custom tool authentication through env vars."""
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
@@ -84,9 +83,7 @@ def uv(uv_args: tuple[str, ...]) -> None:
|
||||
@click.argument("name")
|
||||
@click.option("--provider", type=str, help="The provider to use for the crew")
|
||||
@click.option("--skip_provider", is_flag=True, help="Skip provider validation")
|
||||
def create(
|
||||
type: str, name: str, provider: str | None, skip_provider: bool = False
|
||||
) -> None:
|
||||
def create(type, name, provider, skip_provider=False):
|
||||
"""Create a new crew, or flow."""
|
||||
if type == "crew":
|
||||
create_crew(name, provider, skip_provider)
|
||||
@@ -100,7 +97,7 @@ def create(
|
||||
@click.option(
|
||||
"--tools", is_flag=True, help="Show the installed version of crewai tools"
|
||||
)
|
||||
def version(tools: bool) -> None:
|
||||
def version(tools):
|
||||
"""Show the installed version of crewai."""
|
||||
try:
|
||||
crewai_version = get_version("crewai")
|
||||
@@ -131,7 +128,7 @@ def version(tools: bool) -> None:
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a custom file for training",
|
||||
)
|
||||
def train(n_iterations: int, filename: str) -> None:
|
||||
def train(n_iterations: int, filename: str):
|
||||
"""Train the crew."""
|
||||
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations, filename)
|
||||
@@ -337,7 +334,7 @@ def memory(
|
||||
default="gpt-4o-mini",
|
||||
help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.",
|
||||
)
|
||||
def test(n_iterations: int, model: str) -> None:
|
||||
def test(n_iterations: int, model: str):
|
||||
"""Test the crew and evaluate the results."""
|
||||
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
|
||||
evaluate_crew(n_iterations, model)
|
||||
@@ -350,62 +347,46 @@ def test(n_iterations: int, model: str) -> None:
|
||||
)
|
||||
)
|
||||
@click.pass_context
|
||||
def install(context: click.Context) -> None:
|
||||
def install(context):
|
||||
"""Install the Crew."""
|
||||
install_crew(context.args)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def run() -> None:
|
||||
def run():
|
||||
"""Run the Crew."""
|
||||
run_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def update() -> None:
|
||||
def update():
|
||||
"""Update the pyproject.toml of the Crew project to use uv."""
|
||||
update_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def login() -> None:
|
||||
def login():
|
||||
"""Sign Up/Login to CrewAI AMP."""
|
||||
Settings().clear_user_settings()
|
||||
AuthenticationCommand().login()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--reset", is_flag=True, help="Also reset all CLI configuration to defaults"
|
||||
)
|
||||
def logout(reset: bool) -> None:
|
||||
"""Logout from CrewAI AMP."""
|
||||
settings = Settings()
|
||||
if reset:
|
||||
settings.reset()
|
||||
click.echo("Successfully logged out and reset all CLI configuration.")
|
||||
else:
|
||||
TokenManager().clear_tokens()
|
||||
settings.clear_user_settings()
|
||||
click.echo("Successfully logged out from CrewAI AMP.")
|
||||
|
||||
|
||||
# DEPLOY CREWAI+ COMMANDS
|
||||
@crewai.group()
|
||||
def deploy() -> None:
|
||||
def deploy():
|
||||
"""Deploy the Crew CLI group."""
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool) -> None:
|
||||
def deploy_create(yes: bool):
|
||||
"""Create a Crew deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.create_crew(yes)
|
||||
|
||||
|
||||
@deploy.command(name="list")
|
||||
def deploy_list() -> None:
|
||||
def deploy_list():
|
||||
"""List all deployments."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.list_crews()
|
||||
@@ -413,7 +394,7 @@ def deploy_list() -> None:
|
||||
|
||||
@deploy.command(name="push")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_push(uuid: str | None) -> None:
|
||||
def deploy_push(uuid: str | None):
|
||||
"""Deploy the Crew."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid)
|
||||
@@ -421,7 +402,7 @@ def deploy_push(uuid: str | None) -> None:
|
||||
|
||||
@deploy.command(name="status")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deply_status(uuid: str | None) -> None:
|
||||
def deply_status(uuid: str | None):
|
||||
"""Get the status of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_status(uuid=uuid)
|
||||
@@ -429,7 +410,7 @@ def deply_status(uuid: str | None) -> None:
|
||||
|
||||
@deploy.command(name="logs")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_logs(uuid: str | None) -> None:
|
||||
def deploy_logs(uuid: str | None):
|
||||
"""Get the logs of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||
@@ -437,27 +418,27 @@ def deploy_logs(uuid: str | None) -> None:
|
||||
|
||||
@deploy.command(name="remove")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_remove(uuid: str | None) -> None:
|
||||
def deploy_remove(uuid: str | None):
|
||||
"""Remove a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool() -> None:
|
||||
def tool():
|
||||
"""Tool Repository related commands."""
|
||||
|
||||
|
||||
@tool.command(name="create")
|
||||
@click.argument("handle")
|
||||
def tool_create(handle: str) -> None:
|
||||
def tool_create(handle: str):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.create(handle)
|
||||
|
||||
|
||||
@tool.command(name="install")
|
||||
@click.argument("handle")
|
||||
def tool_install(handle: str) -> None:
|
||||
def tool_install(handle: str):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.install(handle)
|
||||
@@ -473,26 +454,26 @@ def tool_install(handle: str) -> None:
|
||||
)
|
||||
@click.option("--public", "is_public", flag_value=True, default=False)
|
||||
@click.option("--private", "is_public", flag_value=False)
|
||||
def tool_publish(is_public: bool, force: bool) -> None:
|
||||
def tool_publish(is_public: bool, force: bool):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.publish(is_public, force)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def flow() -> None:
|
||||
def flow():
|
||||
"""Flow related commands."""
|
||||
|
||||
|
||||
@flow.command(name="kickoff")
|
||||
def flow_run() -> None:
|
||||
def flow_run():
|
||||
"""Kickoff the Flow."""
|
||||
click.echo("Running the Flow")
|
||||
kickoff_flow()
|
||||
|
||||
|
||||
@flow.command(name="plot")
|
||||
def flow_plot() -> None:
|
||||
def flow_plot():
|
||||
"""Plot the Flow."""
|
||||
click.echo("Plotting the Flow")
|
||||
plot_flow()
|
||||
@@ -500,19 +481,19 @@ def flow_plot() -> None:
|
||||
|
||||
@flow.command(name="add-crew")
|
||||
@click.argument("crew_name")
|
||||
def flow_add_crew(crew_name: str) -> None:
|
||||
def flow_add_crew(crew_name):
|
||||
"""Add a crew to an existing flow."""
|
||||
click.echo(f"Adding crew {crew_name} to the flow")
|
||||
add_crew_to_flow(crew_name)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def triggers() -> None:
|
||||
def triggers():
|
||||
"""Trigger related commands. Use 'crewai triggers list' to see available triggers, or 'crewai triggers run app_slug/trigger_slug' to execute."""
|
||||
|
||||
|
||||
@triggers.command(name="list")
|
||||
def triggers_list() -> None:
|
||||
def triggers_list():
|
||||
"""List all available triggers from integrations."""
|
||||
triggers_cmd = TriggersCommand()
|
||||
triggers_cmd.list_triggers()
|
||||
@@ -520,14 +501,14 @@ def triggers_list() -> None:
|
||||
|
||||
@triggers.command(name="run")
|
||||
@click.argument("trigger_path")
|
||||
def triggers_run(trigger_path: str) -> None:
|
||||
def triggers_run(trigger_path: str):
|
||||
"""Execute crew with trigger payload. Format: app_slug/trigger_slug"""
|
||||
triggers_cmd = TriggersCommand()
|
||||
triggers_cmd.execute_with_trigger(trigger_path)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def chat() -> None:
|
||||
def chat():
|
||||
"""
|
||||
Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
and using the Chat LLM to generate responses.
|
||||
@@ -540,12 +521,12 @@ def chat() -> None:
|
||||
|
||||
|
||||
@crewai.group(invoke_without_command=True)
|
||||
def org() -> None:
|
||||
def org():
|
||||
"""Organization management commands."""
|
||||
|
||||
|
||||
@org.command("list")
|
||||
def org_list() -> None:
|
||||
def org_list():
|
||||
"""List available organizations."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.list()
|
||||
@@ -553,39 +534,39 @@ def org_list() -> None:
|
||||
|
||||
@org.command()
|
||||
@click.argument("id")
|
||||
def switch(id: str) -> None:
|
||||
def switch(id):
|
||||
"""Switch to a specific organization."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.switch(id)
|
||||
|
||||
|
||||
@org.command()
|
||||
def current() -> None:
|
||||
def current():
|
||||
"""Show current organization when 'crewai org' is called without subcommands."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.current()
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def enterprise() -> None:
|
||||
def enterprise():
|
||||
"""Enterprise Configuration commands."""
|
||||
|
||||
|
||||
@enterprise.command("configure")
|
||||
@click.argument("enterprise_url")
|
||||
def enterprise_configure(enterprise_url: str) -> None:
|
||||
def enterprise_configure(enterprise_url: str):
|
||||
"""Configure CrewAI AMP OAuth2 settings from the provided Enterprise URL."""
|
||||
enterprise_command = EnterpriseConfigureCommand()
|
||||
enterprise_command.configure(enterprise_url)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def config() -> None:
|
||||
def config():
|
||||
"""CLI Configuration commands."""
|
||||
|
||||
|
||||
@config.command("list")
|
||||
def config_list() -> None:
|
||||
def config_list():
|
||||
"""List all CLI configuration parameters."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.list()
|
||||
@@ -594,26 +575,26 @@ def config_list() -> None:
|
||||
@config.command("set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def config_set(key: str, value: str) -> None:
|
||||
def config_set(key: str, value: str):
|
||||
"""Set a CLI configuration parameter."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.set(key, value)
|
||||
|
||||
|
||||
@config.command("reset")
|
||||
def config_reset() -> None:
|
||||
def config_reset():
|
||||
"""Reset all CLI configuration parameters to default values."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.reset_all_settings()
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def env() -> None:
|
||||
def env():
|
||||
"""Environment variable commands."""
|
||||
|
||||
|
||||
@env.command("view")
|
||||
def env_view() -> None:
|
||||
def env_view():
|
||||
"""View tracing-related environment variables."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -691,12 +672,12 @@ def env_view() -> None:
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def traces() -> None:
|
||||
def traces():
|
||||
"""Trace collection management commands."""
|
||||
|
||||
|
||||
@traces.command("enable")
|
||||
def traces_enable() -> None:
|
||||
def traces_enable():
|
||||
"""Enable trace collection for crew/flow executions."""
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -719,7 +700,7 @@ def traces_enable() -> None:
|
||||
|
||||
|
||||
@traces.command("disable")
|
||||
def traces_disable() -> None:
|
||||
def traces_disable():
|
||||
"""Disable trace collection for crew/flow executions."""
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -742,7 +723,7 @@ def traces_disable() -> None:
|
||||
|
||||
|
||||
@traces.command("status")
|
||||
def traces_status() -> None:
|
||||
def traces_status():
|
||||
"""Show current trace collection status."""
|
||||
import os
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import click
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
|
||||
def create_flow(name: str) -> None:
|
||||
def create_flow(name):
|
||||
"""Create a new flow."""
|
||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
@@ -49,7 +49,7 @@ def create_flow(name: str) -> None:
|
||||
"poem_crew",
|
||||
]
|
||||
|
||||
def process_file(src_file: Path, dst_file: Path) -> None:
|
||||
def process_file(src_file, dst_file):
|
||||
if src_file.suffix in [".pyc", ".pyo", ".pyd"]:
|
||||
return
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
A class to handle deployment-related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the DeployCommand with project name and API client.
|
||||
"""
|
||||
@@ -67,7 +67,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to deploy.
|
||||
"""
|
||||
self._telemetry.start_deployment_span(uuid)
|
||||
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
|
||||
console.print("Starting deployment...", style="bold blue")
|
||||
if uuid:
|
||||
response = self.plus_api_client.deploy_by_uuid(uuid)
|
||||
@@ -84,7 +84,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
Create a new crew deployment.
|
||||
"""
|
||||
self._telemetry.create_crew_deployment_span()
|
||||
self._create_crew_deployment_span = (
|
||||
self._telemetry.create_crew_deployment_span()
|
||||
)
|
||||
console.print("Creating deployment...", style="bold blue")
|
||||
env_vars = fetch_and_json_env_file()
|
||||
|
||||
@@ -234,7 +236,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
uuid (Optional[str]): The UUID of the crew to get logs for.
|
||||
log_type (str): The type of logs to retrieve (default: "deployment").
|
||||
"""
|
||||
self._telemetry.get_crew_logs_span(uuid, log_type)
|
||||
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
|
||||
console.print(f"Fetching {log_type} logs...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
@@ -255,7 +257,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to remove.
|
||||
"""
|
||||
self._telemetry.remove_crew_span(uuid)
|
||||
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
|
||||
console.print("Removing deployment...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
|
||||
@@ -16,7 +16,7 @@ class TriggersCommand(BaseCommand, PlusAPIMixin):
|
||||
A class to handle trigger-related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user