Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
c06478afcc Fix lint and type-checker issues
- Add 'cast' import to fix mypy type compatibility error
- Remove unused imports to fix lint warnings
- Add assertions to reproduction script to use agent variables
- All custom tool patterns now work correctly

Co-Authored-By: João <joao@crewai.com>
2025-07-27 01:54:23 +00:00
Devin AI
b3be8a6588 Fix custom tool registration in CrewAI 0.150.0
- Update validate_tools method in BaseAgent to accept all documented tool patterns:
  * Function tools (with or without @tool decorator)
  * Dict-based tool definitions
  * BaseTool class inheritance
  * Direct function assignment
- Change tools field type annotation from List[BaseTool] to List[Any] to allow Pydantic validation
- Update parse_tools function to accept all BaseTool instances (not just CrewAITool)
- Add comprehensive tests covering all custom tool patterns from issue #3226
- Add reproduction script to verify all patterns work correctly

Fixes #3226

Co-Authored-By: João <joao@crewai.com>
2025-07-27 01:46:12 +00:00
29 changed files with 571 additions and 687 deletions

View File

@@ -48,7 +48,7 @@ Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = ["crewai-tools~=0.59.0"]
tools = ["crewai-tools~=0.58.0"]
embeddings = [
"tiktoken~=0.8.0"
]

144
reproduce_issue_3226.py Normal file
View File

