mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Compare commits
4 Commits
gl/fix/add
...
devin/1745
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8e7e2db6d | ||
|
|
0328a8aaa4 | ||
|
|
66d35df858 | ||
|
|
f738e9ab62 |
20
README.md
20
README.md
@@ -688,6 +688,26 @@ A: Yes, CrewAI can integrate with custom-trained or fine-tuned models, allowing
|
|||||||
### Q: Can CrewAI agents interact with external tools and APIs?
|
### Q: Can CrewAI agents interact with external tools and APIs?
|
||||||
A: Absolutely! CrewAI agents can easily integrate with external tools, APIs, and databases, empowering them to leverage real-world data and resources.
|
A: Absolutely! CrewAI agents can easily integrate with external tools, APIs, and databases, empowering them to leverage real-world data and resources.
|
||||||
|
|
||||||
|
CrewAI also supports connecting your tools to the Management Control Plane (MCP) via Server-Sent Events (SSE), enabling real-time tool execution from the Crew Control Plane:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai.tools import MCPToolConnector, Tool
|
||||||
|
|
||||||
|
# Define your tools
|
||||||
|
search_tool = Tool(
|
||||||
|
name="search",
|
||||||
|
description="Search for information",
|
||||||
|
func=lambda query: f"Results for {query}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect tools to MCP
|
||||||
|
connector = MCPToolConnector(tools=[search_tool])
|
||||||
|
connector.connect()
|
||||||
|
|
||||||
|
# Listen for tool events from MCP
|
||||||
|
connector.listen() # This will block until interrupted
|
||||||
|
```
|
||||||
|
|
||||||
### Q: Is CrewAI suitable for production environments?
|
### Q: Is CrewAI suitable for production environments?
|
||||||
A: Yes, CrewAI is explicitly designed with production-grade standards, ensuring reliability, stability, and scalability for enterprise deployments.
|
A: Yes, CrewAI is explicitly designed with production-grade standards, ensuring reliability, stability, and scalability for enterprise deployments.
|
||||||
|
|
||||||
|
|||||||
67
examples/mcp_sse_tool_example.py
Normal file
67
examples/mcp_sse_tool_example.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from crewai.tools import MCPToolConnector, Tool
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
"""Set up logging configuration."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
handlers=[logging.StreamHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_exit(signum, frame):
|
||||||
|
"""Handle exit signals gracefully."""
|
||||||
|
print("\nExiting...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to demonstrate MCP SSE tool connection."""
|
||||||
|
setup_logging()
|
||||||
|
signal.signal(signal.SIGINT, handle_exit)
|
||||||
|
|
||||||
|
print("CrewAI MCP SSE Tool Connection Example")
|
||||||
|
print("--------------------------------------")
|
||||||
|
print("This example connects tools to the MCP SSE server.")
|
||||||
|
print("Make sure you're logged in with 'crewai login' first.")
|
||||||
|
print("Press Ctrl+C to exit.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
def search(query: str) -> str:
|
||||||
|
"""Search for information."""
|
||||||
|
return f"Searching for: {query}"
|
||||||
|
|
||||||
|
search_tool = Tool(
|
||||||
|
name="search",
|
||||||
|
description="Search for information",
|
||||||
|
func=search
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = [search_tool]
|
||||||
|
|
||||||
|
connector = MCPToolConnector(tools=tools)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Connecting to MCP SSE server...")
|
||||||
|
connector.connect()
|
||||||
|
print("Connected! Listening for tool events...")
|
||||||
|
|
||||||
|
connector.listen()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nExiting...")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {str(e)}")
|
||||||
|
finally:
|
||||||
|
connector.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1 +1,15 @@
|
|||||||
from .base_tool import BaseTool, tool
|
from .base_tool import BaseTool, tool
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseTool",
|
||||||
|
"tool",
|
||||||
|
]
|
||||||
|
|
||||||
|
from .base_tool import Tool, to_langchain
|
||||||
|
from .mcp_connector import MCPToolConnector
|
||||||
|
|
||||||
|
__all__ += [
|
||||||
|
"Tool",
|
||||||
|
"to_langchain",
|
||||||
|
"MCPToolConnector",
|
||||||
|
]
|
||||||
|
|||||||
115
src/crewai/tools/mcp_connector.py
Normal file
115
src/crewai/tools/mcp_connector.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
|
from crewai.cli.authentication.utils import TokenManager
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from crewai.utilities.sse_client import SSEClient
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolConnector:
|
||||||
|
"""Connects tools to the Management Control Plane (MCP) via SSE."""
|
||||||
|
|
||||||
|
MCP_BASE_URL = "https://app.crewai.com"
|
||||||
|
SSE_ENDPOINT = "/api/v1/tools/events"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tools: Optional[List[BaseTool]] = None,
|
||||||
|
timeout: int = 30
|
||||||
|
):
|
||||||
|
"""Initialize the MCP Tool Connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: List of tools to connect to the MCP.
|
||||||
|
timeout: Connection timeout in seconds.
|
||||||
|
"""
|
||||||
|
self.tools = tools or []
|
||||||
|
self.timeout = timeout
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self.token_manager = TokenManager()
|
||||||
|
self._sse_client: Optional[SSEClient] = None
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""Connect to the MCP SSE server for tools."""
|
||||||
|
token = self.token_manager.get_token()
|
||||||
|
if not token:
|
||||||
|
self.logger.error("Authentication token not found. Please log in first.")
|
||||||
|
raise ValueError("Authentication token not found. Please log in first.")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Requested-With": "XMLHttpRequest",
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_data = {}
|
||||||
|
for tool in self.tools:
|
||||||
|
tool_data[tool.name] = {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"schema": tool.args_schema.model_json_schema() if hasattr(tool.args_schema, "model_json_schema") else {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
self._sse_client = SSEClient(
|
||||||
|
base_url=self.MCP_BASE_URL,
|
||||||
|
endpoint=self.SSE_ENDPOINT,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self._sse_client is not None:
|
||||||
|
self._sse_client.connect()
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to connect to MCP SSE server: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def listen(self) -> None:
|
||||||
|
"""Listen for tool events from the MCP SSE server."""
|
||||||
|
if not self._sse_client:
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
event_handlers = {
|
||||||
|
"tool_request": self._handle_tool_request,
|
||||||
|
"connection_closed": self._handle_connection_closed,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self._sse_client is not None:
|
||||||
|
self._sse_client.listen(event_handlers)
|
||||||
|
else:
|
||||||
|
self.logger.error("SSE client is not initialized")
|
||||||
|
raise ValueError("SSE client is not initialized")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error listening to MCP SSE events: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _handle_tool_request(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle a tool request event from the MCP SSE server."""
|
||||||
|
try:
|
||||||
|
tool_name = data.get("tool_name")
|
||||||
|
arguments = data.get("arguments", {})
|
||||||
|
request_id = data.get("request_id")
|
||||||
|
|
||||||
|
tool = next((t for t in self.tools if t.name == tool_name), None)
|
||||||
|
if not tool:
|
||||||
|
self.logger.error(f"Tool '{tool_name}' not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
result = tool.run(**arguments)
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error handling tool request: {str(e)}")
|
||||||
|
|
||||||
|
def _handle_connection_closed(self, data: Any) -> None:
|
||||||
|
"""Handle a connection closed event from the MCP SSE server."""
|
||||||
|
self.logger.info("MCP SSE connection closed")
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the MCP SSE connection."""
|
||||||
|
if self._sse_client:
|
||||||
|
self._sse_client.close()
|
||||||
|
self._sse_client = None
|
||||||
152
src/crewai/utilities/sse_client.py
Normal file
152
src/crewai/utilities/sse_client.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, Mapping, Optional, Union
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import sseclient
|
||||||
|
|
||||||
|
from crewai.utilities.events import crewai_event_bus
|
||||||
|
from crewai.utilities.events.base_events import BaseEvent
|
||||||
|
|
||||||
|
|
||||||
|
class SSEConnectionStartedEvent(BaseEvent):
|
||||||
|
"""Event emitted when an SSE connection is started"""
|
||||||
|
type: str = "sse_connection_started"
|
||||||
|
endpoint: str
|
||||||
|
headers: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class SSEConnectionErrorEvent(BaseEvent):
|
||||||
|
"""Event emitted when an SSE connection encounters an error"""
|
||||||
|
type: str = "sse_connection_error"
|
||||||
|
endpoint: str
|
||||||
|
error: str
|
||||||
|
|
||||||
|
|
||||||
|
class SSEMessageReceivedEvent(BaseEvent):
|
||||||
|
"""Event emitted when an SSE message is received"""
|
||||||
|
type: str = "sse_message_received"
|
||||||
|
endpoint: str
|
||||||
|
event: str
|
||||||
|
data: Any
|
||||||
|
|
||||||
|
|
||||||
|
class SSEClient:
|
||||||
|
"""Client for connecting to Server-Sent Events (SSE) endpoints"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
endpoint: str = "",
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
):
|
||||||
|
"""Initialize the SSE client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Base URL for the SSE server.
|
||||||
|
endpoint: Endpoint path to connect to (will be joined with base_url).
|
||||||
|
headers: Headers to include in the SSE request.
|
||||||
|
timeout: Connection timeout in seconds.
|
||||||
|
"""
|
||||||
|
self.base_url = base_url
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.headers = headers or {}
|
||||||
|
self.timeout = timeout
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self._client: Optional[sseclient.SSEClient] = None
|
||||||
|
self._response: Optional[requests.Response] = None
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""Establish a connection to the SSE server."""
|
||||||
|
try:
|
||||||
|
url = urljoin(self.base_url, self.endpoint)
|
||||||
|
self.logger.info(f"Connecting to SSE server at {url}")
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=SSEConnectionStartedEvent(
|
||||||
|
endpoint=url,
|
||||||
|
headers=self.headers
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._response = requests.get(
|
||||||
|
url,
|
||||||
|
headers=self.headers,
|
||||||
|
stream=True,
|
||||||
|
timeout=self.timeout
|
||||||
|
)
|
||||||
|
if self._response is not None:
|
||||||
|
self._response.raise_for_status()
|
||||||
|
self._client = sseclient.SSEClient(self._response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error connecting to SSE server: {str(e)}")
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=SSEConnectionErrorEvent(
|
||||||
|
endpoint=urljoin(self.base_url, self.endpoint),
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def listen(self, event_handlers: Optional[Mapping[str, Callable[[Any], None]]] = None) -> None:
|
||||||
|
"""Listen for SSE events and process them with registered handlers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_handlers: Dictionary mapping event types to handler functions.
|
||||||
|
"""
|
||||||
|
if self._client is None:
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
event_handlers = event_handlers or {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self._client is None:
|
||||||
|
self.logger.error("SSE client is not initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
for event in self._client:
|
||||||
|
event_type = event.event or "message"
|
||||||
|
data = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(event.data)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
data = event.data
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=SSEMessageReceivedEvent(
|
||||||
|
endpoint=urljoin(self.base_url, self.endpoint),
|
||||||
|
event=event_type,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = event_handlers.get(event_type)
|
||||||
|
if handler:
|
||||||
|
handler(data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error processing SSE events: {str(e)}")
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=SSEConnectionErrorEvent(
|
||||||
|
endpoint=urljoin(self.base_url, self.endpoint),
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the SSE connection."""
|
||||||
|
if self._response:
|
||||||
|
self._response.close()
|
||||||
|
self._response = None
|
||||||
|
self._client = None
|
||||||
41
tests/integration/test_mcp_tools_integration.py
Normal file
41
tests/integration/test_mcp_tools_integration.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, MCPToolConnector, Tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMCPToolsIntegration(unittest.TestCase):
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("CREWAI_INTEGRATION_TEST"),
|
||||||
|
reason="Integration test requires CREWAI_INTEGRATION_TEST=true"
|
||||||
|
)
|
||||||
|
@patch("crewai.tools.mcp_connector.SSEClient")
|
||||||
|
def test_mcp_tool_connector_integration(self, mock_sse_client):
|
||||||
|
def add(a: int, b: int) -> int:
|
||||||
|
"""Add two numbers."""
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
calculator_tool = Tool(
|
||||||
|
name="calculator_add",
|
||||||
|
description="Add two numbers",
|
||||||
|
func=add
|
||||||
|
)
|
||||||
|
|
||||||
|
connector = MCPToolConnector(tools=[calculator_tool])
|
||||||
|
|
||||||
|
mock_sse = MagicMock()
|
||||||
|
mock_sse_client.return_value = mock_sse
|
||||||
|
|
||||||
|
connector.connect()
|
||||||
|
|
||||||
|
tool_request_data = {
|
||||||
|
"tool_name": "calculator_add",
|
||||||
|
"arguments": {"a": 5, "b": 7},
|
||||||
|
"request_id": "test-request-1"
|
||||||
|
}
|
||||||
|
|
||||||
|
connector._handle_tool_request(tool_request_data)
|
||||||
93
tests/tools/test_mcp_connector.py
Normal file
93
tests/tools/test_mcp_connector.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, Tool
|
||||||
|
from crewai.tools.mcp_connector import MCPToolConnector
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolConnector(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.mock_tool = MagicMock(spec=BaseTool)
|
||||||
|
self.mock_tool.name = "test_tool"
|
||||||
|
self.mock_tool.description = "A test tool"
|
||||||
|
self.mock_tool.args_schema = MagicMock()
|
||||||
|
self.mock_tool.args_schema.model_json_schema.return_value = {
|
||||||
|
"properties": {"input": {"type": "string"}}
|
||||||
|
}
|
||||||
|
self.mock_tool.run.return_value = "Tool result"
|
||||||
|
|
||||||
|
self.connector = MCPToolConnector(tools=[self.mock_tool])
|
||||||
|
|
||||||
|
@patch("crewai.cli.authentication.utils.TokenManager.get_access_token")
|
||||||
|
@patch("crewai.tools.mcp_connector.SSEClient")
|
||||||
|
def test_connect_success(self, mock_sse_client, mock_get_token):
|
||||||
|
mock_get_token.return_value = "test-token"
|
||||||
|
mock_sse = MagicMock()
|
||||||
|
mock_sse_client.return_value = mock_sse
|
||||||
|
|
||||||
|
self.connector.connect()
|
||||||
|
|
||||||
|
mock_get_token.assert_called_once()
|
||||||
|
mock_sse_client.assert_called_once_with(
|
||||||
|
base_url="https://app.crewai.com",
|
||||||
|
endpoint="/api/v1/tools/events",
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer test-token",
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Requested-With": "XMLHttpRequest",
|
||||||
|
},
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
mock_sse.connect.assert_called_once()
|
||||||
|
|
||||||
|
@patch("crewai.cli.authentication.utils.TokenManager.get_access_token")
|
||||||
|
def test_connect_no_token(self, mock_get_token):
|
||||||
|
mock_get_token.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Authentication token not found"):
|
||||||
|
self.connector.connect()
|
||||||
|
|
||||||
|
@patch("crewai.cli.authentication.utils.TokenManager.get_access_token")
|
||||||
|
@patch("crewai.tools.mcp_connector.SSEClient")
|
||||||
|
def test_listen(self, mock_sse_client, mock_get_token):
|
||||||
|
mock_get_token.return_value = "test-token"
|
||||||
|
mock_sse = MagicMock()
|
||||||
|
mock_sse_client.return_value = mock_sse
|
||||||
|
|
||||||
|
self.connector._sse_client = mock_sse
|
||||||
|
self.connector.listen()
|
||||||
|
|
||||||
|
mock_sse.listen.assert_called_once()
|
||||||
|
handlers = mock_sse.listen.call_args[0][0]
|
||||||
|
assert "tool_request" in handlers
|
||||||
|
assert "connection_closed" in handlers
|
||||||
|
|
||||||
|
@patch("crewai.cli.authentication.utils.TokenManager.get_access_token")
|
||||||
|
@patch("crewai.tools.mcp_connector.SSEClient")
|
||||||
|
def test_handle_tool_request(self, mock_sse_client, mock_get_token):
|
||||||
|
mock_get_token.return_value = "test-token"
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"tool_name": "test_tool",
|
||||||
|
"arguments": {"input": "test input"},
|
||||||
|
"request_id": "123"
|
||||||
|
}
|
||||||
|
|
||||||
|
self.connector._handle_tool_request(test_data)
|
||||||
|
|
||||||
|
self.mock_tool.run.assert_called_once_with(input="test input")
|
||||||
|
|
||||||
|
def test_handle_tool_request_not_found(self):
|
||||||
|
test_data = {
|
||||||
|
"tool_name": "non_existent_tool",
|
||||||
|
"arguments": {"input": "test input"},
|
||||||
|
"request_id": "123"
|
||||||
|
}
|
||||||
|
|
||||||
|
self.connector._handle_tool_request(test_data)
|
||||||
|
|
||||||
|
self.mock_tool.run.assert_not_called()
|
||||||
100
tests/utilities/test_sse_client.py
Normal file
100
tests/utilities/test_sse_client.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import sseclient
|
||||||
|
|
||||||
|
from crewai.utilities.sse_client import (
|
||||||
|
SSEClient,
|
||||||
|
SSEConnectionErrorEvent,
|
||||||
|
SSEConnectionStartedEvent,
|
||||||
|
SSEMessageReceivedEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEClient(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.base_url = "https://test.example.com"
|
||||||
|
self.endpoint = "/events"
|
||||||
|
self.headers = {"Authorization": "Bearer test-token"}
|
||||||
|
self.sse_client = SSEClient(
|
||||||
|
base_url=self.base_url,
|
||||||
|
endpoint=self.endpoint,
|
||||||
|
headers=self.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("crewai.utilities.events.crewai_event_bus.emit")
|
||||||
|
@patch("requests.get")
|
||||||
|
@patch("sseclient.SSEClient")
|
||||||
|
def test_connect_success(self, mock_sse_client, mock_get, mock_emit):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
self.sse_client.connect()
|
||||||
|
|
||||||
|
mock_get.assert_called_once_with(
|
||||||
|
"https://test.example.com/events",
|
||||||
|
headers=self.headers,
|
||||||
|
stream=True,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
mock_response.raise_for_status.assert_called_once()
|
||||||
|
mock_sse_client.assert_called_once_with(mock_response)
|
||||||
|
mock_emit.assert_called_once()
|
||||||
|
event = mock_emit.call_args[1]["event"]
|
||||||
|
assert isinstance(event, SSEConnectionStartedEvent)
|
||||||
|
assert event.endpoint == "https://test.example.com/events"
|
||||||
|
assert event.headers == self.headers
|
||||||
|
|
||||||
|
@patch("crewai.utilities.events.crewai_event_bus.emit")
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_connect_error(self, mock_get, mock_emit):
|
||||||
|
mock_get.side_effect = requests.exceptions.RequestException("Connection error")
|
||||||
|
|
||||||
|
with pytest.raises(requests.exceptions.RequestException):
|
||||||
|
self.sse_client.connect()
|
||||||
|
|
||||||
|
mock_emit.assert_called_once()
|
||||||
|
event = mock_emit.call_args[1]["event"]
|
||||||
|
assert isinstance(event, SSEConnectionErrorEvent)
|
||||||
|
assert event.endpoint == "https://test.example.com/events"
|
||||||
|
assert "Connection error" in event.error
|
||||||
|
|
||||||
|
@patch("crewai.utilities.events.crewai_event_bus.emit")
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_listen_with_handlers(self, mock_get, mock_emit):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
mock_sse_client = MagicMock()
|
||||||
|
mock_event1 = MagicMock(event="test_event", data='{"key": "value"}')
|
||||||
|
mock_event2 = MagicMock(event="message", data="plain text")
|
||||||
|
mock_sse_client.__iter__.return_value = [mock_event1, mock_event2]
|
||||||
|
|
||||||
|
self.sse_client._client = mock_sse_client
|
||||||
|
|
||||||
|
test_event_handler = MagicMock()
|
||||||
|
message_handler = MagicMock()
|
||||||
|
|
||||||
|
event_handlers = {
|
||||||
|
"test_event": test_event_handler,
|
||||||
|
"message": message_handler
|
||||||
|
}
|
||||||
|
self.sse_client.listen(event_handlers)
|
||||||
|
|
||||||
|
test_event_handler.assert_called_once_with({"key": "value"})
|
||||||
|
message_handler.assert_called_once_with("plain text")
|
||||||
|
|
||||||
|
assert mock_emit.call_count == 2
|
||||||
|
event1 = mock_emit.call_args_list[0][1]["event"]
|
||||||
|
event2 = mock_emit.call_args_list[1][1]["event"]
|
||||||
|
|
||||||
|
assert isinstance(event1, SSEMessageReceivedEvent)
|
||||||
|
assert event1.event == "test_event"
|
||||||
|
assert event1.data == {"key": "value"}
|
||||||
|
|
||||||
|
assert isinstance(event2, SSEMessageReceivedEvent)
|
||||||
|
assert event2.event == "message"
|
||||||
|
assert event2.data == "plain text"
|
||||||
Reference in New Issue
Block a user