mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-12 14:02:47 +00:00
Compare commits
1 Commits
chore/clea
...
devin/1775
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a649b226e |
@@ -5,6 +5,8 @@ from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
@@ -37,6 +39,13 @@ def _create_hook_decorator(
|
||||
tools: list[str] | None = None,
|
||||
agents: list[str] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
# Sanitize tool names so users can pass human-readable names
|
||||
# (e.g., "File Read Tool") and still match the sanitized tool_name
|
||||
# that appears in ToolCallHookContext at runtime.
|
||||
sanitized_tools: list[str] | None = (
|
||||
[sanitize_tool_name(t) for t in tools] if tools else tools
|
||||
)
|
||||
|
||||
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
setattr(f, marker_attribute, True)
|
||||
|
||||
@@ -44,17 +53,17 @@ def _create_hook_decorator(
|
||||
params = list(sig.parameters.keys())
|
||||
is_method = len(params) >= 2 and params[0] == "self"
|
||||
|
||||
if tools:
|
||||
f._filter_tools = tools # type: ignore[attr-defined]
|
||||
if sanitized_tools:
|
||||
f._filter_tools = sanitized_tools # type: ignore[attr-defined]
|
||||
if agents:
|
||||
f._filter_agents = agents # type: ignore[attr-defined]
|
||||
|
||||
if tools or agents:
|
||||
if sanitized_tools or agents:
|
||||
|
||||
@wraps(f)
|
||||
def filtered_hook(context: Any) -> Any:
|
||||
if tools and hasattr(context, "tool_name"):
|
||||
if context.tool_name not in tools:
|
||||
if sanitized_tools and hasattr(context, "tool_name"):
|
||||
if context.tool_name not in sanitized_tools:
|
||||
return None
|
||||
|
||||
if agents and hasattr(context, "agent"):
|
||||
|
||||
@@ -293,6 +293,195 @@ class TestDecoratorAttributes:
|
||||
assert test_hook._filter_agents == ["Dev"]
|
||||
|
||||
|
||||
class TestToolNameSanitizationInHookFilters:
|
||||
"""Test that tool names in hook filters are auto-sanitized to match runtime tool_name."""
|
||||
|
||||
def test_before_tool_call_filter_matches_human_readable_name(self):
|
||||
"""Test that human-readable tool names like 'File Read Tool' match sanitized context.tool_name."""
|
||||
execution_log = []
|
||||
|
||||
# User passes the human-readable BaseTool.name
|
||||
@before_tool_call(tools=["File Read Tool"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.tool_name)
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
|
||||
# At runtime, context.tool_name is the sanitized version
|
||||
context = ToolCallHookContext(
|
||||
tool_name="file_read_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "file_read_tool"
|
||||
|
||||
def test_after_tool_call_filter_matches_human_readable_name(self):
|
||||
"""Test that after_tool_call also sanitizes tool filter names."""
|
||||
execution_log = []
|
||||
|
||||
@after_tool_call(tools=["Web Search Tool"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.tool_name)
|
||||
return None
|
||||
|
||||
hooks = get_after_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="web_search_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
tool_result="some result",
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "web_search_tool"
|
||||
|
||||
def test_before_tool_call_filter_with_camel_case_name(self):
|
||||
"""Test that CamelCase tool names are sanitized to match snake_case context.tool_name."""
|
||||
execution_log = []
|
||||
|
||||
@before_tool_call(tools=["ExaSearchTool"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.tool_name)
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
|
||||
# Sanitized CamelCase: ExaSearchTool -> exa_search_tool
|
||||
context = ToolCallHookContext(
|
||||
tool_name="exa_search_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "exa_search_tool"
|
||||
|
||||
def test_filter_already_sanitized_name_still_works(self):
|
||||
"""Test that passing already-sanitized names (e.g. 'delete_file') still works."""
|
||||
execution_log = []
|
||||
|
||||
@before_tool_call(tools=["delete_file"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.tool_name)
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="delete_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "delete_file"
|
||||
|
||||
def test_filter_non_matching_sanitized_name_skips_hook(self):
|
||||
"""Test that hooks are correctly skipped when sanitized names don't match."""
|
||||
execution_log = []
|
||||
|
||||
@before_tool_call(tools=["File Read Tool"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.tool_name)
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
|
||||
# This tool name doesn't match "file_read_tool"
|
||||
context = ToolCallHookContext(
|
||||
tool_name="web_search_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 0
|
||||
|
||||
def test_filter_tools_attribute_stores_sanitized_names(self):
|
||||
"""Test that _filter_tools attribute stores sanitized tool names."""
|
||||
|
||||
@before_tool_call(tools=["File Read Tool", "MyCustomTool"])
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
assert hasattr(test_hook, "_filter_tools")
|
||||
assert test_hook._filter_tools == ["file_read_tool", "my_custom_tool"]
|
||||
|
||||
def test_mixed_sanitized_and_unsanitized_tool_names(self):
|
||||
"""Test that a mix of human-readable and already-sanitized names all work."""
|
||||
execution_log = []
|
||||
|
||||
@before_tool_call(tools=["File Read Tool", "delete_file", "ExaSearchTool"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.tool_name)
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
|
||||
# All three should match their sanitized counterparts
|
||||
for tool_name in ["file_read_tool", "delete_file", "exa_search_tool"]:
|
||||
context = ToolCallHookContext(
|
||||
tool_name=tool_name,
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 3
|
||||
assert execution_log == ["file_read_tool", "delete_file", "exa_search_tool"]
|
||||
|
||||
def test_combined_sanitized_tool_filter_and_agent_filter(self):
|
||||
"""Test that sanitized tool filter works alongside agent filter."""
|
||||
execution_log = []
|
||||
|
||||
@before_tool_call(tools=["File Read Tool"], agents=["Researcher"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(f"{context.tool_name}-{context.agent.role}")
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
mock_agent = Mock(role="Researcher")
|
||||
|
||||
# Both filters match
|
||||
context = ToolCallHookContext(
|
||||
tool_name="file_read_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "file_read_tool-Researcher"
|
||||
|
||||
# Tool matches but agent doesn't
|
||||
mock_agent2 = Mock(role="Developer")
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="file_read_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
agent=mock_agent2,
|
||||
)
|
||||
hooks[0](context2)
|
||||
|
||||
assert len(execution_log) == 1 # Still 1, hook was skipped
|
||||
|
||||
|
||||
class TestMultipleDecorators:
|
||||
"""Test using multiple decorators together."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user