@@ -0,0 +1,144 @@
#!/usr/bin/env python3
"""
Reproduction script for issue #3226: Cannot Register Custom Tools with Agents in CrewAI 0.150.0
This script tests all the failing patterns mentioned in the issue.
"""
import sys
import traceback
def test_function_tool():
"""Test 1: Function Tool with @tool decorator"""
print("=== Test 1: Function Tool with @tool decorator ===")
try:
from crewai.tools import tool
from crewai import Agent
@tool
def fetch_logs(query: str) -> str:
"""Fetch logs from New Relic based on query"""
return f"Logs for query: {query}"
agent = Agent(
role='CrashFetcher',
goal='Extract logs',
backstory='An agent that fetches logs',
tools=[fetch_logs],
allow_delegation=False
)
assert len(agent.tools) == 1, f"Expected 1 tool, got {len(agent.tools)}"
print("✅ Function tool with @tool decorator: SUCCESS")
return True
except Exception as e:
print(f"❌ Function tool with @tool decorator: FAILED - {e}")
traceback.print_exc()
return False
def test_dict_tool():
"""Test 2: Dict-based tool definition"""
print("\n=== Test 2: Dict-based tool definition ===")
try:
from crewai import Agent
def fetch_logs_func(query: str) -> str:
return f"Logs for query: {query}"
fetch_logs_dict = {
'name': 'fetch_logs',
'description': 'Fetch logs from New Relic',
'func': fetch_logs_func
}
agent = Agent(
role='CrashFetcher',
goal='Extract logs',
backstory='An agent that fetches logs',
tools=[fetch_logs_dict],
allow_delegation=False
)
assert len(agent.tools) == 1, f"Expected 1 tool, got {len(agent.tools)}"
print("✅ Dict-based tool: SUCCESS")
return True
except Exception as e:
print(f"❌ Dict-based tool: FAILED - {e}")
traceback.print_exc()
return False
def test_basetool_class():
"""Test 3: BaseTool class inheritance"""
print("\n=== Test 3: BaseTool class inheritance ===")
try:
from crewai.tools import BaseTool
from crewai import Agent
class FetchLogsTool(BaseTool):
name: str = "fetch_logs"
description: str = "Fetch logs from New Relic based on query"
def _run(self, query: str) -> str:
return f"Logs for query: {query}"
agent = Agent(
role='CrashFetcher',
goal='Extract logs',
backstory='An agent that fetches logs',
tools=[FetchLogsTool()],
allow_delegation=False
)
assert len(agent.tools) == 1, f"Expected 1 tool, got {len(agent.tools)}"
print("✅ BaseTool class inheritance: SUCCESS")
return True
except Exception as e:
print(f"❌ BaseTool class inheritance: FAILED - {e}")
traceback.print_exc()
return False
def test_direct_function():
"""Test 4: Direct function assignment"""
print("\n=== Test 4: Direct function assignment ===")
try:
from crewai import Agent
def fetch_logs(query: str) -> str:
"""Fetch logs from New Relic based on query"""
return f"Logs for query: {query}"
agent = Agent(
role='CrashFetcher',
goal='Extract logs',
backstory='An agent that fetches logs',
tools=[fetch_logs],
allow_delegation=False
)
assert len(agent.tools) == 1, f"Expected 1 tool, got {len(agent.tools)}"
print("✅ Direct function assignment: SUCCESS")
return True
except Exception as e:
print(f"❌ Direct function assignment: FAILED - {e}")
traceback.print_exc()
return False
def main():
"""Run all tests and report results"""
print("Testing custom tool registration patterns from issue #3226\n")
results = []
results.append(test_function_tool())
results.append(test_dict_tool())
results.append(test_basetool_class())
results.append(test_direct_function())
print("\n=== SUMMARY ===")
passed = sum(results)
total = len(results)
print(f"Tests passed: {passed}/{total}")
if passed == total:
print("🎉 All custom tool patterns are working!")
return 0
else:
print("💥 Some custom tool patterns are still broken")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -54,7 +54,7 @@ def _track_install_async():
_track_install_async()
__version__ = "0.152.0"
__version__ = "0.150.0"
__all__ = [
"Agent",
"Crew",

View File

@@ -2,7 +2,7 @@ import uuid
from abc import ABC, abstractmethod
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
from pydantic import (
UUID4,
@@ -25,7 +25,6 @@ from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.utilities import I18N, Logger, RPMController
from crewai.utilities.config import process_config
from crewai.utilities.converter import Converter
from crewai.utilities.string_utils import interpolate_only
T = TypeVar("T", bound="BaseAgent")
@@ -108,7 +107,7 @@ class BaseAgent(ABC, BaseModel):
default=False,
description="Enable agent to delegate and ask questions among each other.",
)
tools: Optional[List[BaseTool]] = Field(
tools: Optional[List[Any]] = Field(
default_factory=list, description="Tools at agents' disposal"
)
max_iter: int = Field(
@@ -171,27 +170,48 @@ class BaseAgent(ABC, BaseModel):
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
or an object with 'name', 'func', and 'description' attributes. If the
tool meets these criteria, it is processed and added to the list of
tools. Otherwise, a ValueError is raised.
This method ensures that each tool is either an instance of BaseTool,
a function (with or without @tool decorator), a dict with tool definition,
or an object with 'name', 'func', and 'description' attributes.
"""
if not tools:
return []
from crewai.tools.base_tool import tool
processed_tools = []
required_attrs = ["name", "func", "description"]
for tool in tools:
if isinstance(tool, BaseTool):
processed_tools.append(tool)
elif all(hasattr(tool, attr) for attr in required_attrs):
# Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool))
for tool_item in tools:
if isinstance(tool_item, BaseTool):
processed_tools.append(tool_item)
elif callable(tool_item):
if hasattr(tool_item, '__doc__') and tool_item.__doc__:
converted_tool = cast(BaseTool, tool(tool_item))
processed_tools.append(converted_tool)
else:
raise ValueError(
f"Function tool '{tool_item.__name__}' must have a docstring"
)
elif isinstance(tool_item, dict):
required_keys = ['name', 'description', 'func']
if all(key in tool_item for key in required_keys):
processed_tools.append(Tool(
name=tool_item['name'],
description=tool_item['description'],
func=tool_item['func']
))
else:
raise ValueError(
f"Dict tool must contain keys: {required_keys}. "
f"Got: {list(tool_item.keys())}"
)
elif hasattr(tool_item, 'name') and hasattr(tool_item, 'func') and hasattr(tool_item, 'description'):
processed_tools.append(Tool.from_langchain(tool_item))
else:
raise ValueError(
f"Invalid tool type: {type(tool)}. "
"Tool must be an instance of BaseTool or "
"an object with 'name', 'func', and 'description' attributes."
f"Invalid tool type: {type(tool_item)}. "
"Tool must be a BaseTool instance, function, dict with "
"'name'/'description'/'func' keys, or object with "
"'name'/'func'/'description' attributes."
)
return processed_tools

View File

@@ -3,7 +3,6 @@ from typing import Optional
import click
from crewai.cli.config import Settings
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.add_crew_to_flow import add_crew_to_flow
from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow
@@ -228,7 +227,7 @@ def update():
@crewai.command()
def login():
"""Sign Up/Login to CrewAI Enterprise."""
Settings().clear_user_settings()
Settings().clear()
AuthenticationCommand().login()
@@ -370,8 +369,8 @@ def org():
pass
@org.command("list")
def org_list():
@org.command()
def list():
"""List available organizations."""
org_command = OrganizationCommand()
org_command.list()
@@ -392,34 +391,5 @@ def current():
org_command.current()
@crewai.group()
def config():
"""CLI Configuration commands."""
pass
@config.command("list")
def config_list():
"""List all CLI configuration parameters."""
config_command = SettingsCommand()
config_command.list()
@config.command("set")
@click.argument("key")
@click.argument("value")
def config_set(key: str, value: str):
"""Set a CLI configuration parameter."""
config_command = SettingsCommand()
config_command.set(key, value)
@config.command("reset")
def config_reset():
"""Reset all CLI configuration parameters to default values."""
config_command = SettingsCommand()
config_command.reset_all_settings()
if __name__ == "__main__":
crewai()

View File

@@ -4,47 +4,10 @@ from typing import Optional
from pydantic import BaseModel, Field
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
# Settings that are related to the user's account
USER_SETTINGS_KEYS = [
"tool_repository_username",
"tool_repository_password",
"org_name",
"org_uuid",
]
# Settings that are related to the CLI
CLI_SETTINGS_KEYS = [
"enterprise_base_url",
]
# Default values for CLI settings
DEFAULT_CLI_SETTINGS = {
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
}
# Readonly settings - cannot be set by the user
READONLY_SETTINGS_KEYS = [
"org_name",
"org_uuid",
]
# Hidden settings - not displayed by the 'list' command and cannot be set by the user
HIDDEN_SETTINGS_KEYS = [
"config_path",
"tool_repository_username",
"tool_repository_password",
]
class Settings(BaseModel):
enterprise_base_url: Optional[str] = Field(
default=DEFAULT_CREWAI_ENTERPRISE_URL,
description="Base URL of the CrewAI Enterprise instance",
)
tool_repository_username: Optional[str] = Field(
None, description="Username for interacting with the Tool Repository"
)
@@ -57,7 +20,7 @@ class Settings(BaseModel):
org_uuid: Optional[str] = Field(
None, description="UUID of the currently active organization"
)
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True)
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
"""Load Settings from config path"""
@@ -74,16 +37,9 @@ class Settings(BaseModel):
merged_data = {**file_data, **data}
super().__init__(config_path=config_path, **merged_data)
def clear_user_settings(self) -> None:
"""Clear all user settings"""
self._reset_user_settings()
self.dump()
def reset(self) -> None:
"""Reset all settings to default values"""
self._reset_user_settings()
self._reset_cli_settings()
self.dump()
def clear(self) -> None:
"""Clear all settings"""
self.config_path.unlink(missing_ok=True)
def dump(self) -> None:
"""Save current settings to settings.json"""
@@ -96,13 +52,3 @@ class Settings(BaseModel):
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
with self.config_path.open("w") as f:
json.dump(updated_data, f, indent=4)
def _reset_user_settings(self) -> None:
"""Reset all user settings to default values"""
for key in USER_SETTINGS_KEYS:
setattr(self, key, None)
def _reset_cli_settings(self) -> None:
"""Reset all CLI settings to default values"""
for key in CLI_SETTINGS_KEYS:
setattr(self, key, DEFAULT_CLI_SETTINGS[key])

View File

@@ -1,5 +1,3 @@
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
ENV_VARS = {
"openai": [
{
@@ -322,4 +320,5 @@ DEFAULT_LLM_MODEL = "gpt-4o-mini"
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]

View File

@@ -1,3 +1,4 @@
from os import getenv
from typing import List, Optional
from urllib.parse import urljoin
@@ -5,7 +6,6 @@ import requests
from crewai.cli.config import Settings
from crewai.cli.version import get_crewai_version
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
class PlusAPI:
@@ -29,10 +29,7 @@ class PlusAPI:
settings = Settings()
if settings.org_uuid:
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
self.base_url = (
str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
)
self.base_url = getenv("CREWAI_BASE_URL", "https://app.crewai.com")
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
url = urljoin(self.base_url, endpoint)
@@ -111,6 +108,7 @@ class PlusAPI:
def create_crew(self, payload) -> requests.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
def get_organizations(self) -> requests.Response:
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)

View File

@@ -1,67 +0,0 @@
from rich.console import Console
from rich.table import Table
from crewai.cli.command import BaseCommand
from crewai.cli.config import Settings, READONLY_SETTINGS_KEYS, HIDDEN_SETTINGS_KEYS
from typing import Any
console = Console()
class SettingsCommand(BaseCommand):
"""A class to handle CLI configuration commands."""
def __init__(self, settings_kwargs: dict[str, Any] = {}):
super().__init__()
self.settings = Settings(**settings_kwargs)
def list(self) -> None:
"""List all CLI configuration parameters."""
table = Table(title="CrewAI CLI Configuration")
table.add_column("Setting", style="cyan", no_wrap=True)
table.add_column("Value", style="green")
table.add_column("Description", style="yellow")
# Add all settings to the table
for field_name, field_info in Settings.model_fields.items():
if field_name in HIDDEN_SETTINGS_KEYS:
# Do not display hidden settings
continue
current_value = getattr(self.settings, field_name)
description = field_info.description or "No description available"
display_value = (
str(current_value) if current_value is not None else "Not set"
)
table.add_row(field_name, display_value, description)
console.print(table)
def set(self, key: str, value: str) -> None:
"""Set a CLI configuration parameter."""
readonly_settings = READONLY_SETTINGS_KEYS + HIDDEN_SETTINGS_KEYS
if not hasattr(self.settings, key) or key in readonly_settings:
console.print(
f"Error: Unknown or readonly configuration key '{key}'",
style="bold red",
)
console.print("Available keys:", style="yellow")
for field_name in Settings.model_fields.keys():
if field_name not in readonly_settings:
console.print(f" - {field_name}", style="yellow")
raise SystemExit(1)
setattr(self.settings, key, value)
self.settings.dump()
console.print(f"Successfully set '{key}' to '{value}'", style="bold green")
def reset_all_settings(self) -> None:
"""Reset all CLI configuration parameters to default values."""
self.settings.reset()
console.print(
"Successfully reset all configuration parameters to default values. It is recommended to run [bold yellow]'crewai login'[/bold yellow] to re-authenticate.",
style="bold green",
)

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.152.0,<1.0.0"
"crewai[tools]>=0.150.0,<1.0.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.152.0,<1.0.0",
"crewai[tools]>=0.150.0,<1.0.0",
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.152.0"
"crewai[tools]>=0.150.0"
]
[tool.crewai]

View File

@@ -436,7 +436,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
_routers: Set[str] = set()
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None
name: Optional[str] = None
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
class _FlowGeneric(cls): # type: ignore
@@ -474,7 +473,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self,
FlowCreatedEvent(
type="flow_created",
flow_name=self.name or self.__class__.__name__,
flow_name=self.__class__.__name__,
),
)
@@ -770,7 +769,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self,
FlowStartedEvent(
type="flow_started",
flow_name=self.name or self.__class__.__name__,
flow_name=self.__class__.__name__,
inputs=inputs,
),
)
@@ -793,7 +792,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self,
FlowFinishedEvent(
type="flow_finished",
flow_name=self.name or self.__class__.__name__,
flow_name=self.__class__.__name__,
result=final_output,
),
)
@@ -835,7 +834,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
MethodExecutionStartedEvent(
type="method_execution_started",
method_name=method_name,
flow_name=self.name or self.__class__.__name__,
flow_name=self.__class__.__name__,
params=dumped_params,
state=self._copy_state(),
),
@@ -857,7 +856,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=method_name,
flow_name=self.name or self.__class__.__name__,
flow_name=self.__class__.__name__,
state=self._copy_state(),
result=result,
),
@@ -870,7 +869,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
MethodExecutionFailedEvent(
type="method_execution_failed",
method_name=method_name,
flow_name=self.name or self.__class__.__name__,
flow_name=self.__class__.__name__,
error=e,
),
)
@@ -1077,7 +1076,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self,
FlowPlotEvent(
type="flow_plot",
flow_name=self.name or self.__class__.__name__,
flow_name=self.__class__.__name__,
),
)
plot_flow(self, filename)

View File

@@ -1,8 +1,7 @@
import os
from typing import Any, Dict, List
from collections import defaultdict
from mem0 import Memory, MemoryClient
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.memory.storage.interface import Storage
@@ -71,32 +70,26 @@ class Mem0Storage(Storage):
"""
Returns:
dict: A filter dictionary containing AND conditions for querying data.
- Includes user_id and agent_id if both are present.
- Includes user_id if only user_id is present.
- Includes agent_id if only agent_id is present.
- Includes user_id if memory_type is 'external'.
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present.
"""
filter = defaultdict(list)
filter = {
"AND": []
}
# Add user_id condition if the memory type is external
if self.memory_type == "external":
filter["AND"].append({"user_id": self.config.get("user_id", "")})
# Add run_id condition if the memory type is short_term and a run ID is set
if self.memory_type == "short_term" and self.mem0_run_id:
filter["AND"].append({"run_id": self.mem0_run_id})
else:
user_id = self.config.get("user_id", "")
agent_id = self.config.get("agent_id", "")
if user_id and agent_id:
filter["OR"].append({"user_id": user_id})
filter["OR"].append({"agent_id": agent_id})
elif user_id:
filter["AND"].append({"user_id": user_id})
elif agent_id:
filter["AND"].append({"agent_id": agent_id})
return filter
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
user_id = self.config.get("user_id", "")
assistant_message = [{"role" : "assistant","content" : value}]
assistant_message = [{"role" : "assistant","content" : value}]
base_metadata = {
"short_term": "short_term",
@@ -111,32 +104,31 @@ class Mem0Storage(Storage):
"infer": self.infer
}
# MemoryClient-specific overrides
if isinstance(self.memory, MemoryClient):
params["includes"] = self.includes
params["excludes"] = self.excludes
params["output_format"] = "v1.1"
params["version"] = "v2"
if self.memory_type == "short_term" and self.mem0_run_id:
params["run_id"] = self.mem0_run_id
if user_id:
if self.memory_type == "external":
params["user_id"] = user_id
if agent_id := self.config.get("agent_id", self._get_agent_name()):
params["agent_id"] = agent_id
if params:
# MemoryClient-specific overrides
if isinstance(self.memory, MemoryClient):
params["includes"] = self.includes
params["excludes"] = self.excludes
params["output_format"] = "v1.1"
params["version"]="v2"
self.memory.add(assistant_message, **params)
if self.memory_type == "short_term":
params["run_id"] = self.mem0_run_id
self.memory.add(assistant_message, **params)
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
params = {
"query": query,
"limit": limit,
"query": query,
"limit": limit,
"version": "v2",
"output_format": "v1.1"
}
if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id
@@ -146,7 +138,7 @@ class Mem0Storage(Storage):
"entities": {"type": "entity"},
"external": {"type": "external"},
}
if self.memory_type in memory_type_map:
params["metadata"] = memory_type_map[self.memory_type]
if self.memory_type == "short_term":
@@ -159,29 +151,11 @@ class Mem0Storage(Storage):
params['threshold'] = score_threshold
if isinstance(self.memory, Memory):
params.pop("metadata", None)
params.pop("version", None)
params.pop("output_format", None)
params.pop("run_id", None)
del params["metadata"], params["version"], params["run_id"], params['output_format']
results = self.memory.search(**params)
return [r for r in results["results"]]
def reset(self):
if self.memory:
self.memory.reset()
def _sanitize_role(self, role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _get_agent_name(self) -> str:
if not self.crew:
return ""
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)

View File

@@ -38,14 +38,7 @@ class EmbeddingConfigurator:
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
)
try:
embedding_function = self.embedding_functions[provider]
except ImportError as e:
missing_package = str(e).split()[-1]
raise ImportError(
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
)
embedding_function = self.embedding_functions[provider]
return (
embedding_function(config)
if provider == "custom"

View File

@@ -1,7 +1,8 @@
from .base_tool import BaseTool, tool, EnvVar
from .base_tool import BaseTool, tool, EnvVar, Tool
__all__ = [
"BaseTool",
"Tool",
"tool",
"EnvVar",
]
]

View File

@@ -11,7 +11,6 @@ from crewai.agents.parser import (
)
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.tools import BaseTool as CrewAITool
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
@@ -30,10 +29,13 @@ def parse_tools(tools: List[BaseTool]) -> List[CrewStructuredTool]:
tools_list = []
for tool in tools:
if isinstance(tool, CrewAITool):
if isinstance(tool, BaseTool):
tools_list.append(tool.to_structured_tool())
else:
raise ValueError("Tool is not a CrewStructuredTool or BaseTool")
raise ValueError(
f"Tool must be a BaseTool instance, got {type(tool)}. "
"Ensure tools are properly validated before calling parse_tools."
)
return tools_list

View File

@@ -1,5 +1,6 @@
from datetime import datetime, timezone
from datetime import datetime
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from crewai.utilities.serialization import to_serializable
@@ -8,7 +9,7 @@ from crewai.utilities.serialization import to_serializable
class BaseEvent(BaseModel):
"""Base class for all events"""
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
timestamp: datetime = Field(default_factory=datetime.now)
type: str
source_fingerprint: Optional[str] = None # UUID string of the source entity
source_type: Optional[str] = None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"

View File

@@ -4,12 +4,7 @@ import tempfile
import unittest
from pathlib import Path
from crewai.cli.config import (
Settings,
USER_SETTINGS_KEYS,
CLI_SETTINGS_KEYS,
DEFAULT_CLI_SETTINGS,
)
from crewai.cli.config import Settings
class TestSettings(unittest.TestCase):
@@ -57,30 +52,6 @@ class TestSettings(unittest.TestCase):
self.assertEqual(settings.tool_repository_username, "new_user")
self.assertEqual(settings.tool_repository_password, "file_pass")
def test_clear_user_settings(self):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
settings = Settings(config_path=self.config_path, **user_settings)
settings.clear_user_settings()
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
def test_reset_settings(self):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
settings = Settings(
config_path=self.config_path, **user_settings, **cli_settings
)
settings.reset()
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
for key in cli_settings.keys():
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS[key])
def test_dump_new_settings(self):
settings = Settings(
config_path=self.config_path, tool_repository_username="user1"

View File

@@ -6,7 +6,7 @@ from click.testing import CliRunner
import requests
from crewai.cli.organization.main import OrganizationCommand
from crewai.cli.cli import org_list, switch, current
from crewai.cli.cli import list, switch, current
@pytest.fixture
@@ -16,44 +16,44 @@ def runner():
@pytest.fixture
def org_command():
with patch.object(OrganizationCommand, "__init__", return_value=None):
with patch.object(OrganizationCommand, '__init__', return_value=None):
command = OrganizationCommand()
yield command
@pytest.fixture
def mock_settings():
with patch("crewai.cli.organization.main.Settings") as mock_settings_class:
with patch('crewai.cli.organization.main.Settings') as mock_settings_class:
mock_settings_instance = MagicMock()
mock_settings_class.return_value = mock_settings_instance
yield mock_settings_instance
@patch("crewai.cli.cli.OrganizationCommand")
@patch('crewai.cli.cli.OrganizationCommand')
def test_org_list_command(mock_org_command_class, runner):
mock_org_instance = MagicMock()
mock_org_command_class.return_value = mock_org_instance
result = runner.invoke(org_list)
result = runner.invoke(list)
assert result.exit_code == 0
mock_org_command_class.assert_called_once()
mock_org_instance.list.assert_called_once()
@patch("crewai.cli.cli.OrganizationCommand")
@patch('crewai.cli.cli.OrganizationCommand')
def test_org_switch_command(mock_org_command_class, runner):
mock_org_instance = MagicMock()
mock_org_command_class.return_value = mock_org_instance
result = runner.invoke(switch, ["test-id"])
result = runner.invoke(switch, ['test-id'])
assert result.exit_code == 0
mock_org_command_class.assert_called_once()
mock_org_instance.switch.assert_called_once_with("test-id")
mock_org_instance.switch.assert_called_once_with('test-id')
@patch("crewai.cli.cli.OrganizationCommand")
@patch('crewai.cli.cli.OrganizationCommand')
def test_org_current_command(mock_org_command_class, runner):
mock_org_instance = MagicMock()
mock_org_command_class.return_value = mock_org_instance
@@ -67,18 +67,18 @@ def test_org_current_command(mock_org_command_class, runner):
class TestOrganizationCommand(unittest.TestCase):
def setUp(self):
with patch.object(OrganizationCommand, "__init__", return_value=None):
with patch.object(OrganizationCommand, '__init__', return_value=None):
self.org_command = OrganizationCommand()
self.org_command.plus_api_client = MagicMock()
@patch("crewai.cli.organization.main.console")
@patch("crewai.cli.organization.main.Table")
@patch('crewai.cli.organization.main.console')
@patch('crewai.cli.organization.main.Table')
def test_list_organizations_success(self, mock_table, mock_console):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = [
{"name": "Org 1", "uuid": "org-123"},
{"name": "Org 2", "uuid": "org-456"},
{"name": "Org 2", "uuid": "org-456"}
]
self.org_command.plus_api_client = MagicMock()
self.org_command.plus_api_client.get_organizations.return_value = mock_response
@@ -89,14 +89,16 @@ class TestOrganizationCommand(unittest.TestCase):
self.org_command.plus_api_client.get_organizations.assert_called_once()
mock_table.assert_called_once_with(title="Your Organizations")
mock_table.return_value.add_column.assert_has_calls(
[call("Name", style="cyan"), call("ID", style="green")]
)
mock_table.return_value.add_row.assert_has_calls(
[call("Org 1", "org-123"), call("Org 2", "org-456")]
)
mock_table.return_value.add_column.assert_has_calls([
call("Name", style="cyan"),
call("ID", style="green")
])
mock_table.return_value.add_row.assert_has_calls([
call("Org 1", "org-123"),
call("Org 2", "org-456")
])
@patch("crewai.cli.organization.main.console")
@patch('crewai.cli.organization.main.console')
def test_list_organizations_empty(self, mock_console):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
@@ -108,32 +110,33 @@ class TestOrganizationCommand(unittest.TestCase):
self.org_command.plus_api_client.get_organizations.assert_called_once()
mock_console.print.assert_called_once_with(
"You don't belong to any organizations yet.", style="yellow"
"You don't belong to any organizations yet.",
style="yellow"
)
@patch("crewai.cli.organization.main.console")
@patch('crewai.cli.organization.main.console')
def test_list_organizations_api_error(self, mock_console):
self.org_command.plus_api_client = MagicMock()
self.org_command.plus_api_client.get_organizations.side_effect = (
requests.exceptions.RequestException("API Error")
)
self.org_command.plus_api_client.get_organizations.side_effect = requests.exceptions.RequestException("API Error")
with pytest.raises(SystemExit):
self.org_command.list()
self.org_command.plus_api_client.get_organizations.assert_called_once()
mock_console.print.assert_called_once_with(
"Failed to retrieve organization list: API Error", style="bold red"
"Failed to retrieve organization list: API Error",
style="bold red"
)
@patch("crewai.cli.organization.main.console")
@patch("crewai.cli.organization.main.Settings")
@patch('crewai.cli.organization.main.console')
@patch('crewai.cli.organization.main.Settings')
def test_switch_organization_success(self, mock_settings_class, mock_console):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = [
{"name": "Org 1", "uuid": "org-123"},
{"name": "Test Org", "uuid": "test-id"},
{"name": "Test Org", "uuid": "test-id"}
]
self.org_command.plus_api_client = MagicMock()
self.org_command.plus_api_client.get_organizations.return_value = mock_response
@@ -148,16 +151,17 @@ class TestOrganizationCommand(unittest.TestCase):
assert mock_settings_instance.org_name == "Test Org"
assert mock_settings_instance.org_uuid == "test-id"
mock_console.print.assert_called_once_with(
"Successfully switched to Test Org (test-id)", style="bold green"
"Successfully switched to Test Org (test-id)",
style="bold green"
)
@patch("crewai.cli.organization.main.console")
@patch('crewai.cli.organization.main.console')
def test_switch_organization_not_found(self, mock_console):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = [
{"name": "Org 1", "uuid": "org-123"},
{"name": "Org 2", "uuid": "org-456"},
{"name": "Org 2", "uuid": "org-456"}
]
self.org_command.plus_api_client = MagicMock()
self.org_command.plus_api_client.get_organizations.return_value = mock_response
@@ -166,11 +170,12 @@ class TestOrganizationCommand(unittest.TestCase):
self.org_command.plus_api_client.get_organizations.assert_called_once()
mock_console.print.assert_called_once_with(
"Organization with id 'non-existent-id' not found.", style="bold red"
"Organization with id 'non-existent-id' not found.",
style="bold red"
)
@patch("crewai.cli.organization.main.console")
@patch("crewai.cli.organization.main.Settings")
@patch('crewai.cli.organization.main.console')
@patch('crewai.cli.organization.main.Settings')
def test_current_organization_with_org(self, mock_settings_class, mock_console):
mock_settings_instance = MagicMock()
mock_settings_instance.org_name = "Test Org"
@@ -181,11 +186,12 @@ class TestOrganizationCommand(unittest.TestCase):
self.org_command.plus_api_client.get_organizations.assert_not_called()
mock_console.print.assert_called_once_with(
"Currently logged in to organization Test Org (test-id)", style="bold green"
"Currently logged in to organization Test Org (test-id)",
style="bold green"
)
@patch("crewai.cli.organization.main.console")
@patch("crewai.cli.organization.main.Settings")
@patch('crewai.cli.organization.main.console')
@patch('crewai.cli.organization.main.Settings')
def test_current_organization_without_org(self, mock_settings_class, mock_console):
mock_settings_instance = MagicMock()
mock_settings_instance.org_uuid = None
@@ -195,14 +201,16 @@ class TestOrganizationCommand(unittest.TestCase):
assert mock_console.print.call_count == 3
mock_console.print.assert_any_call(
"You're not currently logged in to any organization.", style="yellow"
"You're not currently logged in to any organization.",
style="yellow"
)
@patch("crewai.cli.organization.main.console")
@patch('crewai.cli.organization.main.console')
def test_list_organizations_unauthorized(self, mock_console):
mock_response = MagicMock()
mock_http_error = requests.exceptions.HTTPError(
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
"401 Client Error: Unauthorized",
response=MagicMock(status_code=401)
)
mock_response.raise_for_status.side_effect = mock_http_error
@@ -213,14 +221,15 @@ class TestOrganizationCommand(unittest.TestCase):
self.org_command.plus_api_client.get_organizations.assert_called_once()
mock_console.print.assert_called_once_with(
"You are not logged in to any organization. Use 'crewai login' to login.",
style="bold red",
style="bold red"
)
@patch("crewai.cli.organization.main.console")
@patch('crewai.cli.organization.main.console')
def test_switch_organization_unauthorized(self, mock_console):
mock_response = MagicMock()
mock_http_error = requests.exceptions.HTTPError(
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
"401 Client Error: Unauthorized",
response=MagicMock(status_code=401)
)
mock_response.raise_for_status.side_effect = mock_http_error
@@ -231,5 +240,5 @@ class TestOrganizationCommand(unittest.TestCase):
self.org_command.plus_api_client.get_organizations.assert_called_once()
mock_console.print.assert_called_once_with(
"You are not logged in to any organization. Use 'crewai login' to login.",
style="bold red",
style="bold red"
)

View File

@@ -1,8 +1,8 @@
import os
import unittest
from unittest.mock import MagicMock, patch, ANY
from crewai.cli.plus_api import PlusAPI
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
class TestPlusAPI(unittest.TestCase):
@@ -30,41 +30,29 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
def assert_request_with_org_id(
self, mock_make_request, method: str, endpoint: str, **kwargs
):
def assert_request_with_org_id(self, mock_make_request, method: str, endpoint: str, **kwargs):
mock_make_request.assert_called_once_with(
method,
f"{DEFAULT_CREWAI_ENTERPRISE_URL}{endpoint}",
headers={
"Authorization": ANY,
"Content-Type": ANY,
"User-Agent": ANY,
"X-Crewai-Version": ANY,
"X-Crewai-Organization-Id": self.org_uuid,
},
**kwargs,
method, f"https://app.crewai.com{endpoint}", headers={'Authorization': ANY, 'Content-Type': ANY, 'User-Agent': ANY, 'X-Crewai-Version': ANY, 'X-Crewai-Organization-Id': self.org_uuid}, **kwargs
)
@patch("crewai.cli.plus_api.Settings")
@patch("requests.Session.request")
def test_login_to_tool_repository_with_org_uuid(
self, mock_make_request, mock_settings_class
):
def test_login_to_tool_repository_with_org_uuid(self, mock_make_request, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.login_to_tool_repository()
self.assert_request_with_org_id(
mock_make_request, "POST", "/crewai_plus/api/v1/tools/login"
mock_make_request,
'POST',
'/crewai_plus/api/v1/tools/login'
)
self.assertEqual(response, mock_response)
@@ -78,27 +66,28 @@ class TestPlusAPI(unittest.TestCase):
"GET", "/crewai_plus/api/v1/agents/test_agent_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("requests.Session.request")
def test_get_agent_with_org_uuid(self, mock_make_request, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get_agent("test_agent_handle")
self.assert_request_with_org_id(
mock_make_request, "GET", "/crewai_plus/api/v1/agents/test_agent_handle"
mock_make_request,
"GET",
"/crewai_plus/api/v1/agents/test_agent_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
@@ -109,13 +98,12 @@ class TestPlusAPI(unittest.TestCase):
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("requests.Session.request")
def test_get_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
@@ -127,7 +115,9 @@ class TestPlusAPI(unittest.TestCase):
response = self.api.get_tool("test_tool_handle")
self.assert_request_with_org_id(
mock_make_request, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
mock_make_request,
"GET",
"/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@@ -157,13 +147,12 @@ class TestPlusAPI(unittest.TestCase):
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("requests.Session.request")
def test_publish_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
@@ -171,7 +160,7 @@ class TestPlusAPI(unittest.TestCase):
# Set up mock response
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = True
version = "1.0.0"
@@ -191,9 +180,12 @@ class TestPlusAPI(unittest.TestCase):
"description": description,
"available_exports": None,
}
self.assert_request_with_org_id(
mock_make_request, "POST", "/crewai_plus/api/v1/tools", json=expected_params
mock_make_request,
"POST",
"/crewai_plus/api/v1/tools",
json=expected_params
)
self.assertEqual(response, mock_response)
@@ -319,11 +311,8 @@ class TestPlusAPI(unittest.TestCase):
"POST", "/crewai_plus/api/v1/crews", json=payload
)
@patch("crewai.cli.plus_api.Settings")
def test_custom_base_url(self, mock_settings_class):
mock_settings = MagicMock()
mock_settings.enterprise_base_url = "https://custom-url.com/api"
mock_settings_class.return_value = mock_settings
@patch.dict(os.environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
def test_custom_base_url(self):
custom_api = PlusAPI("test_key")
self.assertEqual(
custom_api.base_url,

View File

@@ -1,91 +0,0 @@
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch, MagicMock, call
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.config import (
Settings,
USER_SETTINGS_KEYS,
CLI_SETTINGS_KEYS,
DEFAULT_CLI_SETTINGS,
HIDDEN_SETTINGS_KEYS,
READONLY_SETTINGS_KEYS,
)
import shutil
class TestSettingsCommand(unittest.TestCase):
def setUp(self):
self.test_dir = Path(tempfile.mkdtemp())
self.config_path = self.test_dir / "settings.json"
self.settings = Settings(config_path=self.config_path)
self.settings_command = SettingsCommand(
settings_kwargs={"config_path": self.config_path}
)
def tearDown(self):
shutil.rmtree(self.test_dir)
@patch("crewai.cli.settings.main.console")
@patch("crewai.cli.settings.main.Table")
def test_list_settings(self, mock_table_class, mock_console):
mock_table_instance = MagicMock()
mock_table_class.return_value = mock_table_instance
self.settings_command.list()
# Tests that the table is created skipping hidden settings
mock_table_instance.add_row.assert_has_calls(
[
call(
field_name,
getattr(self.settings, field_name) or "Not set",
field_info.description,
)
for field_name, field_info in Settings.model_fields.items()
if field_name not in HIDDEN_SETTINGS_KEYS
]
)
# Tests that the table is printed
mock_console.print.assert_called_once_with(mock_table_instance)
def test_set_valid_keys(self):
valid_keys = Settings.model_fields.keys() - (
READONLY_SETTINGS_KEYS + HIDDEN_SETTINGS_KEYS
)
for key in valid_keys:
test_value = f"some_value_for_{key}"
self.settings_command.set(key, test_value)
self.assertEqual(getattr(self.settings_command.settings, key), test_value)
def test_set_invalid_key(self):
with self.assertRaises(SystemExit):
self.settings_command.set("invalid_key", "value")
def test_set_readonly_keys(self):
for key in READONLY_SETTINGS_KEYS:
with self.assertRaises(SystemExit):
self.settings_command.set(key, "some_readonly_key_value")
def test_set_hidden_keys(self):
for key in HIDDEN_SETTINGS_KEYS:
with self.assertRaises(SystemExit):
self.settings_command.set(key, "some_hidden_key_value")
def test_reset_all_settings(self):
for key in USER_SETTINGS_KEYS + CLI_SETTINGS_KEYS:
setattr(self.settings_command.settings, key, f"custom_value_for_{key}")
self.settings_command.settings.dump()
self.settings_command.reset_all_settings()
print(USER_SETTINGS_KEYS)
for key in USER_SETTINGS_KEYS:
self.assertEqual(getattr(self.settings_command.settings, key), None)
for key in CLI_SETTINGS_KEYS:
self.assertEqual(
getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS[key]
)

View File

@@ -755,15 +755,3 @@ def test_multiple_routers_from_same_trigger():
assert execution_order.index("anemia_analysis") > execution_order.index(
"anemia_router"
)
def test_flow_name():
class MyFlow(Flow):
name = "MyFlow"
@start()
def start(self):
return "Hello, world!"
flow = MyFlow()
assert flow.name == "MyFlow"

View File

@@ -191,39 +191,17 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test save method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.memory.add = MagicMock()
# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{"role": "assistant" , "content": test_value}],
[{'role': 'assistant' , 'content': test_value}],
infer=True,
metadata={"type": "short_term", "key": "value"},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent'
)
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")]
mem0_storage.memory.add = MagicMock()
test_value = "This is a test memory"
test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{"role": "assistant" , "content": test_value}],
infer=True,
metadata={"type": "short_term", "key": "value"},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3'
)
@@ -231,13 +209,13 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co
"""Test save method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mem0_storage.memory.add = MagicMock()
# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': test_value}],
infer=True,
@@ -246,9 +224,7 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co
run_id="my_run_id",
includes="include1",
excludes="exclude1",
output_format='v1.1',
user_id='test_user',
agent_id='Test_Agent'
output_format='v1.1'
)
@@ -261,10 +237,10 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
query="test query",
limit=5,
user_id="test_user",
filters={'AND': [{'run_id': 'my_run_id'}]},
filters={'AND': [{'run_id': 'my_run_id'}]},
threshold=0.5
)
@@ -281,8 +257,8 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
query="test query",
limit=5,
metadata={"type": "short_term"},
user_id="test_user",
version='v2',
@@ -310,108 +286,4 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
)
mem0_storage = Mem0Storage(type="short_term", crew=crew)
assert mem0_storage.infer is True
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
config = {
"agent_id": "agent-123",
}
mock_memory = MagicMock(spec=Memory)
with patch.object(Memory, "__new__", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config=config)
mem0_storage.save("test memory", {"key": "value"})
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': 'test memory'}],
infer=True,
metadata={"type": "external", "key": "value"},
agent_id="agent-123",
)
def test_search_method_with_agent_entity():
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123"})
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
filters={"AND": [{"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["content"] == "Result 1"
def test_search_method_with_agent_id_and_user_id():
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"})
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
user_id='user-123',
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["content"] == "Result 1"
def test_search_method_with_memory_oss_missing_params():
"""Test search method handles missing parameters gracefully when using Memory (OSS)"""
config = {
"user_id": "test_user",
"local_mem0_config": {
"vector_store": {"provider": "mock_vector_store"},
"llm": {"provider": "mock_llm"},
}
}
mock_memory = MagicMock(spec=Memory)
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}]}
mock_memory.search = MagicMock(return_value=mock_results)
with patch("mem0.memory.main.Memory.from_config", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config=config)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mock_memory.search.assert_called_once_with(
query="test query",
limit=5,
user_id="test_user",
filters={"AND": [{"user_id": "test_user"}]},
threshold=0.5
)
assert len(results) == 1
assert results[0]["content"] == "Result 1"
def test_search_method_with_memory_oss_no_optional_params():
"""Test search method works when no optional parameters are present"""
mock_memory = MagicMock(spec=Memory)
mock_results = {"results": []}
mock_memory.search = MagicMock(return_value=mock_results)
with patch("mem0.memory.main.Memory", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config={})
results = mem0_storage.search("test query")
mock_memory.search.assert_called_once_with(
query="test query",
limit=3,
filters={},
threshold=0.35
)
assert len(results) == 0
assert mem0_storage.infer is True

