mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
feat: Add ToolCollection class for named tool access (#339)
This change allows accessing tools by name (tools["tool_name"]) in addition to index (tools[0]), making it more intuitive and convenient to work with multiple tools without needing to remember their position in the list
This commit is contained in:
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
"""
|
||||
MCPServer for CrewAI.
|
||||
|
||||
@@ -114,7 +114,7 @@ class MCPServerAdapter:
|
||||
self._adapter.__exit__(None, None, None)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[BaseTool]:
|
||||
def tools(self) -> ToolCollection[BaseTool]:
|
||||
"""The CrewAI tools available from the MCP server.
|
||||
|
||||
Raises:
|
||||
@@ -127,7 +127,7 @@ class MCPServerAdapter:
|
||||
raise ValueError(
|
||||
"MCP server not started, run `mcp_server.start()` first before accessing `tools`"
|
||||
)
|
||||
return self._tools
|
||||
return ToolCollection(self._tools)
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
|
||||
59
src/crewai_tools/adapters/tool_collection.py
Normal file
59
src/crewai_tools/adapters/tool_collection.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import List, Optional, Union, TypeVar, Generic, Dict
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
T = TypeVar('T', bound=BaseTool)
|
||||
|
||||
class ToolCollection(list, Generic[T]):
|
||||
"""
|
||||
A collection of tools that can be accessed by index or name
|
||||
|
||||
This class extends the built-in list to provide dictionary-like
|
||||
access to tools based on their name property.
|
||||
|
||||
Usage:
|
||||
tools = ToolCollection(list_of_tools)
|
||||
# Access by index (regular list behavior)
|
||||
first_tool = tools[0]
|
||||
# Access by name (new functionality)
|
||||
search_tool = tools["search"]
|
||||
"""
|
||||
|
||||
def __init__(self, tools: Optional[List[T]] = None):
|
||||
super().__init__(tools or [])
|
||||
self._name_cache: Dict[str, T] = {}
|
||||
self._build_name_cache()
|
||||
|
||||
def _build_name_cache(self) -> None:
|
||||
self._name_cache = {tool.name: tool for tool in self}
|
||||
|
||||
def __getitem__(self, key: Union[int, str]) -> T:
|
||||
if isinstance(key, str):
|
||||
return self._name_cache[key]
|
||||
return super().__getitem__(key)
|
||||
|
||||
def append(self, tool: T) -> None:
|
||||
super().append(tool)
|
||||
self._name_cache[tool.name] = tool
|
||||
|
||||
def extend(self, tools: List[T]) -> None:
|
||||
super().extend(tools)
|
||||
self._build_name_cache()
|
||||
|
||||
def insert(self, index: int, tool: T) -> None:
|
||||
super().insert(index, tool)
|
||||
self._name_cache[tool.name] = tool
|
||||
|
||||
def remove(self, tool: T) -> None:
|
||||
super().remove(tool)
|
||||
if tool.name in self._name_cache:
|
||||
del self._name_cache[tool.name]
|
||||
|
||||
def pop(self, index: int = -1) -> T:
|
||||
tool = super().pop(index)
|
||||
if tool.name in self._name_cache:
|
||||
del self._name_cache[tool.name]
|
||||
return tool
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self._name_cache.clear()
|
||||
@@ -7,6 +7,7 @@ import typing as t
|
||||
import logging
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.adapters.enterprise_adapter import EnterpriseActionKitToolAdapter
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +17,7 @@ def CrewaiEnterpriseTools(
|
||||
actions_list: t.Optional[t.List[str]] = None,
|
||||
enterprise_action_kit_project_id: t.Optional[str] = None,
|
||||
enterprise_action_kit_project_url: t.Optional[str] = None,
|
||||
) -> t.List[BaseTool]:
|
||||
) -> ToolCollection[BaseTool]:
|
||||
"""Factory function that returns crewai enterprise tools.
|
||||
|
||||
Args:
|
||||
@@ -24,9 +25,11 @@ def CrewaiEnterpriseTools(
|
||||
If not provided, will try to use CREWAI_ENTERPRISE_TOOLS_TOKEN env var.
|
||||
actions_list: Optional list of specific tool names to include.
|
||||
If provided, only tools with these names will be returned.
|
||||
enterprise_action_kit_project_id: Optional ID of the Enterprise Action Kit project.
|
||||
enterprise_action_kit_project_url: Optional URL of the Enterprise Action Kit project.
|
||||
|
||||
Returns:
|
||||
A list of BaseTool instances for enterprise actions
|
||||
A ToolCollection of BaseTool instances for enterprise actions
|
||||
"""
|
||||
if enterprise_token is None:
|
||||
enterprise_token = os.environ.get("CREWAI_ENTERPRISE_TOOLS_TOKEN")
|
||||
@@ -47,7 +50,8 @@ def CrewaiEnterpriseTools(
|
||||
all_tools = adapter.tools()
|
||||
|
||||
if actions_list is None:
|
||||
return all_tools
|
||||
return ToolCollection(all_tools)
|
||||
|
||||
# Filter tools based on the provided list
|
||||
return [tool for tool in all_tools if tool.name in actions_list]
|
||||
filtered_tools = [tool for tool in all_tools if tool.name in actions_list]
|
||||
return ToolCollection(filtered_tools)
|
||||
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
from crewai_tools import MCPServerAdapter
|
||||
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
@pytest.fixture
|
||||
def echo_server_script():
|
||||
@@ -18,7 +18,7 @@ def echo_server_script():
|
||||
def echo_tool(text: str) -> str:
|
||||
"""Echo the input text"""
|
||||
return f"Echo: {text}"
|
||||
|
||||
|
||||
mcp.run()
|
||||
'''
|
||||
)
|
||||
@@ -68,6 +68,7 @@ def test_context_manager_syntax(echo_server_script):
|
||||
command="uv", args=["run", "python", "-c", echo_server_script]
|
||||
)
|
||||
with MCPServerAdapter(serverparams) as tools:
|
||||
assert isinstance(tools, ToolCollection)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "echo_tool"
|
||||
assert tools[0].run(text="hello") == "Echo: hello"
|
||||
@@ -91,7 +92,7 @@ def test_try_finally_syntax(echo_server_script):
|
||||
assert tools[0].run(text="hello") == "Echo: hello"
|
||||
finally:
|
||||
mcp_server_adapter.stop()
|
||||
|
||||
|
||||
def test_try_finally_syntax_sse(echo_sse_server):
|
||||
sse_serverparams = echo_sse_server
|
||||
mcp_server_adapter = MCPServerAdapter(sse_serverparams)
|
||||
70
tests/tools/crewai_enterprise_tools_test.py
Normal file
70
tests/tools/crewai_enterprise_tools_test.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.tools import CrewaiEnterpriseTools
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
class TestCrewaiEnterpriseTools(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mock_tools = [
|
||||
self._create_mock_tool("tool1", "Tool 1 Description"),
|
||||
self._create_mock_tool("tool2", "Tool 2 Description"),
|
||||
self._create_mock_tool("tool3", "Tool 3 Description"),
|
||||
]
|
||||
self.adapter_patcher = patch(
|
||||
'crewai_tools.tools.crewai_enterprise_tools.crewai_enterprise_tools.EnterpriseActionKitToolAdapter'
|
||||
)
|
||||
self.MockAdapter = self.adapter_patcher.start()
|
||||
|
||||
mock_adapter_instance = self.MockAdapter.return_value
|
||||
mock_adapter_instance.tools.return_value = self.mock_tools
|
||||
|
||||
def tearDown(self):
|
||||
self.adapter_patcher.stop()
|
||||
|
||||
def _create_mock_tool(self, name, description):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.name = name
|
||||
mock_tool.description = description
|
||||
return mock_tool
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
|
||||
def test_returns_tool_collection(self):
|
||||
tools = CrewaiEnterpriseTools()
|
||||
self.assertIsInstance(tools, ToolCollection)
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
|
||||
def test_returns_all_tools_when_no_actions_list(self):
|
||||
tools = CrewaiEnterpriseTools()
|
||||
self.assertEqual(len(tools), 3)
|
||||
self.assertEqual(tools[0].name, "tool1")
|
||||
self.assertEqual(tools[1].name, "tool2")
|
||||
self.assertEqual(tools[2].name, "tool3")
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
|
||||
def test_filters_tools_by_actions_list(self):
|
||||
tools = CrewaiEnterpriseTools(actions_list=["tool1", "tool3"])
|
||||
self.assertEqual(len(tools), 2)
|
||||
self.assertEqual(tools[0].name, "tool1")
|
||||
self.assertEqual(tools[1].name, "tool3")
|
||||
|
||||
def test_uses_provided_parameters(self):
|
||||
CrewaiEnterpriseTools(
|
||||
enterprise_token="test-token",
|
||||
enterprise_action_kit_project_id="project-id",
|
||||
enterprise_action_kit_project_url="project-url"
|
||||
)
|
||||
|
||||
self.MockAdapter.assert_called_once_with(
|
||||
enterprise_action_token="test-token",
|
||||
enterprise_action_kit_project_id="project-id",
|
||||
enterprise_action_kit_project_url="project-url"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
|
||||
def test_uses_environment_token(self):
|
||||
CrewaiEnterpriseTools()
|
||||
self.MockAdapter.assert_called_once_with(enterprise_action_token="env-token")
|
||||
172
tests/tools/tool_collection_test.py
Normal file
172
tests/tools/tool_collection_test.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
class TestToolCollection(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
self.search_tool = self._create_mock_tool("search", "Search Tool")
|
||||
self.calculator_tool = self._create_mock_tool("calculator", "Calculator Tool")
|
||||
self.translator_tool = self._create_mock_tool("translator", "Translator Tool")
|
||||
|
||||
self.tools = ToolCollection([
|
||||
self.search_tool,
|
||||
self.calculator_tool,
|
||||
self.translator_tool
|
||||
])
|
||||
|
||||
def _create_mock_tool(self, name, description):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.name = name
|
||||
mock_tool.description = description
|
||||
return mock_tool
|
||||
|
||||
def test_initialization(self):
|
||||
self.assertEqual(len(self.tools), 3)
|
||||
self.assertEqual(self.tools[0].name, "search")
|
||||
self.assertEqual(self.tools[1].name, "calculator")
|
||||
self.assertEqual(self.tools[2].name, "translator")
|
||||
|
||||
def test_empty_initialization(self):
|
||||
empty_collection = ToolCollection()
|
||||
self.assertEqual(len(empty_collection), 0)
|
||||
self.assertEqual(empty_collection._name_cache, {})
|
||||
|
||||
def test_initialization_with_none(self):
|
||||
collection = ToolCollection(None)
|
||||
self.assertEqual(len(collection), 0)
|
||||
self.assertEqual(collection._name_cache, {})
|
||||
|
||||
def test_access_by_index(self):
|
||||
self.assertEqual(self.tools[0], self.search_tool)
|
||||
self.assertEqual(self.tools[1], self.calculator_tool)
|
||||
self.assertEqual(self.tools[2], self.translator_tool)
|
||||
|
||||
def test_access_by_name(self):
|
||||
self.assertEqual(self.tools["search"], self.search_tool)
|
||||
self.assertEqual(self.tools["calculator"], self.calculator_tool)
|
||||
self.assertEqual(self.tools["translator"], self.translator_tool)
|
||||
|
||||
def test_key_error_for_invalid_name(self):
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["nonexistent"]
|
||||
|
||||
def test_index_error_for_invalid_index(self):
|
||||
with self.assertRaises(IndexError):
|
||||
_ = self.tools[10]
|
||||
|
||||
def test_negative_index(self):
|
||||
self.assertEqual(self.tools[-1], self.translator_tool)
|
||||
self.assertEqual(self.tools[-2], self.calculator_tool)
|
||||
self.assertEqual(self.tools[-3], self.search_tool)
|
||||
|
||||
def test_append(self):
|
||||
new_tool = self._create_mock_tool("new", "New Tool")
|
||||
self.tools.append(new_tool)
|
||||
|
||||
self.assertEqual(len(self.tools), 4)
|
||||
self.assertEqual(self.tools[3], new_tool)
|
||||
self.assertEqual(self.tools["new"], new_tool)
|
||||
self.assertIn("new", self.tools._name_cache)
|
||||
|
||||
def test_append_duplicate_name(self):
|
||||
duplicate_tool = self._create_mock_tool("search", "Duplicate Search Tool")
|
||||
self.tools.append(duplicate_tool)
|
||||
|
||||
self.assertEqual(len(self.tools), 4)
|
||||
self.assertEqual(self.tools["search"], duplicate_tool)
|
||||
|
||||
def test_extend(self):
|
||||
new_tools = [
|
||||
self._create_mock_tool("tool4", "Tool 4"),
|
||||
self._create_mock_tool("tool5", "Tool 5"),
|
||||
]
|
||||
self.tools.extend(new_tools)
|
||||
|
||||
self.assertEqual(len(self.tools), 5)
|
||||
self.assertEqual(self.tools["tool4"], new_tools[0])
|
||||
self.assertEqual(self.tools["tool5"], new_tools[1])
|
||||
self.assertIn("tool4", self.tools._name_cache)
|
||||
self.assertIn("tool5", self.tools._name_cache)
|
||||
|
||||
def test_insert(self):
|
||||
new_tool = self._create_mock_tool("inserted", "Inserted Tool")
|
||||
self.tools.insert(1, new_tool)
|
||||
|
||||
self.assertEqual(len(self.tools), 4)
|
||||
self.assertEqual(self.tools[1], new_tool)
|
||||
self.assertEqual(self.tools["inserted"], new_tool)
|
||||
self.assertIn("inserted", self.tools._name_cache)
|
||||
|
||||
def test_remove(self):
|
||||
self.tools.remove(self.calculator_tool)
|
||||
|
||||
self.assertEqual(len(self.tools), 2)
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["calculator"]
|
||||
self.assertNotIn("calculator", self.tools._name_cache)
|
||||
|
||||
def test_remove_nonexistent_tool(self):
|
||||
nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.tools.remove(nonexistent_tool)
|
||||
|
||||
def test_pop(self):
|
||||
popped = self.tools.pop(1)
|
||||
|
||||
self.assertEqual(popped, self.calculator_tool)
|
||||
self.assertEqual(len(self.tools), 2)
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["calculator"]
|
||||
self.assertNotIn("calculator", self.tools._name_cache)
|
||||
|
||||
def test_pop_last(self):
|
||||
popped = self.tools.pop()
|
||||
|
||||
self.assertEqual(popped, self.translator_tool)
|
||||
self.assertEqual(len(self.tools), 2)
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["translator"]
|
||||
self.assertNotIn("translator", self.tools._name_cache)
|
||||
|
||||
def test_clear(self):
|
||||
self.tools.clear()
|
||||
|
||||
self.assertEqual(len(self.tools), 0)
|
||||
self.assertEqual(self.tools._name_cache, {})
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["search"]
|
||||
|
||||
def test_iteration(self):
|
||||
tools_list = list(self.tools)
|
||||
self.assertEqual(tools_list, [self.search_tool, self.calculator_tool, self.translator_tool])
|
||||
|
||||
def test_contains(self):
|
||||
self.assertIn(self.search_tool, self.tools)
|
||||
self.assertIn(self.calculator_tool, self.tools)
|
||||
self.assertIn(self.translator_tool, self.tools)
|
||||
|
||||
nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool")
|
||||
self.assertNotIn(nonexistent_tool, self.tools)
|
||||
|
||||
def test_slicing(self):
|
||||
slice_result = self.tools[1:3]
|
||||
self.assertEqual(len(slice_result), 2)
|
||||
self.assertEqual(slice_result[0], self.calculator_tool)
|
||||
self.assertEqual(slice_result[1], self.translator_tool)
|
||||
|
||||
self.assertIsInstance(slice_result, list)
|
||||
self.assertNotIsInstance(slice_result, ToolCollection)
|
||||
|
||||
def test_getitem_with_tool_name_as_int(self):
|
||||
numeric_name_tool = self._create_mock_tool("123", "Numeric Name Tool")
|
||||
self.tools.append(numeric_name_tool)
|
||||
|
||||
self.assertEqual(self.tools["123"], numeric_name_tool)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
_ = self.tools[123]
|
||||
Reference in New Issue
Block a user