mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
* feat: add capability to see and expose public Tool classes * feat: persist available Tools from repository on publish * ci: ignore explictly templates from ruff check Ruff only applies --exclude to files it discovers itself. So we have to skip manually the same files excluded from `ruff.toml` * sytle: fix linter issues * refactor: renaming available_tools_classes by available_exports * feat: provide more context about exportable tools * feat: allow to install a Tool from pypi * test: fix tests * feat: add env_vars attribute to BaseTool * remove TODO: security check since we are handle that on enterprise side
453 lines
14 KiB
Python
453 lines
14 KiB
Python
import importlib.util
|
|
import os
|
|
import shutil
|
|
import sys
|
|
from functools import reduce
|
|
from inspect import getmro, isclass, isfunction, ismethod
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, get_type_hints
|
|
|
|
import click
|
|
import tomli
|
|
from rich.console import Console
|
|
|
|
from crewai.cli.constants import ENV_VARS
|
|
from crewai.crew import Crew
|
|
from crewai.flow import Flow
|
|
|
|
if sys.version_info >= (3, 11):
|
|
import tomllib
|
|
|
|
console = Console()
|
|
|
|
|
|
def copy_template(src, dst, name, class_name, folder_name):
|
|
"""Copy a file from src to dst."""
|
|
with open(src, "r") as file:
|
|
content = file.read()
|
|
|
|
# Interpolate the content
|
|
content = content.replace("{{name}}", name)
|
|
content = content.replace("{{crew_name}}", class_name)
|
|
content = content.replace("{{folder_name}}", folder_name)
|
|
|
|
# Write the interpolated content to the new file
|
|
with open(dst, "w") as file:
|
|
file.write(content)
|
|
|
|
click.secho(f" - Created {dst}", fg="green")
|
|
|
|
|
|
def read_toml(file_path: str = "pyproject.toml"):
|
|
"""Read the content of a TOML file and return it as a dictionary."""
|
|
with open(file_path, "rb") as f:
|
|
toml_dict = tomli.load(f)
|
|
return toml_dict
|
|
|
|
|
|
def parse_toml(content):
|
|
if sys.version_info >= (3, 11):
|
|
return tomllib.loads(content)
|
|
return tomli.loads(content)
|
|
|
|
|
|
def get_project_name(
|
|
pyproject_path: str = "pyproject.toml", require: bool = False
|
|
) -> str | None:
|
|
"""Get the project name from the pyproject.toml file."""
|
|
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
|
|
|
|
|
|
def get_project_version(
|
|
pyproject_path: str = "pyproject.toml", require: bool = False
|
|
) -> str | None:
|
|
"""Get the project version from the pyproject.toml file."""
|
|
return _get_project_attribute(
|
|
pyproject_path, ["project", "version"], require=require
|
|
)
|
|
|
|
|
|
def get_project_description(
|
|
pyproject_path: str = "pyproject.toml", require: bool = False
|
|
) -> str | None:
|
|
"""Get the project description from the pyproject.toml file."""
|
|
return _get_project_attribute(
|
|
pyproject_path, ["project", "description"], require=require
|
|
)
|
|
|
|
|
|
def _get_project_attribute(
|
|
pyproject_path: str, keys: List[str], require: bool
|
|
) -> Any | None:
|
|
"""Get an attribute from the pyproject.toml file."""
|
|
attribute = None
|
|
|
|
try:
|
|
with open(pyproject_path, "r") as f:
|
|
pyproject_content = parse_toml(f.read())
|
|
|
|
dependencies = (
|
|
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
|
|
)
|
|
if not any(True for dep in dependencies if "crewai" in dep):
|
|
raise Exception("crewai is not in the dependencies.")
|
|
|
|
attribute = _get_nested_value(pyproject_content, keys)
|
|
except FileNotFoundError:
|
|
print(f"Error: {pyproject_path} not found.")
|
|
except KeyError:
|
|
print(f"Error: {pyproject_path} is not a valid pyproject.toml file.")
|
|
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
|
print(
|
|
f"Error: {pyproject_path} is not a valid TOML file."
|
|
if sys.version_info >= (3, 11)
|
|
else f"Error reading the pyproject.toml file: {e}"
|
|
)
|
|
except Exception as e:
|
|
print(f"Error reading the pyproject.toml file: {e}")
|
|
|
|
if require and not attribute:
|
|
console.print(
|
|
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
|
|
style="bold red",
|
|
)
|
|
raise SystemExit
|
|
|
|
return attribute
|
|
|
|
|
|
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
|
return reduce(dict.__getitem__, keys, data)
|
|
|
|
|
|
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
|
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
|
try:
|
|
# Read the .env file
|
|
with open(env_file_path, "r") as f:
|
|
env_content = f.read()
|
|
|
|
# Parse the .env file content to a dictionary
|
|
env_dict = {}
|
|
for line in env_content.splitlines():
|
|
if line.strip() and not line.strip().startswith("#"):
|
|
key, value = line.split("=", 1)
|
|
env_dict[key.strip()] = value.strip()
|
|
|
|
return env_dict
|
|
|
|
except FileNotFoundError:
|
|
print(f"Error: {env_file_path} not found.")
|
|
except Exception as e:
|
|
print(f"Error reading the .env file: {e}")
|
|
|
|
return {}
|
|
|
|
|
|
def tree_copy(source, destination):
|
|
"""Copies the entire directory structure from the source to the destination."""
|
|
for item in os.listdir(source):
|
|
source_item = os.path.join(source, item)
|
|
destination_item = os.path.join(destination, item)
|
|
if os.path.isdir(source_item):
|
|
shutil.copytree(source_item, destination_item)
|
|
else:
|
|
shutil.copy2(source_item, destination_item)
|
|
|
|
|
|
def tree_find_and_replace(directory, find, replace):
|
|
"""Recursively searches through a directory, replacing a target string in
|
|
both file contents and filenames with a specified replacement string.
|
|
"""
|
|
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
|
|
for filename in files:
|
|
filepath = os.path.join(path, filename)
|
|
|
|
with open(filepath, "r", encoding="utf-8", errors="ignore") as file:
|
|
contents = file.read()
|
|
with open(filepath, "w") as file:
|
|
file.write(contents.replace(find, replace))
|
|
|
|
if find in filename:
|
|
new_filename = filename.replace(find, replace)
|
|
new_filepath = os.path.join(path, new_filename)
|
|
os.rename(filepath, new_filepath)
|
|
|
|
for dirname in dirs:
|
|
if find in dirname:
|
|
new_dirname = dirname.replace(find, replace)
|
|
new_dirpath = os.path.join(path, new_dirname)
|
|
old_dirpath = os.path.join(path, dirname)
|
|
os.rename(old_dirpath, new_dirpath)
|
|
|
|
|
|
def load_env_vars(folder_path):
|
|
"""
|
|
Loads environment variables from a .env file in the specified folder path.
|
|
|
|
Args:
|
|
- folder_path (Path): The path to the folder containing the .env file.
|
|
|
|
Returns:
|
|
- dict: A dictionary of environment variables.
|
|
"""
|
|
env_file_path = folder_path / ".env"
|
|
env_vars = {}
|
|
if env_file_path.exists():
|
|
with open(env_file_path, "r") as file:
|
|
for line in file:
|
|
key, _, value = line.strip().partition("=")
|
|
if key and value:
|
|
env_vars[key] = value
|
|
return env_vars
|
|
|
|
|
|
def update_env_vars(env_vars, provider, model):
|
|
"""
|
|
Updates environment variables with the API key for the selected provider and model.
|
|
|
|
Args:
|
|
- env_vars (dict): Environment variables dictionary.
|
|
- provider (str): Selected provider.
|
|
- model (str): Selected model.
|
|
|
|
Returns:
|
|
- None
|
|
"""
|
|
api_key_var = ENV_VARS.get(
|
|
provider,
|
|
[
|
|
click.prompt(
|
|
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
|
type=str,
|
|
)
|
|
],
|
|
)[0]
|
|
|
|
if api_key_var not in env_vars:
|
|
try:
|
|
env_vars[api_key_var] = click.prompt(
|
|
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
|
|
)
|
|
except click.exceptions.Abort:
|
|
click.secho("Operation aborted by the user.", fg="red")
|
|
return None
|
|
else:
|
|
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
|
|
|
|
env_vars["MODEL"] = model
|
|
click.secho(f"Selected model: {model}", fg="green")
|
|
return env_vars
|
|
|
|
|
|
def write_env_file(folder_path, env_vars):
|
|
"""
|
|
Writes environment variables to a .env file in the specified folder.
|
|
|
|
Args:
|
|
- folder_path (Path): The path to the folder where the .env file will be written.
|
|
- env_vars (dict): A dictionary of environment variables to write.
|
|
"""
|
|
env_file_path = folder_path / ".env"
|
|
with open(env_file_path, "w") as file:
|
|
for key, value in env_vars.items():
|
|
file.write(f"{key}={value}\n")
|
|
|
|
|
|
def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
|
"""Get the crew instances from the a file."""
|
|
crew_instances = []
|
|
try:
|
|
import importlib.util
|
|
|
|
for root, _, files in os.walk("."):
|
|
if crew_path in files:
|
|
crew_os_path = os.path.join(root, crew_path)
|
|
try:
|
|
spec = importlib.util.spec_from_file_location(
|
|
"crew_module", crew_os_path
|
|
)
|
|
if not spec or not spec.loader:
|
|
continue
|
|
module = importlib.util.module_from_spec(spec)
|
|
try:
|
|
sys.modules[spec.name] = module
|
|
spec.loader.exec_module(module)
|
|
|
|
for attr_name in dir(module):
|
|
module_attr = getattr(module, attr_name)
|
|
|
|
try:
|
|
crew_instances.extend(fetch_crews(module_attr))
|
|
except Exception as e:
|
|
print(f"Error processing attribute {attr_name}: {e}")
|
|
continue
|
|
|
|
except Exception as exec_error:
|
|
print(f"Error executing module: {exec_error}")
|
|
import traceback
|
|
|
|
print(f"Traceback: {traceback.format_exc()}")
|
|
except (ImportError, AttributeError) as e:
|
|
if require:
|
|
console.print(
|
|
f"Error importing crew from {crew_path}: {str(e)}",
|
|
style="bold red",
|
|
)
|
|
continue
|
|
|
|
break
|
|
|
|
if require:
|
|
console.print("No valid Crew instance found in crew.py", style="bold red")
|
|
raise SystemExit
|
|
|
|
except Exception as e:
|
|
if require:
|
|
console.print(
|
|
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
|
)
|
|
raise SystemExit
|
|
return crew_instances
|
|
|
|
|
|
def get_crew_instance(module_attr) -> Crew | None:
|
|
if (
|
|
callable(module_attr)
|
|
and hasattr(module_attr, "is_crew_class")
|
|
and module_attr.is_crew_class
|
|
):
|
|
return module_attr().crew()
|
|
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
|
|
module_attr
|
|
).get("return") is Crew:
|
|
return module_attr()
|
|
elif isinstance(module_attr, Crew):
|
|
return module_attr
|
|
else:
|
|
return None
|
|
|
|
|
|
def fetch_crews(module_attr) -> list[Crew]:
|
|
crew_instances: list[Crew] = []
|
|
|
|
if crew_instance := get_crew_instance(module_attr):
|
|
crew_instances.append(crew_instance)
|
|
|
|
if isinstance(module_attr, type) and issubclass(module_attr, Flow):
|
|
instance = module_attr()
|
|
for attr_name in dir(instance):
|
|
attr = getattr(instance, attr_name)
|
|
if crew_instance := get_crew_instance(attr):
|
|
crew_instances.append(crew_instance)
|
|
return crew_instances
|
|
|
|
|
|
def is_valid_tool(obj):
|
|
from crewai.tools.base_tool import Tool
|
|
|
|
if isclass(obj):
|
|
try:
|
|
return any(base.__name__ == "BaseTool" for base in getmro(obj))
|
|
except (TypeError, AttributeError):
|
|
return False
|
|
|
|
return isinstance(obj, Tool)
|
|
|
|
|
|
def extract_available_exports(dir_path: str = "src"):
|
|
"""
|
|
Extract available tool classes from the project's __init__.py files.
|
|
Only includes classes that inherit from BaseTool or functions decorated with @tool.
|
|
|
|
Returns:
|
|
list: A list of valid tool class names or ["BaseTool"] if none found
|
|
"""
|
|
try:
|
|
init_files = Path(dir_path).glob("**/__init__.py")
|
|
available_exports = []
|
|
|
|
for init_file in init_files:
|
|
tools = _load_tools_from_init(init_file)
|
|
available_exports.extend(tools)
|
|
|
|
if not available_exports:
|
|
_print_no_tools_warning()
|
|
raise SystemExit(1)
|
|
|
|
return available_exports
|
|
|
|
except Exception as e:
|
|
console.print(f"[red]Error: Could not extract tool classes: {str(e)}[/red]")
|
|
console.print(
|
|
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
|
|
)
|
|
raise SystemExit(1)
|
|
|
|
|
|
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
|
"""
|
|
Load and validate tools from a given __init__.py file.
|
|
"""
|
|
spec = importlib.util.spec_from_file_location("temp_module", init_file)
|
|
|
|
if not spec or not spec.loader:
|
|
return []
|
|
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules["temp_module"] = module
|
|
|
|
try:
|
|
spec.loader.exec_module(module)
|
|
|
|
if not hasattr(module, "__all__"):
|
|
console.print(
|
|
f"[bold yellow]Warning: No __all__ defined in {init_file}[/bold yellow]"
|
|
)
|
|
raise SystemExit(1)
|
|
|
|
return [
|
|
{
|
|
"name": name,
|
|
}
|
|
for name in module.__all__
|
|
if hasattr(module, name) and is_valid_tool(getattr(module, name))
|
|
]
|
|
|
|
except Exception as e:
|
|
console.print(f"[red]Warning: Could not load {init_file}: {str(e)}[/red]")
|
|
raise SystemExit(1)
|
|
|
|
finally:
|
|
sys.modules.pop("temp_module", None)
|
|
|
|
|
|
def _print_no_tools_warning():
|
|
"""
|
|
Display warning and usage instructions if no tools were found.
|
|
"""
|
|
console.print(
|
|
"\n[bold yellow]Warning: No valid tools were exposed in your __init__.py file![/bold yellow]"
|
|
)
|
|
console.print(
|
|
"Your __init__.py file must contain all classes that inherit from [bold]BaseTool[/bold] "
|
|
"or functions decorated with [bold]@tool[/bold]."
|
|
)
|
|
console.print(
|
|
"\nExample:\n[dim]# In your __init__.py file[/dim]\n"
|
|
"[green]__all__ = ['YourTool', 'your_tool_function'][/green]\n\n"
|
|
"[dim]# In your tool.py file[/dim]\n"
|
|
"[green]from crewai.tools import BaseTool, tool\n\n"
|
|
"# Tool class example\n"
|
|
"class YourTool(BaseTool):\n"
|
|
' name = "your_tool"\n'
|
|
' description = "Your tool description"\n'
|
|
" # ... rest of implementation\n\n"
|
|
"# Decorated function example\n"
|
|
"@tool\n"
|
|
"def your_tool_function(text: str) -> str:\n"
|
|
' """Your tool description"""\n'
|
|
" # ... implementation\n"
|
|
" return result\n"
|
|
)
|