221
tests/test_custom_tools.py Normal file
View File

@@ -0,0 +1,221 @@
"""
Test custom tool registration patterns to ensure all documented patterns work correctly.
This addresses issue #3226 where custom tool registration was broken in CrewAI 0.150.0.
"""
import pytest
from crewai import Agent
from crewai.tools import BaseTool, tool
class TestCustomToolPatterns:
"""Test all custom tool patterns mentioned in issue #3226."""
def test_function_tool_with_decorator(self):
"""Test function tool with @tool decorator."""
@tool
def fetch_logs(query: str) -> str:
"""Fetch logs from New Relic based on query"""
return f"Logs for query: {query}"
agent = Agent(
role='CrashFetcher',
goal='Extract logs',
backstory='An agent that fetches logs',
tools=[fetch_logs],
allow_delegation=False
)
assert len(agent.tools) == 1
assert agent.tools[0].name == "fetch_logs"
assert "Fetch logs from New Relic" in agent.tools[0].description
def test_dict_based_tool(self):
"""Test dict-based tool definition."""
def fetch_logs_func(query: str) -> str:
return f"Logs for query: {query}"
fetch_logs_dict = {
'name': 'fetch_logs',
'description': 'Fetch logs from New Relic',
'func': fetch_logs_func
}
agent = Agent(
role='CrashFetcher',
goal='Extract logs',
backstory='An agent that fetches logs',
tools=[fetch_logs_dict],
allow_delegation=False
)
assert len(agent.tools) == 1
assert agent.tools[0].name == "fetch_logs"
assert "Fetch logs from New Relic" in agent.tools[0].description
def test_basetool_class_inheritance(self):
"""Test BaseTool class inheritance."""
class FetchLogsTool(BaseTool):
name: str = "fetch_logs"
description: str = "Fetch logs from New Relic based on query"
def _run(self, query: str) -> str:
return f"Logs for query: {query}"
agent = Agent(
role='CrashFetcher',
goal='Extract logs',
backstory='An agent that fetches logs',
tools=[FetchLogsTool()],
allow_delegation=False
)
assert len(agent.tools) == 1
assert agent.tools[0].name == "fetch_logs"
assert "Fetch logs from New Relic" in agent.tools[0].description
def test_direct_function_assignment(self):
"""Test direct function assignment."""
def fetch_logs(query: str) -> str:
"""Fetch logs from New Relic based on query"""
return f"Logs for query: {query}"
agent = Agent(
role='CrashFetcher',
goal='Extract logs',
backstory='An agent that fetches logs',
tools=[fetch_logs],
allow_delegation=False
)
assert len(agent.tools) == 1
assert agent.tools[0].name == "fetch_logs"
assert "Fetch logs from New Relic" in agent.tools[0].description
def test_mixed_tool_types(self):
"""Test mixing different tool types in the same agent."""
@tool
def decorated_tool(query: str) -> str:
"""A decorated tool"""
return f"Decorated: {query}"
class ClassTool(BaseTool):
name: str = "class_tool"
description: str = "A class-based tool"
def _run(self, query: str) -> str:
return f"Class: {query}"
def function_tool(query: str) -> str:
"""A function tool"""
return f"Function: {query}"
dict_tool = {
'name': 'dict_tool',
'description': 'A dict-based tool',
'func': lambda query: f"Dict: {query}"
}
agent = Agent(
role='MultiTool',
goal='Use multiple tool types',
backstory='An agent with various tools',
tools=[decorated_tool, ClassTool(), function_tool, dict_tool],
allow_delegation=False
)
assert len(agent.tools) == 4
tool_names = [tool.name for tool in agent.tools]
assert "decorated_tool" in tool_names
assert "class_tool" in tool_names
assert "function_tool" in tool_names
assert "dict_tool" in tool_names
def test_invalid_tool_types(self):
"""Test that invalid tool types raise appropriate errors."""
with pytest.raises(ValueError, match="Invalid tool type"):
Agent(
role='Test',
goal='Test invalid tools',
backstory='Testing',
tools=["invalid_string_tool"],
allow_delegation=False
)
with pytest.raises(ValueError, match="Invalid tool type"):
Agent(
role='Test',
goal='Test invalid tools',
backstory='Testing',
tools=[123],
allow_delegation=False
)
def test_function_without_docstring_fails(self):
"""Test that functions without docstrings fail validation."""
def no_docstring_func(query: str) -> str:
return f"No docstring: {query}"
with pytest.raises(ValueError, match="must have a docstring"):
Agent(
role='Test',
goal='Test function without docstring',
backstory='Testing',
tools=[no_docstring_func],
allow_delegation=False
)
def test_incomplete_dict_tool_fails(self):
"""Test that dict tools missing required keys fail validation."""
incomplete_dict = {
'name': 'incomplete',
'description': 'Missing func key'
}
with pytest.raises(ValueError, match="Dict tool must contain keys"):
Agent(
role='Test',
goal='Test incomplete dict tool',
backstory='Testing',
tools=[incomplete_dict],
allow_delegation=False
)
def test_tool_execution(self):
"""Test that tools can actually be executed."""
@tool
def test_execution_tool(message: str) -> str:
"""A tool for testing execution"""
return f"Executed: {message}"
agent = Agent(
role='Executor',
goal='Execute tools',
backstory='An agent that executes tools',
tools=[test_execution_tool],
allow_delegation=False
)
tool_instance = agent.tools[0]
result = tool_instance.run(message="test")
assert result == "Executed: test"
def test_tool_with_multiple_parameters(self):
"""Test tools with multiple parameters work correctly."""
@tool
def multi_param_tool(param1: str, param2: int, param3: bool = True) -> str:
"""A tool with multiple parameters"""
return f"p1={param1}, p2={param2}, p3={param3}"
agent = Agent(
role='MultiParam',
goal='Use multi-parameter tools',
backstory='An agent with complex tools',
tools=[multi_param_tool],
allow_delegation=False
)
assert len(agent.tools) == 1
tool_instance = agent.tools[0]
result = tool_instance.run(param1="test", param2=42, param3=False)
assert result == "p1=test, p2=42, p3=False"

