mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 15:18:14 +00:00
Fix #2698: Implement MCP SSE server connection for tools
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
41
tests/integration/test_mcp_tools_integration.py
Normal file
41
tests/integration/test_mcp_tools_integration.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.tools import BaseTool, MCPToolConnector, Tool
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestMCPToolsIntegration(unittest.TestCase):
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("CREWAI_INTEGRATION_TEST"),
|
||||
reason="Integration test requires CREWAI_INTEGRATION_TEST=true"
|
||||
)
|
||||
@patch("crewai.tools.mcp_connector.SSEClient")
|
||||
def test_mcp_tool_connector_integration(self, mock_sse_client):
|
||||
def add(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
calculator_tool = Tool(
|
||||
name="calculator_add",
|
||||
description="Add two numbers",
|
||||
func=add
|
||||
)
|
||||
|
||||
connector = MCPToolConnector(tools=[calculator_tool])
|
||||
|
||||
mock_sse = MagicMock()
|
||||
mock_sse_client.return_value = mock_sse
|
||||
|
||||
connector.connect()
|
||||
|
||||
tool_request_data = {
|
||||
"tool_name": "calculator_add",
|
||||
"arguments": {"a": 5, "b": 7},
|
||||
"request_id": "test-request-1"
|
||||
}
|
||||
|
||||
connector._handle_tool_request(tool_request_data)
|
||||
93
tests/tools/test_mcp_connector.py
Normal file
93
tests/tools/test_mcp_connector.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.tools import BaseTool, Tool
|
||||
from crewai.tools.mcp_connector import MCPToolConnector
|
||||
|
||||
|
||||
class TestMCPToolConnector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mock_tool = MagicMock(spec=BaseTool)
|
||||
self.mock_tool.name = "test_tool"
|
||||
self.mock_tool.description = "A test tool"
|
||||
self.mock_tool.args_schema = MagicMock()
|
||||
self.mock_tool.args_schema.model_json_schema.return_value = {
|
||||
"properties": {"input": {"type": "string"}}
|
||||
}
|
||||
self.mock_tool.run.return_value = "Tool result"
|
||||
|
||||
self.connector = MCPToolConnector(tools=[self.mock_tool])
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_access_token")
|
||||
@patch("crewai.tools.mcp_connector.SSEClient")
|
||||
def test_connect_success(self, mock_sse_client, mock_get_token):
|
||||
mock_get_token.return_value = "test-token"
|
||||
mock_sse = MagicMock()
|
||||
mock_sse_client.return_value = mock_sse
|
||||
|
||||
self.connector.connect()
|
||||
|
||||
mock_get_token.assert_called_once()
|
||||
mock_sse_client.assert_called_once_with(
|
||||
base_url="https://app.crewai.com",
|
||||
endpoint="/api/v1/tools/events",
|
||||
headers={
|
||||
"Authorization": "Bearer test-token",
|
||||
"Accept": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
mock_sse.connect.assert_called_once()
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_access_token")
|
||||
def test_connect_no_token(self, mock_get_token):
|
||||
mock_get_token.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Authentication token not found"):
|
||||
self.connector.connect()
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_access_token")
|
||||
@patch("crewai.tools.mcp_connector.SSEClient")
|
||||
def test_listen(self, mock_sse_client, mock_get_token):
|
||||
mock_get_token.return_value = "test-token"
|
||||
mock_sse = MagicMock()
|
||||
mock_sse_client.return_value = mock_sse
|
||||
|
||||
self.connector._sse_client = mock_sse
|
||||
self.connector.listen()
|
||||
|
||||
mock_sse.listen.assert_called_once()
|
||||
handlers = mock_sse.listen.call_args[0][0]
|
||||
assert "tool_request" in handlers
|
||||
assert "connection_closed" in handlers
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_access_token")
|
||||
@patch("crewai.tools.mcp_connector.SSEClient")
|
||||
def test_handle_tool_request(self, mock_sse_client, mock_get_token):
|
||||
mock_get_token.return_value = "test-token"
|
||||
|
||||
test_data = {
|
||||
"tool_name": "test_tool",
|
||||
"arguments": {"input": "test input"},
|
||||
"request_id": "123"
|
||||
}
|
||||
|
||||
self.connector._handle_tool_request(test_data)
|
||||
|
||||
self.mock_tool.run.assert_called_once_with(input="test input")
|
||||
|
||||
def test_handle_tool_request_not_found(self):
|
||||
test_data = {
|
||||
"tool_name": "non_existent_tool",
|
||||
"arguments": {"input": "test input"},
|
||||
"request_id": "123"
|
||||
}
|
||||
|
||||
self.connector._handle_tool_request(test_data)
|
||||
|
||||
self.mock_tool.run.assert_not_called()
|
||||
100
tests/utilities/test_sse_client.py
Normal file
100
tests/utilities/test_sse_client.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import sseclient
|
||||
|
||||
from crewai.utilities.sse_client import (
|
||||
SSEClient,
|
||||
SSEConnectionErrorEvent,
|
||||
SSEConnectionStartedEvent,
|
||||
SSEMessageReceivedEvent,
|
||||
)
|
||||
|
||||
|
||||
class TestSSEClient(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.base_url = "https://test.example.com"
|
||||
self.endpoint = "/events"
|
||||
self.headers = {"Authorization": "Bearer test-token"}
|
||||
self.sse_client = SSEClient(
|
||||
base_url=self.base_url,
|
||||
endpoint=self.endpoint,
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
@patch("crewai.utilities.events.crewai_event_bus.emit")
|
||||
@patch("requests.get")
|
||||
@patch("sseclient.SSEClient")
|
||||
def test_connect_success(self, mock_sse_client, mock_get, mock_emit):
|
||||
mock_response = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
self.sse_client.connect()
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
"https://test.example.com/events",
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=30
|
||||
)
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
mock_sse_client.assert_called_once_with(mock_response)
|
||||
mock_emit.assert_called_once()
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert isinstance(event, SSEConnectionStartedEvent)
|
||||
assert event.endpoint == "https://test.example.com/events"
|
||||
assert event.headers == self.headers
|
||||
|
||||
@patch("crewai.utilities.events.crewai_event_bus.emit")
|
||||
@patch("requests.get")
|
||||
def test_connect_error(self, mock_get, mock_emit):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("Connection error")
|
||||
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
self.sse_client.connect()
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert isinstance(event, SSEConnectionErrorEvent)
|
||||
assert event.endpoint == "https://test.example.com/events"
|
||||
assert "Connection error" in event.error
|
||||
|
||||
@patch("crewai.utilities.events.crewai_event_bus.emit")
|
||||
@patch("requests.get")
|
||||
def test_listen_with_handlers(self, mock_get, mock_emit):
|
||||
mock_response = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
mock_sse_client = MagicMock()
|
||||
mock_event1 = MagicMock(event="test_event", data='{"key": "value"}')
|
||||
mock_event2 = MagicMock(event="message", data="plain text")
|
||||
mock_sse_client.__iter__.return_value = [mock_event1, mock_event2]
|
||||
|
||||
self.sse_client._client = mock_sse_client
|
||||
|
||||
test_event_handler = MagicMock()
|
||||
message_handler = MagicMock()
|
||||
|
||||
event_handlers = {
|
||||
"test_event": test_event_handler,
|
||||
"message": message_handler
|
||||
}
|
||||
self.sse_client.listen(event_handlers)
|
||||
|
||||
test_event_handler.assert_called_once_with({"key": "value"})
|
||||
message_handler.assert_called_once_with("plain text")
|
||||
|
||||
assert mock_emit.call_count == 2
|
||||
event1 = mock_emit.call_args_list[0][1]["event"]
|
||||
event2 = mock_emit.call_args_list[1][1]["event"]
|
||||
|
||||
assert isinstance(event1, SSEMessageReceivedEvent)
|
||||
assert event1.event == "test_event"
|
||||
assert event1.data == {"key": "value"}
|
||||
|
||||
assert isinstance(event2, SSEMessageReceivedEvent)
|
||||
assert event2.event == "message"
|
||||
assert event2.data == "plain text"
|
||||
Reference in New Issue
Block a user