mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-22 19:58:14 +00:00
Compare commits
2 Commits
devin/1774
...
feat/cli-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5aa229f72 | ||
|
|
f13d307534 |
@@ -22,6 +22,7 @@ from crewai.cli.replay_from_task import replay_task_command
|
||||
from crewai.cli.reset_memories_command import reset_memories_command
|
||||
from crewai.cli.run_crew import run_crew
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
from crewai.cli.train_crew import train_crew
|
||||
from crewai.cli.triggers.main import TriggersCommand
|
||||
@@ -34,7 +35,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
|
||||
@click.group()
|
||||
@click.version_option(get_version("crewai"))
|
||||
def crewai():
|
||||
def crewai() -> None:
|
||||
"""Top-level command group for crewai."""
|
||||
|
||||
|
||||
@@ -45,7 +46,7 @@ def crewai():
|
||||
),
|
||||
)
|
||||
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
|
||||
def uv(uv_args):
|
||||
def uv(uv_args: tuple[str, ...]) -> None:
|
||||
"""A wrapper around uv commands that adds custom tool authentication through env vars."""
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
@@ -83,7 +84,9 @@ def uv(uv_args):
|
||||
@click.argument("name")
|
||||
@click.option("--provider", type=str, help="The provider to use for the crew")
|
||||
@click.option("--skip_provider", is_flag=True, help="Skip provider validation")
|
||||
def create(type, name, provider, skip_provider=False):
|
||||
def create(
|
||||
type: str, name: str, provider: str | None, skip_provider: bool = False
|
||||
) -> None:
|
||||
"""Create a new crew, or flow."""
|
||||
if type == "crew":
|
||||
create_crew(name, provider, skip_provider)
|
||||
@@ -97,7 +100,7 @@ def create(type, name, provider, skip_provider=False):
|
||||
@click.option(
|
||||
"--tools", is_flag=True, help="Show the installed version of crewai tools"
|
||||
)
|
||||
def version(tools):
|
||||
def version(tools: bool) -> None:
|
||||
"""Show the installed version of crewai."""
|
||||
try:
|
||||
crewai_version = get_version("crewai")
|
||||
@@ -128,7 +131,7 @@ def version(tools):
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a custom file for training",
|
||||
)
|
||||
def train(n_iterations: int, filename: str):
|
||||
def train(n_iterations: int, filename: str) -> None:
|
||||
"""Train the crew."""
|
||||
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations, filename)
|
||||
@@ -334,7 +337,7 @@ def memory(
|
||||
default="gpt-4o-mini",
|
||||
help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.",
|
||||
)
|
||||
def test(n_iterations: int, model: str):
|
||||
def test(n_iterations: int, model: str) -> None:
|
||||
"""Test the crew and evaluate the results."""
|
||||
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
|
||||
evaluate_crew(n_iterations, model)
|
||||
@@ -347,46 +350,62 @@ def test(n_iterations: int, model: str):
|
||||
)
|
||||
)
|
||||
@click.pass_context
|
||||
def install(context):
|
||||
def install(context: click.Context) -> None:
|
||||
"""Install the Crew."""
|
||||
install_crew(context.args)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""Run the Crew."""
|
||||
run_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def update():
|
||||
def update() -> None:
|
||||
"""Update the pyproject.toml of the Crew project to use uv."""
|
||||
update_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def login():
|
||||
def login() -> None:
|
||||
"""Sign Up/Login to CrewAI AMP."""
|
||||
Settings().clear_user_settings()
|
||||
AuthenticationCommand().login()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--reset", is_flag=True, help="Also reset all CLI configuration to defaults"
|
||||
)
|
||||
def logout(reset: bool) -> None:
|
||||
"""Logout from CrewAI AMP."""
|
||||
settings = Settings()
|
||||
if reset:
|
||||
settings.reset()
|
||||
click.echo("Successfully logged out and reset all CLI configuration.")
|
||||
else:
|
||||
TokenManager().clear_tokens()
|
||||
settings.clear_user_settings()
|
||||
click.echo("Successfully logged out from CrewAI AMP.")
|
||||
|
||||
|
||||
# DEPLOY CREWAI+ COMMANDS
|
||||
@crewai.group()
|
||||
def deploy():
|
||||
def deploy() -> None:
|
||||
"""Deploy the Crew CLI group."""
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool):
|
||||
def deploy_create(yes: bool) -> None:
|
||||
"""Create a Crew deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.create_crew(yes)
|
||||
|
||||
|
||||
@deploy.command(name="list")
|
||||
def deploy_list():
|
||||
def deploy_list() -> None:
|
||||
"""List all deployments."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.list_crews()
|
||||
@@ -394,7 +413,7 @@ def deploy_list():
|
||||
|
||||
@deploy.command(name="push")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_push(uuid: str | None):
|
||||
def deploy_push(uuid: str | None) -> None:
|
||||
"""Deploy the Crew."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid)
|
||||
@@ -402,7 +421,7 @@ def deploy_push(uuid: str | None):
|
||||
|
||||
@deploy.command(name="status")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deply_status(uuid: str | None):
|
||||
def deply_status(uuid: str | None) -> None:
|
||||
"""Get the status of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_status(uuid=uuid)
|
||||
@@ -410,7 +429,7 @@ def deply_status(uuid: str | None):
|
||||
|
||||
@deploy.command(name="logs")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_logs(uuid: str | None):
|
||||
def deploy_logs(uuid: str | None) -> None:
|
||||
"""Get the logs of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||
@@ -418,27 +437,27 @@ def deploy_logs(uuid: str | None):
|
||||
|
||||
@deploy.command(name="remove")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_remove(uuid: str | None):
|
||||
def deploy_remove(uuid: str | None) -> None:
|
||||
"""Remove a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
def tool() -> None:
|
||||
"""Tool Repository related commands."""
|
||||
|
||||
|
||||
@tool.command(name="create")
|
||||
@click.argument("handle")
|
||||
def tool_create(handle: str):
|
||||
def tool_create(handle: str) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.create(handle)
|
||||
|
||||
|
||||
@tool.command(name="install")
|
||||
@click.argument("handle")
|
||||
def tool_install(handle: str):
|
||||
def tool_install(handle: str) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.install(handle)
|
||||
@@ -454,26 +473,26 @@ def tool_install(handle: str):
|
||||
)
|
||||
@click.option("--public", "is_public", flag_value=True, default=False)
|
||||
@click.option("--private", "is_public", flag_value=False)
|
||||
def tool_publish(is_public: bool, force: bool):
|
||||
def tool_publish(is_public: bool, force: bool) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.publish(is_public, force)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def flow():
|
||||
def flow() -> None:
|
||||
"""Flow related commands."""
|
||||
|
||||
|
||||
@flow.command(name="kickoff")
|
||||
def flow_run():
|
||||
def flow_run() -> None:
|
||||
"""Kickoff the Flow."""
|
||||
click.echo("Running the Flow")
|
||||
kickoff_flow()
|
||||
|
||||
|
||||
@flow.command(name="plot")
|
||||
def flow_plot():
|
||||
def flow_plot() -> None:
|
||||
"""Plot the Flow."""
|
||||
click.echo("Plotting the Flow")
|
||||
plot_flow()
|
||||
@@ -481,19 +500,19 @@ def flow_plot():
|
||||
|
||||
@flow.command(name="add-crew")
|
||||
@click.argument("crew_name")
|
||||
def flow_add_crew(crew_name):
|
||||
def flow_add_crew(crew_name: str) -> None:
|
||||
"""Add a crew to an existing flow."""
|
||||
click.echo(f"Adding crew {crew_name} to the flow")
|
||||
add_crew_to_flow(crew_name)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def triggers():
|
||||
def triggers() -> None:
|
||||
"""Trigger related commands. Use 'crewai triggers list' to see available triggers, or 'crewai triggers run app_slug/trigger_slug' to execute."""
|
||||
|
||||
|
||||
@triggers.command(name="list")
|
||||
def triggers_list():
|
||||
def triggers_list() -> None:
|
||||
"""List all available triggers from integrations."""
|
||||
triggers_cmd = TriggersCommand()
|
||||
triggers_cmd.list_triggers()
|
||||
@@ -501,14 +520,14 @@ def triggers_list():
|
||||
|
||||
@triggers.command(name="run")
|
||||
@click.argument("trigger_path")
|
||||
def triggers_run(trigger_path: str):
|
||||
def triggers_run(trigger_path: str) -> None:
|
||||
"""Execute crew with trigger payload. Format: app_slug/trigger_slug"""
|
||||
triggers_cmd = TriggersCommand()
|
||||
triggers_cmd.execute_with_trigger(trigger_path)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def chat():
|
||||
def chat() -> None:
|
||||
"""
|
||||
Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
and using the Chat LLM to generate responses.
|
||||
@@ -521,12 +540,12 @@ def chat():
|
||||
|
||||
|
||||
@crewai.group(invoke_without_command=True)
|
||||
def org():
|
||||
def org() -> None:
|
||||
"""Organization management commands."""
|
||||
|
||||
|
||||
@org.command("list")
|
||||
def org_list():
|
||||
def org_list() -> None:
|
||||
"""List available organizations."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.list()
|
||||
@@ -534,39 +553,39 @@ def org_list():
|
||||
|
||||
@org.command()
|
||||
@click.argument("id")
|
||||
def switch(id):
|
||||
def switch(id: str) -> None:
|
||||
"""Switch to a specific organization."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.switch(id)
|
||||
|
||||
|
||||
@org.command()
|
||||
def current():
|
||||
def current() -> None:
|
||||
"""Show current organization when 'crewai org' is called without subcommands."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.current()
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def enterprise():
|
||||
def enterprise() -> None:
|
||||
"""Enterprise Configuration commands."""
|
||||
|
||||
|
||||
@enterprise.command("configure")
|
||||
@click.argument("enterprise_url")
|
||||
def enterprise_configure(enterprise_url: str):
|
||||
def enterprise_configure(enterprise_url: str) -> None:
|
||||
"""Configure CrewAI AMP OAuth2 settings from the provided Enterprise URL."""
|
||||
enterprise_command = EnterpriseConfigureCommand()
|
||||
enterprise_command.configure(enterprise_url)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def config():
|
||||
def config() -> None:
|
||||
"""CLI Configuration commands."""
|
||||
|
||||
|
||||
@config.command("list")
|
||||
def config_list():
|
||||
def config_list() -> None:
|
||||
"""List all CLI configuration parameters."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.list()
|
||||
@@ -575,26 +594,26 @@ def config_list():
|
||||
@config.command("set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def config_set(key: str, value: str):
|
||||
def config_set(key: str, value: str) -> None:
|
||||
"""Set a CLI configuration parameter."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.set(key, value)
|
||||
|
||||
|
||||
@config.command("reset")
|
||||
def config_reset():
|
||||
def config_reset() -> None:
|
||||
"""Reset all CLI configuration parameters to default values."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.reset_all_settings()
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def env():
|
||||
def env() -> None:
|
||||
"""Environment variable commands."""
|
||||
|
||||
|
||||
@env.command("view")
|
||||
def env_view():
|
||||
def env_view() -> None:
|
||||
"""View tracing-related environment variables."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -672,12 +691,12 @@ def env_view():
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def traces():
|
||||
def traces() -> None:
|
||||
"""Trace collection management commands."""
|
||||
|
||||
|
||||
@traces.command("enable")
|
||||
def traces_enable():
|
||||
def traces_enable() -> None:
|
||||
"""Enable trace collection for crew/flow executions."""
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -700,7 +719,7 @@ def traces_enable():
|
||||
|
||||
|
||||
@traces.command("disable")
|
||||
def traces_disable():
|
||||
def traces_disable() -> None:
|
||||
"""Disable trace collection for crew/flow executions."""
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -723,7 +742,7 @@ def traces_disable():
|
||||
|
||||
|
||||
@traces.command("status")
|
||||
def traces_status():
|
||||
def traces_status() -> None:
|
||||
"""Show current trace collection status."""
|
||||
import os
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import click
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
|
||||
def create_flow(name):
|
||||
def create_flow(name: str) -> None:
|
||||
"""Create a new flow."""
|
||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
@@ -49,7 +49,7 @@ def create_flow(name):
|
||||
"poem_crew",
|
||||
]
|
||||
|
||||
def process_file(src_file, dst_file):
|
||||
def process_file(src_file: Path, dst_file: Path) -> None:
|
||||
if src_file.suffix in [".pyc", ".pyo", ".pyd"]:
|
||||
return
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
A class to handle deployment-related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the DeployCommand with project name and API client.
|
||||
"""
|
||||
@@ -67,7 +67,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to deploy.
|
||||
"""
|
||||
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
|
||||
self._telemetry.start_deployment_span(uuid)
|
||||
console.print("Starting deployment...", style="bold blue")
|
||||
if uuid:
|
||||
response = self.plus_api_client.deploy_by_uuid(uuid)
|
||||
@@ -84,9 +84,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
Create a new crew deployment.
|
||||
"""
|
||||
self._create_crew_deployment_span = (
|
||||
self._telemetry.create_crew_deployment_span()
|
||||
)
|
||||
self._telemetry.create_crew_deployment_span()
|
||||
console.print("Creating deployment...", style="bold blue")
|
||||
env_vars = fetch_and_json_env_file()
|
||||
|
||||
@@ -236,7 +234,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
uuid (Optional[str]): The UUID of the crew to get logs for.
|
||||
log_type (str): The type of logs to retrieve (default: "deployment").
|
||||
"""
|
||||
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
|
||||
self._telemetry.get_crew_logs_span(uuid, log_type)
|
||||
console.print(f"Fetching {log_type} logs...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
@@ -257,7 +255,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to remove.
|
||||
"""
|
||||
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
|
||||
self._telemetry.remove_crew_span(uuid)
|
||||
console.print("Removing deployment...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
|
||||
@@ -16,7 +16,7 @@ class TriggersCommand(BaseCommand, PlusAPIMixin):
|
||||
A class to handle trigger-related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
|
||||
@@ -3153,19 +3153,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
logger.warning(message)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
filename: str = "crewai_flow.html",
|
||||
show: bool = True,
|
||||
output_dir: str | None = None,
|
||||
) -> str:
|
||||
def plot(self, filename: str = "crewai_flow.html", show: bool = True) -> str:
|
||||
"""Create interactive HTML visualization of Flow structure.
|
||||
|
||||
Args:
|
||||
filename: Output HTML filename (default: "crewai_flow.html").
|
||||
show: Whether to open in browser (default: True).
|
||||
output_dir: Directory to save generated files. Defaults to the
|
||||
current working directory.
|
||||
|
||||
Returns:
|
||||
Absolute path to generated HTML file.
|
||||
@@ -3178,9 +3171,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
),
|
||||
)
|
||||
structure = build_flow_structure(self)
|
||||
return render_interactive(
|
||||
structure, filename=filename, show=show, output_dir=output_dir
|
||||
)
|
||||
return render_interactive(structure, filename=filename, show=show)
|
||||
|
||||
@staticmethod
|
||||
def _show_tracing_disabled_message() -> None:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any, ClassVar
|
||||
import webbrowser
|
||||
|
||||
@@ -204,24 +205,20 @@ def render_interactive(
|
||||
dag: FlowStructure,
|
||||
filename: str = "flow_dag.html",
|
||||
show: bool = True,
|
||||
output_dir: str | None = None,
|
||||
) -> str:
|
||||
"""Create interactive HTML visualization of Flow structure.
|
||||
|
||||
Generates three output files: HTML template, CSS stylesheet, and
|
||||
JavaScript. Files are saved to the specified output directory, or the
|
||||
current working directory when *output_dir* is ``None``. Optionally
|
||||
opens the visualization in the default browser.
|
||||
Generates three output files in a temporary directory: HTML template,
|
||||
CSS stylesheet, and JavaScript. Optionally opens the visualization in
|
||||
default browser.
|
||||
|
||||
Args:
|
||||
dag: FlowStructure to visualize.
|
||||
filename: Output HTML filename (basename only, no path).
|
||||
show: Whether to open in browser.
|
||||
output_dir: Directory to save generated files. Defaults to the
|
||||
current working directory (``os.getcwd()``).
|
||||
|
||||
Returns:
|
||||
Absolute path to generated HTML file.
|
||||
Absolute path to generated HTML file in temporary directory.
|
||||
"""
|
||||
node_positions = calculate_node_positions(dag)
|
||||
|
||||
@@ -406,13 +403,12 @@ def render_interactive(
|
||||
extensions=[CSSExtension, JSExtension],
|
||||
)
|
||||
|
||||
dest_dir = Path(output_dir) if output_dir else Path.cwd()
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = dest_dir / Path(filename).name
|
||||
temp_dir = Path(tempfile.mkdtemp(prefix="crewai_flow_"))
|
||||
output_path = temp_dir / Path(filename).name
|
||||
css_filename = output_path.stem + "_style.css"
|
||||
css_output_path = dest_dir / css_filename
|
||||
css_output_path = temp_dir / css_filename
|
||||
js_filename = output_path.stem + "_script.js"
|
||||
js_output_path = dest_dir / js_filename
|
||||
js_output_path = temp_dir / js_filename
|
||||
|
||||
css_file = template_dir / "style.css"
|
||||
css_content = css_file.read_text(encoding="utf-8")
|
||||
|
||||
@@ -281,6 +281,7 @@ class BaseTool(BaseModel, ABC):
|
||||
result_as_answer=self.result_as_answer,
|
||||
max_usage_count=self.max_usage_count,
|
||||
current_usage_count=self.current_usage_count,
|
||||
cache_function=self.cache_function,
|
||||
)
|
||||
structured_tool._original_tool = self
|
||||
return structured_tool
|
||||
|
||||
@@ -58,6 +58,7 @@ class CrewStructuredTool:
|
||||
result_as_answer: bool = False,
|
||||
max_usage_count: int | None = None,
|
||||
current_usage_count: int = 0,
|
||||
cache_function: Callable[..., bool] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the structured tool.
|
||||
|
||||
@@ -69,6 +70,7 @@ class CrewStructuredTool:
|
||||
result_as_answer: Whether to return the output directly
|
||||
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
||||
current_usage_count: Current number of times this tool has been used.
|
||||
cache_function: Function to determine if the tool result should be cached.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
@@ -78,6 +80,7 @@ class CrewStructuredTool:
|
||||
self.result_as_answer = result_as_answer
|
||||
self.max_usage_count = max_usage_count
|
||||
self.current_usage_count = current_usage_count
|
||||
self.cache_function = cache_function
|
||||
self._original_tool: BaseTool | None = None
|
||||
|
||||
# Validate the function signature matches the schema
|
||||
@@ -86,7 +89,7 @@ class CrewStructuredTool:
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Callable,
|
||||
func: Callable[..., Any],
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
return_direct: bool = False,
|
||||
@@ -147,7 +150,7 @@ class CrewStructuredTool:
|
||||
@staticmethod
|
||||
def _create_schema_from_function(
|
||||
name: str,
|
||||
func: Callable,
|
||||
func: Callable[..., Any],
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic schema from a function's signature.
|
||||
|
||||
@@ -182,7 +185,7 @@ class CrewStructuredTool:
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields) # type: ignore[call-overload]
|
||||
return create_model(schema_name, **fields) # type: ignore[call-overload, no-any-return]
|
||||
|
||||
def _validate_function_signature(self) -> None:
|
||||
"""Validate that the function signature matches the args schema."""
|
||||
@@ -210,7 +213,7 @@ class CrewStructuredTool:
|
||||
f"not found in args_schema"
|
||||
)
|
||||
|
||||
def _parse_args(self, raw_args: str | dict) -> dict:
|
||||
def _parse_args(self, raw_args: str | dict[str, Any]) -> dict[str, Any]:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
Args:
|
||||
@@ -234,8 +237,8 @@ class CrewStructuredTool:
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str | dict,
|
||||
config: dict | None = None,
|
||||
input: str | dict[str, Any],
|
||||
config: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Asynchronously invoke the tool.
|
||||
@@ -269,7 +272,7 @@ class CrewStructuredTool:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _run(self, *args, **kwargs) -> Any:
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Legacy method for compatibility."""
|
||||
# Convert args/kwargs to our expected format
|
||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
|
||||
@@ -277,7 +280,10 @@ class CrewStructuredTool:
|
||||
return self.invoke(input_dict)
|
||||
|
||||
def invoke(
|
||||
self, input: str | dict, config: dict | None = None, **kwargs: Any
|
||||
self,
|
||||
input: str | dict[str, Any],
|
||||
config: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Main method for tool execution."""
|
||||
parsed_args = self._parse_args(input)
|
||||
@@ -313,9 +319,10 @@ class CrewStructuredTool:
|
||||
self._original_tool.current_usage_count = self.current_usage_count
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
def args(self) -> dict[str, Any]:
|
||||
"""Get the tool's input arguments schema."""
|
||||
return self.args_schema.model_json_schema()["properties"]
|
||||
schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"]
|
||||
return schema
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"CrewStructuredTool(name='{sanitize_tool_name(self.name)}', description='{self.description}')"
|
||||
|
||||
@@ -333,9 +333,9 @@ def test_visualization_plot_method():
|
||||
"""Test that flow.plot() method works."""
|
||||
flow = SimpleFlow()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
html_file = flow.plot("test_plot.html", show=False, output_dir=tmp_dir)
|
||||
assert os.path.exists(html_file)
|
||||
html_file = flow.plot("test_plot.html", show=False)
|
||||
|
||||
assert os.path.exists(html_file)
|
||||
|
||||
|
||||
def test_router_paths_to_string_conditions():
|
||||
@@ -667,94 +667,4 @@ def test_no_warning_for_properly_typed_router(caplog):
|
||||
# No warnings should be logged
|
||||
warning_messages = [r.message for r in caplog.records if r.levelno >= logging.WARNING]
|
||||
assert not any("Could not determine return paths" in msg for msg in warning_messages)
|
||||
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)
|
||||
|
||||
|
||||
def test_plot_saves_to_current_working_directory():
|
||||
"""Test that plot() saves the HTML file to the current working directory by default.
|
||||
|
||||
Regression test for https://github.com/crewAIInc/crewAI/issues/4991
|
||||
"""
|
||||
flow = SimpleFlow()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
original_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tmp_dir)
|
||||
html_file = flow.plot("test_cwd_plot.html", show=False)
|
||||
|
||||
# The returned path must live inside the CWD, not a hidden temp dir
|
||||
assert Path(html_file).parent == Path(tmp_dir)
|
||||
assert os.path.exists(html_file)
|
||||
assert html_file == str(Path(tmp_dir) / "test_cwd_plot.html")
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
def test_plot_saves_to_explicit_output_dir():
|
||||
"""Test that plot() saves files to a user-specified output directory."""
|
||||
flow = SimpleFlow()
|
||||
|
||||
with tempfile.TemporaryDirectory() as output_dir:
|
||||
html_file = flow.plot(
|
||||
"custom_output.html", show=False, output_dir=output_dir
|
||||
)
|
||||
|
||||
assert Path(html_file).parent == Path(output_dir)
|
||||
assert os.path.exists(html_file)
|
||||
|
||||
# CSS and JS companion files should also be in the same directory
|
||||
html_path = Path(html_file)
|
||||
css_file = html_path.parent / f"{html_path.stem}_style.css"
|
||||
js_file = html_path.parent / f"{html_path.stem}_script.js"
|
||||
assert css_file.exists()
|
||||
assert js_file.exists()
|
||||
|
||||
|
||||
def test_render_interactive_saves_to_cwd_by_default():
|
||||
"""Test that render_interactive() writes to CWD when output_dir is None.
|
||||
|
||||
Regression test for https://github.com/crewAIInc/crewAI/issues/4991
|
||||
"""
|
||||
flow = SimpleFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
original_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tmp_dir)
|
||||
html_file = visualize_flow_structure(
|
||||
structure, "cwd_test.html", show=False
|
||||
)
|
||||
|
||||
assert Path(html_file).parent == Path(tmp_dir)
|
||||
assert os.path.exists(html_file)
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
def test_render_interactive_saves_to_specified_output_dir():
|
||||
"""Test that render_interactive() writes to the specified output_dir."""
|
||||
flow = SimpleFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
with tempfile.TemporaryDirectory() as output_dir:
|
||||
html_file = visualize_flow_structure(
|
||||
structure, "output_dir_test.html", show=False, output_dir=output_dir
|
||||
)
|
||||
|
||||
assert Path(html_file).parent == Path(output_dir)
|
||||
assert os.path.exists(html_file)
|
||||
|
||||
with open(html_file, "r", encoding="utf-8") as f:
|
||||
html_content = f.read()
|
||||
assert "<!DOCTYPE html>" in html_content
|
||||
|
||||
|
||||
def test_plot_returned_path_is_absolute():
|
||||
"""Test that the path returned by plot() is always absolute."""
|
||||
flow = SimpleFlow()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
html_file = flow.plot("abs_path_test.html", show=False, output_dir=tmp_dir)
|
||||
assert os.path.isabs(html_file)
|
||||
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)
|
||||
@@ -38,6 +38,44 @@ def test_initialization(basic_function, schema_class):
|
||||
assert tool.args_schema == schema_class
|
||||
|
||||
|
||||
def test_cache_function_passed_through(basic_function, schema_class):
|
||||
"""Test that cache_function is stored on CrewStructuredTool."""
|
||||
|
||||
def no_cache(_args: dict, _result: str) -> bool:
|
||||
return False
|
||||
|
||||
tool = CrewStructuredTool(
|
||||
name="test_tool",
|
||||
description="Test tool description",
|
||||
func=basic_function,
|
||||
args_schema=schema_class,
|
||||
cache_function=no_cache,
|
||||
)
|
||||
|
||||
assert tool.cache_function is no_cache
|
||||
|
||||
|
||||
def test_base_tool_passes_cache_function_to_structured_tool():
|
||||
"""Test that BaseTool.to_structured_tool propagates cache_function."""
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
def no_cache(_args: dict, _result: str) -> bool:
|
||||
return False
|
||||
|
||||
class MyCacheTool(BaseTool):
|
||||
name: str = "cache_test"
|
||||
description: str = "tool for testing cache passthrough"
|
||||
|
||||
def _run(self, query: str = "") -> str:
|
||||
return "result"
|
||||
|
||||
my_tool = MyCacheTool()
|
||||
my_tool.cache_function = no_cache # type: ignore[assignment]
|
||||
structured = my_tool.to_structured_tool()
|
||||
|
||||
assert structured.cache_function is no_cache
|
||||
|
||||
|
||||
def test_from_function(basic_function):
|
||||
"""Test creating tool from function"""
|
||||
tool = CrewStructuredTool.from_function(
|
||||
|
||||
Reference in New Issue
Block a user