View File

@@ -1,25 +0,0 @@
from unittest.mock import patch
import pytest
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
def test_configure_embedder_importerror():
configurator = EmbeddingConfigurator()
embedder_config = {
'provider': 'openai',
'config': {
'model': 'text-embedding-ada-002',
}
}
with patch('chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction') as mock_openai:
mock_openai.side_effect = ImportError("Module not found.")
with pytest.raises(ImportError) as exc_info:
configurator.configure_embedder(embedder_config)
assert str(exc_info.value) == "Module not found."
mock_openai.assert_called_once()

View File

@@ -64,8 +64,7 @@ def base_agent():
llm="gpt-4o-mini",
goal="Just say hi",
backstory="You are a helpful assistant that just says hi",
)
)
@pytest.fixture(scope="module")
def base_task(base_agent):
@@ -75,7 +74,6 @@ def base_task(base_agent):
agent=base_agent,
)
event_listener = EventListener()
@@ -450,27 +448,6 @@ def test_flow_emits_start_event():
assert received_events[0].type == "flow_started"
def test_flow_name_emitted_to_event_bus():
received_events = []
class MyFlowClass(Flow):
name = "PRODUCTION_FLOW"
@start()
def start(self):
return "Hello, world!"
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
received_events.append(event)
flow = MyFlowClass()
flow.kickoff()
assert len(received_events) == 1
assert received_events[0].flow_name == "PRODUCTION_FLOW"
def test_flow_emits_finish_event():
received_events = []
@@ -779,7 +756,6 @@ def test_streaming_empty_response_handling():
received_chunks = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
@@ -817,7 +793,6 @@ def test_streaming_empty_response_handling():
# Restore the original method
llm.call = original_call
@pytest.mark.vcr(filter_headers=["authorization"])
def test_stream_llm_emits_event_with_task_and_agent_info():
completed_event = []
@@ -826,7 +801,6 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
stream_event = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@@ -853,7 +827,7 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
description="Just say hi",
expected_output="hi",
llm=LLM(model="gpt-4o-mini", stream=True),
agent=agent,
agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
@@ -881,7 +855,6 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
assert set(all_task_id) == {task.id}
assert set(all_task_name) == {task.name}
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
completed_event = []
@@ -890,7 +863,6 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
stream_event = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@@ -932,7 +904,6 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
assert set(all_task_id) == {base_task.id}
assert set(all_task_name) == {base_task.name}
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_emits_event_with_lite_agent():
completed_event = []
@@ -941,7 +912,6 @@ def test_llm_emits_event_with_lite_agent():
stream_event = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@@ -966,6 +936,7 @@ def test_llm_emits_event_with_lite_agent():
)
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
assert len(completed_event) == 2
assert len(failed_event) == 0
assert len(started_event) == 2

