mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 10:08:13 +00:00
Compare commits
5 Commits
devin/1768
...
tm-push-cu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c29545afb | ||
|
|
028db9dbbc | ||
|
|
08fc6ac6f9 | ||
|
|
31092293e5 | ||
|
|
0600843299 |
@@ -89,9 +89,6 @@ from crewai_tools.tools.jina_scrape_website_tool.jina_scrape_website_tool import
|
||||
from crewai_tools.tools.json_search_tool.json_search_tool import JSONSearchTool
|
||||
from crewai_tools.tools.linkup.linkup_search_tool import LinkupSearchTool
|
||||
from crewai_tools.tools.llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
|
||||
MCPDiscoveryTool,
|
||||
)
|
||||
from crewai_tools.tools.mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||
from crewai_tools.tools.merge_agent_handler_tool.merge_agent_handler_tool import (
|
||||
MergeAgentHandlerTool,
|
||||
@@ -239,7 +236,6 @@ __all__ = [
|
||||
"JinaScrapeWebsiteTool",
|
||||
"LinkupSearchTool",
|
||||
"LlamaIndexTool",
|
||||
"MCPDiscoveryTool",
|
||||
"MCPServerAdapter",
|
||||
"MDXSearchTool",
|
||||
"MergeAgentHandlerTool",
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
|
||||
MCPDiscoveryResult,
|
||||
MCPDiscoveryTool,
|
||||
MCPDiscoveryToolSchema,
|
||||
MCPServerMetrics,
|
||||
MCPServerRecommendation,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MCPDiscoveryResult",
|
||||
"MCPDiscoveryTool",
|
||||
"MCPDiscoveryToolSchema",
|
||||
"MCPServerMetrics",
|
||||
"MCPServerRecommendation",
|
||||
]
|
||||
@@ -1,414 +0,0 @@
|
||||
"""MCP Discovery Tool for CrewAI agents.
|
||||
|
||||
This tool enables agents to dynamically discover MCP servers based on
|
||||
natural language queries using the MCP Discovery API.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerMetrics(TypedDict, total=False):
|
||||
"""Performance metrics for an MCP server."""
|
||||
|
||||
avg_latency_ms: float | None
|
||||
uptime_pct: float | None
|
||||
last_checked: str | None
|
||||
|
||||
|
||||
class MCPServerRecommendation(TypedDict, total=False):
|
||||
"""A recommended MCP server from the discovery API."""
|
||||
|
||||
server: str
|
||||
npm_package: str
|
||||
install_command: str
|
||||
confidence: float
|
||||
description: str
|
||||
capabilities: list[str]
|
||||
metrics: MCPServerMetrics
|
||||
docs_url: str
|
||||
github_url: str
|
||||
|
||||
|
||||
class MCPDiscoveryResult(TypedDict):
|
||||
"""Result from the MCP Discovery API."""
|
||||
|
||||
recommendations: list[MCPServerRecommendation]
|
||||
total_found: int
|
||||
query_time_ms: int
|
||||
|
||||
|
||||
class MCPDiscoveryConstraints(BaseModel):
|
||||
"""Constraints for MCP server discovery."""
|
||||
|
||||
max_latency_ms: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum acceptable latency in milliseconds",
|
||||
)
|
||||
required_features: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of required features/capabilities",
|
||||
)
|
||||
exclude_servers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of server names to exclude from results",
|
||||
)
|
||||
|
||||
|
||||
class MCPDiscoveryToolSchema(BaseModel):
|
||||
"""Input schema for MCPDiscoveryTool."""
|
||||
|
||||
need: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Natural language description of what you need. "
|
||||
"For example: 'database with authentication', 'email automation', "
|
||||
"'file storage', 'web scraping'"
|
||||
),
|
||||
)
|
||||
constraints: MCPDiscoveryConstraints | None = Field(
|
||||
default=None,
|
||||
description="Optional constraints to filter results",
|
||||
)
|
||||
limit: int = Field(
|
||||
default=5,
|
||||
description="Maximum number of recommendations to return (1-10)",
|
||||
ge=1,
|
||||
le=10,
|
||||
)
|
||||
|
||||
|
||||
class MCPDiscoveryTool(BaseTool):
|
||||
"""Tool for discovering MCP servers dynamically.
|
||||
|
||||
This tool uses the MCP Discovery API to find MCP servers that match
|
||||
a natural language description of what the agent needs. It enables
|
||||
agents to dynamically discover and select the best MCP servers for
|
||||
their tasks without requiring pre-configuration.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai_tools import MCPDiscoveryTool
|
||||
|
||||
discovery_tool = MCPDiscoveryTool()
|
||||
|
||||
agent = Agent(
|
||||
role='Researcher',
|
||||
tools=[discovery_tool],
|
||||
goal='Research and analyze data'
|
||||
)
|
||||
|
||||
# The agent can now discover MCP servers dynamically:
|
||||
# discover_mcp_server(need="database with authentication")
|
||||
# Returns: Supabase MCP server with installation instructions
|
||||
```
|
||||
|
||||
Attributes:
|
||||
name: The name of the tool.
|
||||
description: A description of what the tool does.
|
||||
args_schema: The Pydantic model for input validation.
|
||||
base_url: The base URL for the MCP Discovery API.
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
|
||||
name: str = "Discover MCP Server"
|
||||
description: str = (
|
||||
"Discover MCP (Model Context Protocol) servers that match your needs. "
|
||||
"Use this tool to find the best MCP server for any task using natural "
|
||||
"language. Returns server recommendations with installation instructions, "
|
||||
"capabilities, and performance metrics."
|
||||
)
|
||||
args_schema: type[BaseModel] = MCPDiscoveryToolSchema
|
||||
base_url: str = "https://mcp-discovery-production.up.railway.app"
|
||||
timeout: int = 30
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="MCP_DISCOVERY_API_KEY",
|
||||
description="API key for MCP Discovery (optional for free tier)",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _build_request_payload(
|
||||
self,
|
||||
need: str,
|
||||
constraints: MCPDiscoveryConstraints | None,
|
||||
limit: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the request payload for the discovery API.
|
||||
|
||||
Args:
|
||||
need: Natural language description of what is needed.
|
||||
constraints: Optional constraints to filter results.
|
||||
limit: Maximum number of recommendations.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the request payload.
|
||||
"""
|
||||
payload: dict[str, Any] = {
|
||||
"need": need,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
if constraints:
|
||||
constraints_dict: dict[str, Any] = {}
|
||||
if constraints.max_latency_ms is not None:
|
||||
constraints_dict["max_latency_ms"] = constraints.max_latency_ms
|
||||
if constraints.required_features:
|
||||
constraints_dict["required_features"] = constraints.required_features
|
||||
if constraints.exclude_servers:
|
||||
constraints_dict["exclude_servers"] = constraints.exclude_servers
|
||||
if constraints_dict:
|
||||
payload["constraints"] = constraints_dict
|
||||
|
||||
return payload
|
||||
|
||||
def _make_api_request(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Make a request to the MCP Discovery API.
|
||||
|
||||
Args:
|
||||
payload: The request payload.
|
||||
|
||||
Returns:
|
||||
The API response as a dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API returns an empty response.
|
||||
requests.exceptions.RequestException: If the request fails.
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/discover"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
api_key = os.environ.get("MCP_DISCOVERY_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
response = None
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
if not results:
|
||||
logger.error("Empty response from MCP Discovery API")
|
||||
raise ValueError("Empty response from MCP Discovery API")
|
||||
return results
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"Error making request to MCP Discovery API: {e}"
|
||||
if response is not None and hasattr(response, "content"):
|
||||
error_msg += (
|
||||
f"\nResponse content: "
|
||||
f"{response.content.decode('utf-8', errors='replace')}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
if response is not None and hasattr(response, "content"):
|
||||
logger.error(f"Error decoding JSON response: {e}")
|
||||
logger.error(
|
||||
f"Response content: "
|
||||
f"{response.content.decode('utf-8', errors='replace')}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Error decoding JSON response: {e} (No response content available)"
|
||||
)
|
||||
raise
|
||||
|
||||
def _process_single_recommendation(
|
||||
self, rec: dict[str, Any]
|
||||
) -> MCPServerRecommendation | None:
|
||||
"""Process a single recommendation from the API.
|
||||
|
||||
Args:
|
||||
rec: Raw recommendation dictionary from the API.
|
||||
|
||||
Returns:
|
||||
Processed MCPServerRecommendation or None if malformed.
|
||||
"""
|
||||
try:
|
||||
metrics_data = rec.get("metrics", {}) if isinstance(rec, dict) else {}
|
||||
metrics: MCPServerMetrics = {
|
||||
"avg_latency_ms": metrics_data.get("avg_latency_ms"),
|
||||
"uptime_pct": metrics_data.get("uptime_pct"),
|
||||
"last_checked": metrics_data.get("last_checked"),
|
||||
}
|
||||
|
||||
recommendation: MCPServerRecommendation = {
|
||||
"server": rec.get("server", ""),
|
||||
"npm_package": rec.get("npm_package", ""),
|
||||
"install_command": rec.get("install_command", ""),
|
||||
"confidence": rec.get("confidence", 0.0),
|
||||
"description": rec.get("description", ""),
|
||||
"capabilities": rec.get("capabilities", []),
|
||||
"metrics": metrics,
|
||||
"docs_url": rec.get("docs_url", ""),
|
||||
"github_url": rec.get("github_url", ""),
|
||||
}
|
||||
return recommendation
|
||||
except (KeyError, TypeError, AttributeError) as e:
|
||||
logger.warning(f"Skipping malformed recommendation: {rec}, error: {e}")
|
||||
return None
|
||||
|
||||
def _process_recommendations(
|
||||
self, recommendations: list[dict[str, Any]]
|
||||
) -> list[MCPServerRecommendation]:
|
||||
"""Process and validate server recommendations.
|
||||
|
||||
Args:
|
||||
recommendations: Raw recommendations from the API.
|
||||
|
||||
Returns:
|
||||
List of processed MCPServerRecommendation objects.
|
||||
"""
|
||||
processed: list[MCPServerRecommendation] = []
|
||||
for rec in recommendations:
|
||||
result = self._process_single_recommendation(rec)
|
||||
if result is not None:
|
||||
processed.append(result)
|
||||
return processed
|
||||
|
||||
def _format_result(self, result: MCPDiscoveryResult) -> str:
|
||||
"""Format the discovery result as a human-readable string.
|
||||
|
||||
Args:
|
||||
result: The discovery result to format.
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the result.
|
||||
"""
|
||||
if not result["recommendations"]:
|
||||
return "No MCP servers found matching your requirements."
|
||||
|
||||
lines = [
|
||||
f"Found {result['total_found']} MCP server(s) "
|
||||
f"(query took {result['query_time_ms']}ms):\n"
|
||||
]
|
||||
|
||||
for i, rec in enumerate(result["recommendations"], 1):
|
||||
confidence_pct = rec.get("confidence", 0) * 100
|
||||
lines.append(f"{i}. {rec.get('server', 'Unknown')} ({confidence_pct:.0f}% confidence)")
|
||||
lines.append(f" Description: {rec.get('description', 'N/A')}")
|
||||
lines.append(f" Capabilities: {', '.join(rec.get('capabilities', []))}")
|
||||
lines.append(f" Install: {rec.get('install_command', 'N/A')}")
|
||||
lines.append(f" NPM Package: {rec.get('npm_package', 'N/A')}")
|
||||
|
||||
metrics = rec.get("metrics", {})
|
||||
if metrics.get("avg_latency_ms") is not None:
|
||||
lines.append(f" Avg Latency: {metrics['avg_latency_ms']}ms")
|
||||
if metrics.get("uptime_pct") is not None:
|
||||
lines.append(f" Uptime: {metrics['uptime_pct']}%")
|
||||
|
||||
if rec.get("docs_url"):
|
||||
lines.append(f" Docs: {rec['docs_url']}")
|
||||
if rec.get("github_url"):
|
||||
lines.append(f" GitHub: {rec['github_url']}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Execute the MCP discovery operation.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments matching MCPDiscoveryToolSchema.
|
||||
|
||||
Returns:
|
||||
A formatted string with discovered MCP servers.
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing.
|
||||
"""
|
||||
need: str | None = kwargs.get("need")
|
||||
if not need:
|
||||
raise ValueError("'need' parameter is required")
|
||||
|
||||
constraints_data = kwargs.get("constraints")
|
||||
constraints: MCPDiscoveryConstraints | None = None
|
||||
if constraints_data:
|
||||
if isinstance(constraints_data, dict):
|
||||
constraints = MCPDiscoveryConstraints(**constraints_data)
|
||||
elif isinstance(constraints_data, MCPDiscoveryConstraints):
|
||||
constraints = constraints_data
|
||||
|
||||
limit: int = kwargs.get("limit", 5)
|
||||
|
||||
payload = self._build_request_payload(need, constraints, limit)
|
||||
response = self._make_api_request(payload)
|
||||
|
||||
recommendations = self._process_recommendations(
|
||||
response.get("recommendations", [])
|
||||
)
|
||||
|
||||
result: MCPDiscoveryResult = {
|
||||
"recommendations": recommendations,
|
||||
"total_found": response.get("total_found", len(recommendations)),
|
||||
"query_time_ms": response.get("query_time_ms", 0),
|
||||
}
|
||||
|
||||
return self._format_result(result)
|
||||
|
||||
def discover(
|
||||
self,
|
||||
need: str,
|
||||
constraints: MCPDiscoveryConstraints | None = None,
|
||||
limit: int = 5,
|
||||
) -> MCPDiscoveryResult:
|
||||
"""Discover MCP servers matching the given requirements.
|
||||
|
||||
This is a convenience method that returns structured data instead
|
||||
of a formatted string.
|
||||
|
||||
Args:
|
||||
need: Natural language description of what is needed.
|
||||
constraints: Optional constraints to filter results.
|
||||
limit: Maximum number of recommendations (1-10).
|
||||
|
||||
Returns:
|
||||
MCPDiscoveryResult containing server recommendations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
tool = MCPDiscoveryTool()
|
||||
result = tool.discover(
|
||||
need="database with authentication",
|
||||
constraints=MCPDiscoveryConstraints(
|
||||
max_latency_ms=200,
|
||||
required_features=["auth", "realtime"]
|
||||
),
|
||||
limit=3
|
||||
)
|
||||
for rec in result["recommendations"]:
|
||||
print(f"{rec['server']}: {rec['description']}")
|
||||
```
|
||||
"""
|
||||
payload = self._build_request_payload(need, constraints, limit)
|
||||
response = self._make_api_request(payload)
|
||||
|
||||
recommendations = self._process_recommendations(
|
||||
response.get("recommendations", [])
|
||||
)
|
||||
|
||||
return {
|
||||
"recommendations": recommendations,
|
||||
"total_found": response.get("total_found", len(recommendations)),
|
||||
"query_time_ms": response.get("query_time_ms", 0),
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,452 +0,0 @@
|
||||
"""Tests for the MCP Discovery Tool."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
|
||||
MCPDiscoveryConstraints,
|
||||
MCPDiscoveryResult,
|
||||
MCPDiscoveryTool,
|
||||
MCPDiscoveryToolSchema,
|
||||
MCPServerMetrics,
|
||||
MCPServerRecommendation,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_response() -> dict:
|
||||
"""Create a mock API response for testing."""
|
||||
return {
|
||||
"recommendations": [
|
||||
{
|
||||
"server": "sqlite-server",
|
||||
"npm_package": "@modelcontextprotocol/server-sqlite",
|
||||
"install_command": "npx -y @modelcontextprotocol/server-sqlite",
|
||||
"confidence": 0.38,
|
||||
"description": "SQLite database server for MCP.",
|
||||
"capabilities": ["sqlite", "sql", "database", "embedded"],
|
||||
"metrics": {
|
||||
"avg_latency_ms": 50.0,
|
||||
"uptime_pct": 99.9,
|
||||
"last_checked": "2026-01-17T10:30:00Z",
|
||||
},
|
||||
"docs_url": "https://modelcontextprotocol.io/docs/servers/sqlite",
|
||||
"github_url": "https://github.com/modelcontextprotocol/servers",
|
||||
},
|
||||
{
|
||||
"server": "postgres-server",
|
||||
"npm_package": "@modelcontextprotocol/server-postgres",
|
||||
"install_command": "npx -y @modelcontextprotocol/server-postgres",
|
||||
"confidence": 0.33,
|
||||
"description": "PostgreSQL database server for MCP.",
|
||||
"capabilities": ["postgres", "sql", "database", "queries"],
|
||||
"metrics": {
|
||||
"avg_latency_ms": None,
|
||||
"uptime_pct": None,
|
||||
"last_checked": None,
|
||||
},
|
||||
"docs_url": "https://modelcontextprotocol.io/docs/servers/postgres",
|
||||
"github_url": "https://github.com/modelcontextprotocol/servers",
|
||||
},
|
||||
],
|
||||
"total_found": 2,
|
||||
"query_time_ms": 245,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discovery_tool() -> MCPDiscoveryTool:
|
||||
"""Create an MCPDiscoveryTool instance for testing."""
|
||||
return MCPDiscoveryTool()
|
||||
|
||||
|
||||
class TestMCPDiscoveryToolSchema:
|
||||
"""Tests for MCPDiscoveryToolSchema."""
|
||||
|
||||
def test_schema_with_required_fields(self) -> None:
|
||||
"""Test schema with only required fields."""
|
||||
schema = MCPDiscoveryToolSchema(need="database server")
|
||||
assert schema.need == "database server"
|
||||
assert schema.constraints is None
|
||||
assert schema.limit == 5
|
||||
|
||||
def test_schema_with_all_fields(self) -> None:
|
||||
"""Test schema with all fields."""
|
||||
constraints = MCPDiscoveryConstraints(
|
||||
max_latency_ms=200,
|
||||
required_features=["auth", "realtime"],
|
||||
exclude_servers=["deprecated-server"],
|
||||
)
|
||||
schema = MCPDiscoveryToolSchema(
|
||||
need="database with authentication",
|
||||
constraints=constraints,
|
||||
limit=3,
|
||||
)
|
||||
assert schema.need == "database with authentication"
|
||||
assert schema.constraints is not None
|
||||
assert schema.constraints.max_latency_ms == 200
|
||||
assert schema.constraints.required_features == ["auth", "realtime"]
|
||||
assert schema.constraints.exclude_servers == ["deprecated-server"]
|
||||
assert schema.limit == 3
|
||||
|
||||
def test_schema_limit_validation(self) -> None:
|
||||
"""Test that limit is validated to be between 1 and 10."""
|
||||
with pytest.raises(ValueError):
|
||||
MCPDiscoveryToolSchema(need="test", limit=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MCPDiscoveryToolSchema(need="test", limit=11)
|
||||
|
||||
|
||||
class TestMCPDiscoveryConstraints:
|
||||
"""Tests for MCPDiscoveryConstraints."""
|
||||
|
||||
def test_empty_constraints(self) -> None:
|
||||
"""Test creating empty constraints."""
|
||||
constraints = MCPDiscoveryConstraints()
|
||||
assert constraints.max_latency_ms is None
|
||||
assert constraints.required_features is None
|
||||
assert constraints.exclude_servers is None
|
||||
|
||||
def test_full_constraints(self) -> None:
|
||||
"""Test creating constraints with all fields."""
|
||||
constraints = MCPDiscoveryConstraints(
|
||||
max_latency_ms=100,
|
||||
required_features=["feature1", "feature2"],
|
||||
exclude_servers=["server1", "server2"],
|
||||
)
|
||||
assert constraints.max_latency_ms == 100
|
||||
assert constraints.required_features == ["feature1", "feature2"]
|
||||
assert constraints.exclude_servers == ["server1", "server2"]
|
||||
|
||||
|
||||
class TestMCPDiscoveryTool:
|
||||
"""Tests for MCPDiscoveryTool."""
|
||||
|
||||
def test_tool_initialization(self, discovery_tool: MCPDiscoveryTool) -> None:
|
||||
"""Test tool initialization with default values."""
|
||||
assert discovery_tool.name == "Discover MCP Server"
|
||||
assert "MCP" in discovery_tool.description
|
||||
assert discovery_tool.base_url == "https://mcp-discovery-production.up.railway.app"
|
||||
assert discovery_tool.timeout == 30
|
||||
|
||||
def test_build_request_payload_basic(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test building request payload with basic parameters."""
|
||||
payload = discovery_tool._build_request_payload(
|
||||
need="database server",
|
||||
constraints=None,
|
||||
limit=5,
|
||||
)
|
||||
assert payload == {"need": "database server", "limit": 5}
|
||||
|
||||
def test_build_request_payload_with_constraints(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test building request payload with constraints."""
|
||||
constraints = MCPDiscoveryConstraints(
|
||||
max_latency_ms=200,
|
||||
required_features=["auth"],
|
||||
exclude_servers=["old-server"],
|
||||
)
|
||||
payload = discovery_tool._build_request_payload(
|
||||
need="database",
|
||||
constraints=constraints,
|
||||
limit=3,
|
||||
)
|
||||
assert payload["need"] == "database"
|
||||
assert payload["limit"] == 3
|
||||
assert "constraints" in payload
|
||||
assert payload["constraints"]["max_latency_ms"] == 200
|
||||
assert payload["constraints"]["required_features"] == ["auth"]
|
||||
assert payload["constraints"]["exclude_servers"] == ["old-server"]
|
||||
|
||||
def test_build_request_payload_partial_constraints(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test building request payload with partial constraints."""
|
||||
constraints = MCPDiscoveryConstraints(max_latency_ms=100)
|
||||
payload = discovery_tool._build_request_payload(
|
||||
need="test",
|
||||
constraints=constraints,
|
||||
limit=5,
|
||||
)
|
||||
assert payload["constraints"] == {"max_latency_ms": 100}
|
||||
|
||||
def test_process_recommendations(
|
||||
self, discovery_tool: MCPDiscoveryTool, mock_api_response: dict
|
||||
) -> None:
|
||||
"""Test processing recommendations from API response."""
|
||||
recommendations = discovery_tool._process_recommendations(
|
||||
mock_api_response["recommendations"]
|
||||
)
|
||||
assert len(recommendations) == 2
|
||||
|
||||
first_rec = recommendations[0]
|
||||
assert first_rec["server"] == "sqlite-server"
|
||||
assert first_rec["npm_package"] == "@modelcontextprotocol/server-sqlite"
|
||||
assert first_rec["confidence"] == 0.38
|
||||
assert first_rec["capabilities"] == ["sqlite", "sql", "database", "embedded"]
|
||||
assert first_rec["metrics"]["avg_latency_ms"] == 50.0
|
||||
assert first_rec["metrics"]["uptime_pct"] == 99.9
|
||||
|
||||
second_rec = recommendations[1]
|
||||
assert second_rec["server"] == "postgres-server"
|
||||
assert second_rec["metrics"]["avg_latency_ms"] is None
|
||||
|
||||
def test_process_recommendations_with_malformed_data(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test processing recommendations with malformed data."""
|
||||
malformed_recommendations = [
|
||||
{"server": "valid-server", "confidence": 0.5},
|
||||
None,
|
||||
{"invalid": "data"},
|
||||
]
|
||||
recommendations = discovery_tool._process_recommendations(
|
||||
malformed_recommendations
|
||||
)
|
||||
assert len(recommendations) >= 1
|
||||
assert recommendations[0]["server"] == "valid-server"
|
||||
|
||||
def test_format_result_with_recommendations(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test formatting results with recommendations."""
|
||||
result: MCPDiscoveryResult = {
|
||||
"recommendations": [
|
||||
{
|
||||
"server": "test-server",
|
||||
"npm_package": "@test/server",
|
||||
"install_command": "npx -y @test/server",
|
||||
"confidence": 0.85,
|
||||
"description": "A test server",
|
||||
"capabilities": ["test", "demo"],
|
||||
"metrics": {
|
||||
"avg_latency_ms": 100.0,
|
||||
"uptime_pct": 99.5,
|
||||
"last_checked": "2026-01-17T10:00:00Z",
|
||||
},
|
||||
"docs_url": "https://example.com/docs",
|
||||
"github_url": "https://github.com/test/server",
|
||||
}
|
||||
],
|
||||
"total_found": 1,
|
||||
"query_time_ms": 150,
|
||||
}
|
||||
formatted = discovery_tool._format_result(result)
|
||||
assert "Found 1 MCP server(s)" in formatted
|
||||
assert "test-server" in formatted
|
||||
assert "85% confidence" in formatted
|
||||
assert "A test server" in formatted
|
||||
assert "test, demo" in formatted
|
||||
assert "npx -y @test/server" in formatted
|
||||
assert "100.0ms" in formatted
|
||||
assert "99.5%" in formatted
|
||||
|
||||
def test_format_result_empty(self, discovery_tool: MCPDiscoveryTool) -> None:
|
||||
"""Test formatting results with no recommendations."""
|
||||
result: MCPDiscoveryResult = {
|
||||
"recommendations": [],
|
||||
"total_found": 0,
|
||||
"query_time_ms": 50,
|
||||
}
|
||||
formatted = discovery_tool._format_result(result)
|
||||
assert "No MCP servers found" in formatted
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_success(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test successful API request."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool._make_api_request({"need": "database", "limit": 5})
|
||||
|
||||
assert result == mock_api_response
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"] == {"need": "database", "limit": 5}
|
||||
assert call_args[1]["timeout"] == 30
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_with_api_key(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test API request with API key."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch.dict("os.environ", {"MCP_DISCOVERY_API_KEY": "test-key"}):
|
||||
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||
|
||||
call_args = mock_post.call_args
|
||||
assert "Authorization" in call_args[1]["headers"]
|
||||
assert call_args[1]["headers"]["Authorization"] == "Bearer test-key"
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_empty_response(
|
||||
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test API request with empty response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(ValueError, match="Empty response"):
|
||||
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_network_error(
|
||||
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test API request with network error."""
|
||||
mock_post.side_effect = requests.exceptions.ConnectionError("Network error")
|
||||
|
||||
with pytest.raises(requests.exceptions.ConnectionError):
|
||||
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_json_decode_error(
|
||||
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test API request with JSON decode error."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Error", "", 0)
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.content = b"invalid json"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_run_success(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test successful _run execution."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool._run(need="database server")
|
||||
|
||||
assert "sqlite-server" in result
|
||||
assert "postgres-server" in result
|
||||
assert "Found 2 MCP server(s)" in result
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_run_with_constraints(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test _run with constraints."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool._run(
|
||||
need="database",
|
||||
constraints={"max_latency_ms": 100, "required_features": ["sql"]},
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert "sqlite-server" in result
|
||||
call_args = mock_post.call_args
|
||||
payload = call_args[1]["json"]
|
||||
assert payload["constraints"]["max_latency_ms"] == 100
|
||||
assert payload["constraints"]["required_features"] == ["sql"]
|
||||
assert payload["limit"] == 3
|
||||
|
||||
def test_run_missing_need_parameter(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test _run with missing need parameter."""
|
||||
with pytest.raises(ValueError, match="'need' parameter is required"):
|
||||
discovery_tool._run()
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_discover_method(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test the discover convenience method."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool.discover(
|
||||
need="database",
|
||||
constraints=MCPDiscoveryConstraints(max_latency_ms=200),
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "recommendations" in result
|
||||
assert "total_found" in result
|
||||
assert "query_time_ms" in result
|
||||
assert len(result["recommendations"]) == 2
|
||||
assert result["total_found"] == 2
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_discover_returns_structured_data(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test that discover returns properly structured data."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool.discover(need="database")
|
||||
|
||||
first_rec = result["recommendations"][0]
|
||||
assert "server" in first_rec
|
||||
assert "npm_package" in first_rec
|
||||
assert "install_command" in first_rec
|
||||
assert "confidence" in first_rec
|
||||
assert "description" in first_rec
|
||||
assert "capabilities" in first_rec
|
||||
assert "metrics" in first_rec
|
||||
assert "docs_url" in first_rec
|
||||
assert "github_url" in first_rec
|
||||
|
||||
|
||||
class TestMCPDiscoveryToolIntegration:
|
||||
"""Integration tests for MCPDiscoveryTool (requires network)."""
|
||||
|
||||
@pytest.mark.skip(reason="Integration test - requires network access")
|
||||
def test_real_api_call(self) -> None:
|
||||
"""Test actual API call to MCP Discovery service."""
|
||||
tool = MCPDiscoveryTool()
|
||||
result = tool._run(need="database", limit=3)
|
||||
assert "MCP server" in result or "No MCP servers found" in result
|
||||
@@ -60,6 +60,7 @@ class PlusAPI:
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
available_exports: list[dict[str, Any]] | None = None,
|
||||
tools_metadata: list[dict[str, Any]] | None = None,
|
||||
) -> requests.Response:
|
||||
params = {
|
||||
"handle": handle,
|
||||
@@ -68,6 +69,7 @@ class PlusAPI:
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": available_exports,
|
||||
"tools_metadata": {"package": handle, "tools": tools_metadata} if tools_metadata else None,
|
||||
}
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.cli.utils import (
|
||||
build_env_with_tool_repository_credentials,
|
||||
extract_available_exports,
|
||||
extract_tools_metadata,
|
||||
get_project_description,
|
||||
get_project_name,
|
||||
get_project_version,
|
||||
@@ -94,6 +95,18 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
console.print(
|
||||
f"[green]Found these tools to publish: {', '.join([e['name'] for e in available_exports])}[/green]"
|
||||
)
|
||||
|
||||
console.print("[bold blue]Extracting tool metadata...[/bold blue]")
|
||||
try:
|
||||
tools_metadata = extract_tools_metadata()
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Warning: Could not extract tool metadata: {e}[/yellow]\n"
|
||||
f"Publishing will continue without detailed metadata."
|
||||
)
|
||||
tools_metadata = []
|
||||
|
||||
self._print_tools_preview(tools_metadata)
|
||||
self._print_current_organization()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_build_dir:
|
||||
@@ -111,7 +124,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"Project build failed. Please ensure that the command `uv build --sdist` completes successfully.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
raise SystemExit(1)
|
||||
|
||||
tarball_path = os.path.join(temp_build_dir, tarball_filename)
|
||||
with open(tarball_path, "rb") as file:
|
||||
@@ -127,6 +140,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
description=project_description,
|
||||
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
|
||||
available_exports=available_exports,
|
||||
tools_metadata=tools_metadata,
|
||||
)
|
||||
|
||||
self._validate_response(publish_response)
|
||||
@@ -237,6 +251,41 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
def _print_tools_preview(self, tools_metadata: list[dict[str, Any]]) -> None:
|
||||
if not tools_metadata:
|
||||
console.print("[yellow]No tool metadata extracted.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"\n[bold]Tools to be published ({len(tools_metadata)}):[/bold]\n")
|
||||
|
||||
for tool in tools_metadata:
|
||||
console.print(f" [bold cyan]{tool.get('name', 'Unknown')}[/bold cyan]")
|
||||
if tool.get("module"):
|
||||
console.print(f" Module: {tool.get('module')}")
|
||||
console.print(f" Name: {tool.get('humanized_name', 'N/A')}")
|
||||
console.print(f" Description: {tool.get('description', 'N/A')[:80]}{'...' if len(tool.get('description', '')) > 80 else ''}")
|
||||
|
||||
init_params = tool.get("init_params_schema", {}).get("properties", {})
|
||||
if init_params:
|
||||
required = tool.get("init_params_schema", {}).get("required", [])
|
||||
console.print(" Init parameters:")
|
||||
for param_name, param_info in init_params.items():
|
||||
param_type = param_info.get("type", "any")
|
||||
is_required = param_name in required
|
||||
req_marker = "[red]*[/red]" if is_required else ""
|
||||
default = f" = {param_info['default']}" if "default" in param_info else ""
|
||||
console.print(f" - {param_name}: {param_type}{default} {req_marker}")
|
||||
|
||||
env_vars = tool.get("env_vars", [])
|
||||
if env_vars:
|
||||
console.print(" Environment variables:")
|
||||
for env_var in env_vars:
|
||||
req_marker = "[red]*[/red]" if env_var.get("required") else ""
|
||||
default = f" (default: {env_var['default']})" if env_var.get("default") else ""
|
||||
console.print(f" - {env_var['name']}: {env_var.get('description', 'N/A')}{default} {req_marker}")
|
||||
|
||||
console.print()
|
||||
|
||||
def _print_current_organization(self) -> None:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from functools import reduce
|
||||
from collections.abc import Mapping
|
||||
from functools import lru_cache, reduce
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
from inspect import getmro, isclass, isfunction, ismethod
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -511,3 +514,217 @@ def _print_no_tools_warning() -> None:
|
||||
" # ... implementation\n"
|
||||
" return result\n"
|
||||
)
|
||||
|
||||
|
||||
def extract_tools_metadata(dir_path: str = "src") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract rich metadata from tool classes in the project.
|
||||
|
||||
Returns a list of tool metadata dictionaries containing:
|
||||
- name: Class name
|
||||
- humanized_name: From name field default
|
||||
- description: From description field default
|
||||
- run_params_schema: JSON Schema for _run() params (from args_schema)
|
||||
- init_params_schema: JSON Schema for __init__ params (filtered)
|
||||
- env_vars: List of environment variable dicts
|
||||
"""
|
||||
tools_metadata: list[dict[str, Any]] = []
|
||||
|
||||
for init_file in Path(dir_path).glob("**/__init__.py"):
|
||||
tools = _extract_tool_metadata_from_init(init_file)
|
||||
tools_metadata.extend(tools)
|
||||
|
||||
return tools_metadata
|
||||
|
||||
|
||||
def _extract_tool_metadata_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load module from init file and extract metadata from valid tool classes.
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
module_name = f"temp_metadata_{hashlib.md5(str(init_file).encode()).hexdigest()[:8]}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, init_file)
|
||||
|
||||
if not spec or not spec.loader:
|
||||
return []
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
exported_names = getattr(module, "__all__", None)
|
||||
if not exported_names:
|
||||
return []
|
||||
|
||||
tools_metadata = []
|
||||
for name in exported_names:
|
||||
obj = getattr(module, name, None)
|
||||
if obj is None or not (inspect.isclass(obj) and issubclass(obj, BaseTool)):
|
||||
continue
|
||||
if tool_info := _extract_single_tool_metadata(obj):
|
||||
tools_metadata.append(tool_info)
|
||||
|
||||
return tools_metadata
|
||||
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Warning: Could not extract metadata from {init_file}: {e}[/yellow]"
|
||||
)
|
||||
return []
|
||||
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
|
||||
def _extract_single_tool_metadata(tool_class: type) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract metadata from a single tool class.
|
||||
"""
|
||||
try:
|
||||
core_schema = tool_class.__pydantic_core_schema__
|
||||
if not core_schema:
|
||||
return None
|
||||
|
||||
schema = _unwrap_schema(core_schema)
|
||||
fields = schema.get("schema", {}).get("fields", {})
|
||||
|
||||
try:
|
||||
file_path = inspect.getfile(tool_class)
|
||||
relative_path = Path(file_path).relative_to(Path.cwd())
|
||||
module_path = relative_path.with_suffix("")
|
||||
if module_path.parts[0] == "src":
|
||||
module_path = Path(*module_path.parts[1:])
|
||||
module = ".".join(module_path.parts)
|
||||
except (TypeError, ValueError):
|
||||
module = tool_class.__module__
|
||||
|
||||
return {
|
||||
"name": tool_class.__name__,
|
||||
"module": module,
|
||||
"humanized_name": _extract_field_default(
|
||||
fields.get("name"), fallback=tool_class.__name__
|
||||
),
|
||||
"description": str(
|
||||
_extract_field_default(fields.get("description"))
|
||||
).strip(),
|
||||
"run_params_schema": _extract_run_params_schema(fields.get("args_schema")),
|
||||
"init_params_schema": _extract_init_params_schema(tool_class),
|
||||
"env_vars": _extract_env_vars(fields.get("env_vars")),
|
||||
}
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _unwrap_schema(schema: Mapping[str, Any] | dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Unwrap nested schema structures to get to the actual schema definition.
|
||||
"""
|
||||
result: dict[str, Any] = dict(schema)
|
||||
while result.get("type") in {"function-after", "default"} and "schema" in result:
|
||||
result = dict(result["schema"])
|
||||
return result
|
||||
|
||||
|
||||
def _extract_field_default(
|
||||
field: dict[str, Any] | None, fallback: str | list[Any] = ""
|
||||
) -> str | list[Any] | int:
|
||||
"""
|
||||
Extract the default value from a field schema.
|
||||
"""
|
||||
if not field:
|
||||
return fallback
|
||||
|
||||
schema = field.get("schema", {})
|
||||
default = schema.get("default")
|
||||
return default if isinstance(default, (list, str, int)) else fallback
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_schema_generator() -> type:
|
||||
"""Get a SchemaGenerator that omits non-serializable defaults."""
|
||||
from pydantic.json_schema import GenerateJsonSchema
|
||||
from pydantic_core import PydanticOmit
|
||||
|
||||
class SchemaGenerator(GenerateJsonSchema):
|
||||
def handle_invalid_for_json_schema(
|
||||
self, schema: Any, error_info: Any
|
||||
) -> dict[str, Any]:
|
||||
raise PydanticOmit
|
||||
|
||||
return SchemaGenerator
|
||||
|
||||
|
||||
def _extract_run_params_schema(args_schema_field: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""
|
||||
Extract JSON Schema for the tool's run parameters from args_schema field.
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
if not args_schema_field:
|
||||
return {}
|
||||
|
||||
args_schema_class = args_schema_field.get("schema", {}).get("default")
|
||||
if not (inspect.isclass(args_schema_class) and issubclass(args_schema_class, BaseModel)):
|
||||
return {}
|
||||
|
||||
try:
|
||||
return args_schema_class.model_json_schema(schema_generator=_get_schema_generator())
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
_IGNORED_INIT_PARAMS = frozenset({
|
||||
"name",
|
||||
"description",
|
||||
"env_vars",
|
||||
"args_schema",
|
||||
"description_updated",
|
||||
"cache_function",
|
||||
"result_as_answer",
|
||||
"max_usage_count",
|
||||
"current_usage_count",
|
||||
"package_dependencies",
|
||||
})
|
||||
|
||||
|
||||
def _extract_init_params_schema(tool_class: type) -> dict[str, Any]:
|
||||
"""
|
||||
Extract JSON Schema for the tool's __init__ parameters, filtering out base fields.
|
||||
"""
|
||||
try:
|
||||
json_schema = tool_class.model_json_schema(
|
||||
schema_generator=_get_schema_generator(), mode="serialization"
|
||||
)
|
||||
json_schema["properties"] = {
|
||||
key: value
|
||||
for key, value in json_schema.get("properties", {}).items()
|
||||
if key not in _IGNORED_INIT_PARAMS
|
||||
}
|
||||
return json_schema
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _extract_env_vars(env_vars_field: dict[str, Any] | None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract environment variable definitions from env_vars field.
|
||||
"""
|
||||
from crewai.tools.base_tool import EnvVar
|
||||
|
||||
if not env_vars_field:
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"name": env_var.name,
|
||||
"description": env_var.description,
|
||||
"required": env_var.required,
|
||||
"default": env_var.default,
|
||||
}
|
||||
for env_var in env_vars_field.get("schema", {}).get("default", [])
|
||||
if isinstance(env_var, EnvVar)
|
||||
]
|
||||
|
||||
@@ -152,6 +152,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
"tools_metadata": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
@@ -190,6 +191,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
"tools_metadata": None,
|
||||
}
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
@@ -218,6 +220,48 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
"tools_metadata": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool_with_tools_metadata(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
available_exports = [{"name": "MyTool"}]
|
||||
tools_metadata = [
|
||||
{
|
||||
"name": "MyTool",
|
||||
"humanized_name": "my_tool",
|
||||
"description": "A test tool",
|
||||
"run_params_schema": {"type": "object", "properties": {}},
|
||||
"init_params_schema": {"type": "object", "properties": {}},
|
||||
"env_vars": [{"name": "API_KEY", "description": "API key", "required": True, "default": None}],
|
||||
}
|
||||
]
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file,
|
||||
available_exports=available_exports,
|
||||
tools_metadata=tools_metadata,
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": available_exports,
|
||||
"tools_metadata": {"package": handle, "tools": tools_metadata},
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
|
||||
@@ -363,3 +363,261 @@ def test_get_crews_ignores_template_directories(
|
||||
utils.get_crews()
|
||||
|
||||
assert not template_crew_detected
|
||||
|
||||
|
||||
# Tests for extract_tools_metadata
|
||||
|
||||
|
||||
def test_extract_tools_metadata_empty_project(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list for empty project."""
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_no_init_file(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list when no __init__.py exists."""
|
||||
(temp_project_dir / "some_file.py").write_text("print('hello')")
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_empty_init_file(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list for empty __init__.py."""
|
||||
create_init_file(temp_project_dir, "")
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_no_all_variable(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list when __all__ is not defined."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"from crewai.tools import BaseTool\n\nclass MyTool(BaseTool):\n pass",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_valid_base_tool_class(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from a valid BaseTool class."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
assert metadata[0]["name"] == "MyTool"
|
||||
assert metadata[0]["humanized_name"] == "my_tool"
|
||||
assert metadata[0]["description"] == "A test tool"
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_args_schema(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts run_params_schema from args_schema."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
class MyToolInput(BaseModel):
|
||||
query: str
|
||||
limit: int = 10
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
args_schema: type[BaseModel] = MyToolInput
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
assert metadata[0]["name"] == "MyTool"
|
||||
run_params = metadata[0]["run_params_schema"]
|
||||
assert "properties" in run_params
|
||||
assert "query" in run_params["properties"]
|
||||
assert "limit" in run_params["properties"]
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_env_vars(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts env_vars."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import EnvVar
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(name="MY_API_KEY", description="API key for service", required=True),
|
||||
EnvVar(name="MY_OPTIONAL_VAR", description="Optional var", required=False, default="default_value"),
|
||||
]
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
env_vars = metadata[0]["env_vars"]
|
||||
assert len(env_vars) == 2
|
||||
assert env_vars[0]["name"] == "MY_API_KEY"
|
||||
assert env_vars[0]["description"] == "API key for service"
|
||||
assert env_vars[0]["required"] is True
|
||||
assert env_vars[1]["name"] == "MY_OPTIONAL_VAR"
|
||||
assert env_vars[1]["required"] is False
|
||||
assert env_vars[1]["default"] == "default_value"
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_custom_init_params(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts init_params_schema with custom params."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
api_endpoint: str = "https://api.example.com"
|
||||
timeout: int = 30
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
init_params = metadata[0]["init_params_schema"]
|
||||
assert "properties" in init_params
|
||||
# Custom params should be included
|
||||
assert "api_endpoint" in init_params["properties"]
|
||||
assert "timeout" in init_params["properties"]
|
||||
# Base params should be filtered out
|
||||
assert "name" not in init_params["properties"]
|
||||
assert "description" not in init_params["properties"]
|
||||
|
||||
|
||||
def test_extract_tools_metadata_multiple_tools(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from multiple tools."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class FirstTool(BaseTool):
|
||||
name: str = "first_tool"
|
||||
description: str = "First test tool"
|
||||
|
||||
class SecondTool(BaseTool):
|
||||
name: str = "second_tool"
|
||||
description: str = "Second test tool"
|
||||
|
||||
__all__ = ['FirstTool', 'SecondTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 2
|
||||
names = [m["name"] for m in metadata]
|
||||
assert "FirstTool" in names
|
||||
assert "SecondTool" in names
|
||||
|
||||
|
||||
def test_extract_tools_metadata_multiple_init_files(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from multiple __init__.py files."""
|
||||
# Create tool in root __init__.py
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class RootTool(BaseTool):
|
||||
name: str = "root_tool"
|
||||
description: str = "Root tool"
|
||||
|
||||
__all__ = ['RootTool']
|
||||
""",
|
||||
)
|
||||
|
||||
# Create nested package with another tool
|
||||
nested_dir = temp_project_dir / "nested"
|
||||
nested_dir.mkdir()
|
||||
create_init_file(
|
||||
nested_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class NestedTool(BaseTool):
|
||||
name: str = "nested_tool"
|
||||
description: str = "Nested tool"
|
||||
|
||||
__all__ = ['NestedTool']
|
||||
""",
|
||||
)
|
||||
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 2
|
||||
names = [m["name"] for m in metadata]
|
||||
assert "RootTool" in names
|
||||
assert "NestedTool" in names
|
||||
|
||||
|
||||
def test_extract_tools_metadata_ignores_non_tool_exports(temp_project_dir):
|
||||
"""Test that extract_tools_metadata ignores non-BaseTool exports."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
|
||||
def not_a_tool():
|
||||
pass
|
||||
|
||||
SOME_CONSTANT = "value"
|
||||
|
||||
__all__ = ['MyTool', 'not_a_tool', 'SOME_CONSTANT']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
assert metadata[0]["name"] == "MyTool"
|
||||
|
||||
|
||||
def test_extract_tools_metadata_import_error_returns_empty(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list on import error."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from nonexistent_module import something
|
||||
|
||||
class MyTool(BaseTool):
|
||||
pass
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
# Should not raise, just return empty list
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_syntax_error_returns_empty(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list on syntax error."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
# Missing closing parenthesis
|
||||
def __init__(self, name:
|
||||
pass
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
# Should not raise, just return empty list
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
@@ -185,9 +185,14 @@ def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
@patch("crewai.cli.tools.main.ToolCommand._print_current_organization")
|
||||
def test_publish_when_not_in_sync_and_force(
|
||||
mock_print_org,
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
@@ -222,6 +227,7 @@ def test_publish_when_not_in_sync_and_force(
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
tools_metadata=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
mock_print_org.assert_called_once()
|
||||
|
||||
@@ -242,7 +248,12 @@ def test_publish_when_not_in_sync_and_force(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
def test_publish_success(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
@@ -277,6 +288,7 @@ def test_publish_success(
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
tools_metadata=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
|
||||
|
||||
@@ -295,7 +307,12 @@ def test_publish_success(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
def test_publish_failure(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
@@ -336,7 +353,12 @@ def test_publish_failure(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
def test_publish_api_error(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
@@ -362,6 +384,39 @@ def test_publish_api_error(
|
||||
mock_publish.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=True)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
side_effect=Exception("Failed to extract metadata"),
|
||||
)
|
||||
def test_publish_metadata_extraction_failure_continues_with_warning(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
"""Test that metadata extraction failure shows warning but continues publishing."""
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass # May fail later due to API mock, but should get past metadata extraction
|
||||
output = capsys.readouterr().out
|
||||
assert "Warning: Could not extract tool metadata" in output
|
||||
assert "Publishing will continue without detailed metadata" in output
|
||||
assert "No tool metadata extracted" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.Settings")
|
||||
def test_print_current_organization_with_org(mock_settings, capsys, tool_command):
|
||||
mock_settings_instance = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user