9
uv.lock generated
View File

@@ -798,7 +798,7 @@ requires-dist = [
{ name = "blinker", specifier = ">=1.9.0" },
{ name = "chromadb", specifier = ">=0.5.23" },
{ name = "click", specifier = ">=8.1.7" },
{ name = "crewai-tools", marker = "extra == 'tools'", specifier = "~=0.59.0" },
{ name = "crewai-tools", marker = "extra == 'tools'", specifier = "~=0.58.0" },
{ name = "docling", marker = "extra == 'docling'", specifier = ">=2.12.0" },
{ name = "instructor", specifier = ">=1.3.3" },
{ name = "json-repair", specifier = "==0.25.2" },
@@ -850,7 +850,7 @@ dev = [
[[package]]
name = "crewai-tools"
version = "0.59.0"
version = "0.58.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "chromadb" },
@@ -860,7 +860,6 @@ dependencies = [
{ name = "embedchain" },
{ name = "lancedb" },
{ name = "openai" },
{ name = "portalocker" },
{ name = "pydantic" },
{ name = "pyright" },
{ name = "pytube" },
@@ -868,9 +867,9 @@ dependencies = [
{ name = "stagehand" },
{ name = "tiktoken" },
]
sdist = { url = "https://files.pythonhosted.org/packages/10/cd/af005e7dca5a35ed6000db2d1594acad890997b92172d8af2c1d9a83784e/crewai_tools-0.59.0.tar.gz", hash = "sha256:030e4b65446f4c6eccdcba5380bbdc90896de74589cb1bdb3cabc2c22c83f326", size = 1032093, upload-time = "2025-07-30T21:15:54.182Z" }
sdist = { url = "https://files.pythonhosted.org/packages/1f/bf/72c3a0cb5a8be1f635a4e3b07ee2ad81a6d427e63b7748c2727a33ade0d4/crewai_tools-0.58.0.tar.gz", hash = "sha256:ea82d5df8611ae22a8291934c4cd0b7ed5b77eca475f81014f018b7eca4d3350", size = 1026853, upload-time = "2025-07-23T17:45:53.228Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/b2/6b11d770fda6df99ddc49bdb26d3cb0b5efd026ff395a72ca5eeb704c335/crewai_tools-0.59.0-py3-none-any.whl", hash = "sha256:816297934eb352368ff0869a0eb11dd9febb4c40b6d0eda8e6eba8f3e426f446", size = 657025, upload-time = "2025-07-30T21:15:52.531Z" },
{ url = "https://files.pythonhosted.org/packages/34/bd/1de36fbf8fb717817d3bf72a94da38e27af6cc5b888d7c1203a3f0b0cc2f/crewai_tools-0.58.0-py3-none-any.whl", hash = "sha256:151688bf0fa8c90e27dcdbaa8619f3dee2a14e97f1b420a38187b12d88305175", size = 650113, upload-time = "2025-07-23T17:45:51.056Z" },
]
[[package]]