feat: implement client-initiated real-time human input event streams

- Add HumanInputRequiredEvent and HumanInputCompletedEvent to task_events.py
- Implement HTTP server with WebSocket, SSE, and long-polling endpoints
- Add EventStreamManager for connection and event management
- Integrate event emission in agent executor human input flow
- Add comprehensive tests for server endpoints and event integration
- Add optional FastAPI dependencies for server functionality
- Include documentation and example usage
- Maintain backward compatibility with existing human input flow

Addresses issue #3259 for WebSocket/SSE/long-polling human input events

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-08-02 16:05:43 +00:00
parent 88ed91561f
commit ea560d0af1
12 changed files with 1440 additions and 2 deletions

View File

@@ -0,0 +1,213 @@
import pytest
import asyncio
import json
from unittest.mock import MagicMock, patch
from crewai.server.event_stream_manager import EventStreamManager
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
class TestEventStreamManager:
"""Test the event stream manager"""
def setup_method(self):
"""Setup test environment"""
self.manager = EventStreamManager()
self.manager._websocket_connections.clear()
self.manager._sse_connections.clear()
self.manager._polling_events.clear()
def test_websocket_connection_management(self):
"""Test WebSocket connection management"""
execution_id = "test-execution"
websocket1 = MagicMock()
websocket2 = MagicMock()
self.manager.add_websocket_connection(execution_id, websocket1)
assert execution_id in self.manager._websocket_connections
assert websocket1 in self.manager._websocket_connections[execution_id]
self.manager.add_websocket_connection(execution_id, websocket2)
assert len(self.manager._websocket_connections[execution_id]) == 2
self.manager.remove_websocket_connection(execution_id, websocket1)
assert websocket1 not in self.manager._websocket_connections[execution_id]
assert websocket2 in self.manager._websocket_connections[execution_id]
self.manager.remove_websocket_connection(execution_id, websocket2)
assert execution_id not in self.manager._websocket_connections
def test_sse_connection_management(self):
"""Test SSE connection management"""
execution_id = "test-execution"
queue1 = MagicMock()
queue2 = MagicMock()
self.manager.add_sse_connection(execution_id, queue1)
assert execution_id in self.manager._sse_connections
assert queue1 in self.manager._sse_connections[execution_id]
self.manager.add_sse_connection(execution_id, queue2)
assert len(self.manager._sse_connections[execution_id]) == 2
self.manager.remove_sse_connection(execution_id, queue1)
assert queue1 not in self.manager._sse_connections[execution_id]
assert queue2 in self.manager._sse_connections[execution_id]
self.manager.remove_sse_connection(execution_id, queue2)
assert execution_id not in self.manager._sse_connections
def test_polling_events_storage(self):
"""Test polling events storage and retrieval"""
execution_id = "test-execution"
event1 = {"event_id": "event-1", "type": "test", "data": "test1"}
event2 = {"event_id": "event-2", "type": "test", "data": "test2"}
self.manager._store_polling_event(execution_id, event1)
self.manager._store_polling_event(execution_id, event2)
events = self.manager.get_polling_events(execution_id)
assert len(events) == 2
assert events[0] == event1
assert events[1] == event2
def test_polling_events_with_last_event_id(self):
"""Test polling events retrieval with last_event_id"""
execution_id = "test-execution"
event1 = {"event_id": "event-1", "type": "test", "data": "test1"}
event2 = {"event_id": "event-2", "type": "test", "data": "test2"}
event3 = {"event_id": "event-3", "type": "test", "data": "test3"}
self.manager._store_polling_event(execution_id, event1)
self.manager._store_polling_event(execution_id, event2)
self.manager._store_polling_event(execution_id, event3)
events = self.manager.get_polling_events(execution_id, "event-1")
assert len(events) == 2
assert events[0] == event2
assert events[1] == event3
def test_polling_events_limit(self):
"""Test polling events storage limit"""
execution_id = "test-execution"
for i in range(105):
event = {"event_id": f"event-{i}", "type": "test", "data": f"test{i}"}
self.manager._store_polling_event(execution_id, event)
events = self.manager.get_polling_events(execution_id)
assert len(events) == 100
assert events[0]["event_id"] == "event-5"
assert events[-1]["event_id"] == "event-104"
def test_event_serialization(self):
"""Test event serialization"""
event = HumanInputRequiredEvent(
execution_id="test-execution",
crew_id="test-crew",
task_id="test-task",
prompt="Test prompt"
)
serialized = self.manager._serialize_event(event)
assert isinstance(serialized, dict)
assert serialized["type"] == "human_input_required"
assert serialized["execution_id"] == "test-execution"
assert "event_id" in serialized
def test_broadcast_websocket(self):
"""Test WebSocket broadcasting"""
execution_id = "test-execution"
websocket = MagicMock()
self.manager.add_websocket_connection(execution_id, websocket)
event_data = {"type": "test", "data": "test"}
with patch('asyncio.create_task') as mock_create_task:
self.manager._broadcast_websocket(execution_id, event_data)
mock_create_task.assert_called_once()
def test_broadcast_sse(self):
"""Test SSE broadcasting"""
execution_id = "test-execution"
queue = MagicMock()
self.manager.add_sse_connection(execution_id, queue)
event_data = {"type": "test", "data": "test"}
self.manager._broadcast_sse(execution_id, event_data)
queue.put_nowait.assert_called_once_with(event_data)
def test_broadcast_event(self):
"""Test complete event broadcasting"""
execution_id = "test-execution"
event = HumanInputRequiredEvent(
execution_id=execution_id,
crew_id="test-crew",
task_id="test-task",
prompt="Test prompt"
)
with patch.object(self.manager, '_broadcast_websocket') as mock_ws, \
patch.object(self.manager, '_broadcast_sse') as mock_sse, \
patch.object(self.manager, '_store_polling_event') as mock_poll:
self.manager._broadcast_event(event)
mock_ws.assert_called_once()
mock_sse.assert_called_once()
mock_poll.assert_called_once()
def test_cleanup_execution(self):
"""Test execution cleanup"""
execution_id = "test-execution"
websocket = MagicMock()
queue = MagicMock()
event = {"event_id": "event-1", "type": "test"}
self.manager.add_websocket_connection(execution_id, websocket)
self.manager.add_sse_connection(execution_id, queue)
self.manager._store_polling_event(execution_id, event)
assert execution_id in self.manager._websocket_connections
assert execution_id in self.manager._sse_connections
assert execution_id in self.manager._polling_events
self.manager.cleanup_execution(execution_id)
assert execution_id not in self.manager._websocket_connections
assert execution_id not in self.manager._sse_connections
assert execution_id not in self.manager._polling_events
def test_register_event_listeners(self):
"""Test event listener registration"""
with patch('crewai.utilities.events.crewai_event_bus.crewai_event_bus.on') as mock_on:
self.manager.register_event_listeners()
assert mock_on.call_count == 2
self.manager.register_event_listeners()
assert mock_on.call_count == 2
def test_broadcast_event_without_execution_id(self):
"""Test broadcasting event without execution_id"""
event = HumanInputRequiredEvent(
crew_id="test-crew",
task_id="test-task",
prompt="Test prompt"
)
with patch.object(self.manager, '_broadcast_websocket') as mock_ws, \
patch.object(self.manager, '_broadcast_sse') as mock_sse, \
patch.object(self.manager, '_store_polling_event') as mock_poll:
self.manager._broadcast_event(event)
mock_ws.assert_not_called()
mock_sse.assert_not_called()
mock_poll.assert_not_called()

View File

@@ -0,0 +1,139 @@
import pytest
import asyncio
import json
from unittest.mock import patch, MagicMock
try:
from fastapi.testclient import TestClient
from crewai.server.human_input_server import HumanInputServer
from crewai.server.event_stream_manager import event_stream_manager
from crewai.utilities.events.task_events import HumanInputRequiredEvent
FASTAPI_AVAILABLE = True
except ImportError:
FASTAPI_AVAILABLE = False
@pytest.mark.skipif(not FASTAPI_AVAILABLE, reason="FastAPI dependencies not available")
class TestHumanInputServer:
"""Test the human input server endpoints"""
def setup_method(self):
"""Setup test environment"""
self.server = HumanInputServer(host="localhost", port=8001, api_key="test-key")
self.client = TestClient(self.server.app)
event_stream_manager._websocket_connections.clear()
event_stream_manager._sse_connections.clear()
event_stream_manager._polling_events.clear()
def test_health_endpoint(self):
"""Test health check endpoint"""
response = self.client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "timestamp" in data
def test_polling_endpoint_unauthorized(self):
"""Test polling endpoint without authentication"""
response = self.client.get("/poll/human-input/test-execution-id")
assert response.status_code == 401
def test_polling_endpoint_authorized(self):
"""Test polling endpoint with authentication"""
headers = {"Authorization": "Bearer test-key"}
response = self.client.get("/poll/human-input/test-execution-id", headers=headers)
assert response.status_code == 200
data = response.json()
assert "events" in data
assert isinstance(data["events"], list)
def test_polling_endpoint_with_events(self):
"""Test polling endpoint returns stored events"""
execution_id = "test-execution-id"
event = HumanInputRequiredEvent(
execution_id=execution_id,
crew_id="test-crew",
task_id="test-task",
agent_id="test-agent",
prompt="Test prompt",
context="Test context",
event_id="test-event-1"
)
event_stream_manager._store_polling_event(execution_id, event.to_json())
headers = {"Authorization": "Bearer test-key"}
response = self.client.get(f"/poll/human-input/{execution_id}", headers=headers)
assert response.status_code == 200
data = response.json()
assert len(data["events"]) == 1
assert data["events"][0]["type"] == "human_input_required"
assert data["events"][0]["execution_id"] == execution_id
def test_polling_endpoint_with_last_event_id(self):
"""Test polling endpoint with last_event_id parameter"""
execution_id = "test-execution-id"
event1 = HumanInputRequiredEvent(
execution_id=execution_id,
event_id="event-1"
)
event2 = HumanInputRequiredEvent(
execution_id=execution_id,
event_id="event-2"
)
event_stream_manager._store_polling_event(execution_id, event1.to_json())
event_stream_manager._store_polling_event(execution_id, event2.to_json())
headers = {"Authorization": "Bearer test-key"}
response = self.client.get(
f"/poll/human-input/{execution_id}?last_event_id=event-1",
headers=headers
)
assert response.status_code == 200
data = response.json()
assert len(data["events"]) == 1
assert data["events"][0]["event_id"] == "event-2"
def test_sse_endpoint_unauthorized(self):
"""Test SSE endpoint without authentication"""
response = self.client.get("/events/human-input/test-execution-id")
assert response.status_code == 401
def test_sse_endpoint_authorized(self):
"""Test SSE endpoint with authentication"""
headers = {"Authorization": "Bearer test-key"}
with self.client.stream("GET", "/events/human-input/test-execution-id", headers=headers) as response:
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
def test_websocket_endpoint_unauthorized(self):
"""Test WebSocket endpoint without authentication"""
with pytest.raises(Exception):
with self.client.websocket_connect("/ws/human-input/test-execution-id"):
pass
def test_websocket_endpoint_authorized(self):
"""Test WebSocket endpoint with authentication"""
with self.client.websocket_connect("/ws/human-input/test-execution-id?token=test-key") as websocket:
assert websocket is not None
def test_server_without_api_key(self):
"""Test server initialization without API key"""
server = HumanInputServer(host="localhost", port=8002)
client = TestClient(server.app)
response = client.get("/poll/human-input/test-execution-id")
assert response.status_code == 200
response = client.get("/events/human-input/test-execution-id")
assert response.status_code == 200
@pytest.mark.skipif(FASTAPI_AVAILABLE, reason="Testing import error handling")
def test_server_without_fastapi():
"""Test server initialization without FastAPI dependencies"""
with pytest.raises(ImportError, match="FastAPI dependencies not available"):
HumanInputServer()

View File

@@ -0,0 +1,209 @@
import pytest
import uuid
from unittest.mock import patch, MagicMock
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
class TestHumanInputEventIntegration:
"""Test integration between human input flow and event system"""
def setup_method(self):
"""Setup test environment"""
self.executor = CrewAgentExecutorMixin()
self.executor.crew = MagicMock()
self.executor.crew.id = str(uuid.uuid4())
self.executor.crew._train = False
self.executor.task = MagicMock()
self.executor.task.id = str(uuid.uuid4())
self.executor.agent = MagicMock()
self.executor.agent.id = str(uuid.uuid4())
self.executor._printer = MagicMock()
@patch('builtins.input', return_value='test feedback')
def test_human_input_emits_required_event(self, mock_input):
"""Test that human input emits HumanInputRequiredEvent"""
events_captured = []
def capture_event(event):
events_captured.append(event)
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
result = self.executor._ask_human_input("Test result")
assert result == 'test feedback'
assert len(events_captured) == 2
required_event = events_captured[0][1]
completed_event = events_captured[1][1]
assert isinstance(required_event, HumanInputRequiredEvent)
assert isinstance(completed_event, HumanInputCompletedEvent)
assert required_event.execution_id == str(self.executor.crew.id)
assert required_event.crew_id == str(self.executor.crew.id)
assert required_event.task_id == str(self.executor.task.id)
assert required_event.agent_id == str(self.executor.agent.id)
assert "HUMAN FEEDBACK" in required_event.prompt
assert required_event.context == "Test result"
assert required_event.event_id is not None
assert completed_event.execution_id == str(self.executor.crew.id)
assert completed_event.human_feedback == 'test feedback'
assert completed_event.event_id == required_event.event_id
@patch('builtins.input', return_value='training feedback')
def test_training_mode_human_input_events(self, mock_input):
"""Test human input events in training mode"""
self.executor.crew._train = True
events_captured = []
def capture_event(event):
events_captured.append(event)
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
result = self.executor._ask_human_input("Test result")
assert result == 'training feedback'
assert len(events_captured) == 2
required_event = events_captured[0][1]
assert isinstance(required_event, HumanInputRequiredEvent)
assert "TRAINING MODE" in required_event.prompt
@patch('builtins.input', return_value='')
def test_empty_feedback_events(self, mock_input):
"""Test events with empty feedback"""
events_captured = []
def capture_event(event):
events_captured.append(event)
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
result = self.executor._ask_human_input("Test result")
assert result == ''
assert len(events_captured) == 2
completed_event = events_captured[1][1]
assert isinstance(completed_event, HumanInputCompletedEvent)
assert completed_event.human_feedback == ''
def test_event_payload_structure(self):
"""Test that event payload matches GitHub issue specification"""
event = HumanInputRequiredEvent(
execution_id="test-execution-id",
crew_id="test-crew-id",
task_id="test-task-id",
agent_id="test-agent-id",
prompt="Test prompt",
context="Test context",
reason_flags={"ambiguity": True, "missing_field": False},
event_id="test-event-id"
)
payload = event.to_json()
assert payload["type"] == "human_input_required"
assert payload["execution_id"] == "test-execution-id"
assert payload["crew_id"] == "test-crew-id"
assert payload["task_id"] == "test-task-id"
assert payload["agent_id"] == "test-agent-id"
assert payload["prompt"] == "Test prompt"
assert payload["context"] == "Test context"
assert payload["reason_flags"]["ambiguity"] is True
assert payload["reason_flags"]["missing_field"] is False
assert payload["event_id"] == "test-event-id"
assert "timestamp" in payload
def test_completed_event_payload_structure(self):
"""Test that completed event payload is correct"""
event = HumanInputCompletedEvent(
execution_id="test-execution-id",
crew_id="test-crew-id",
task_id="test-task-id",
agent_id="test-agent-id",
event_id="test-event-id",
human_feedback="User feedback"
)
payload = event.to_json()
assert payload["type"] == "human_input_completed"
assert payload["execution_id"] == "test-execution-id"
assert payload["crew_id"] == "test-crew-id"
assert payload["task_id"] == "test-task-id"
assert payload["agent_id"] == "test-agent-id"
assert payload["event_id"] == "test-event-id"
assert payload["human_feedback"] == "User feedback"
assert "timestamp" in payload
@patch('builtins.input', side_effect=KeyboardInterrupt("Test interrupt"))
def test_human_input_exception_handling(self, mock_input):
"""Test that events are still emitted even if input is interrupted"""
events_captured = []
def capture_event(event):
events_captured.append(event)
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
with pytest.raises(KeyboardInterrupt):
self.executor._ask_human_input("Test result")
assert len(events_captured) == 1
required_event = events_captured[0][1]
assert isinstance(required_event, HumanInputRequiredEvent)
def test_human_input_without_crew(self):
"""Test human input events when crew is None"""
self.executor.crew = None
events_captured = []
def capture_event(event):
events_captured.append(event)
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event), \
patch('builtins.input', return_value='test'):
result = self.executor._ask_human_input("Test result")
assert len(events_captured) == 2
required_event = events_captured[0][1]
assert required_event.execution_id is None
assert required_event.crew_id is None
def test_human_input_without_task(self):
"""Test human input events when task is None"""
self.executor.task = None
events_captured = []
def capture_event(event):
events_captured.append(event)
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event), \
patch('builtins.input', return_value='test'):
result = self.executor._ask_human_input("Test result")
assert len(events_captured) == 2
required_event = events_captured[0][1]
assert required_event.task_id is None
def test_human_input_without_agent(self):
"""Test human input events when agent is None"""
self.executor.agent = None
events_captured = []
def capture_event(event):
events_captured.append(event)
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event), \
patch('builtins.input', return_value='test'):
result = self.executor._ask_human_input("Test result")
assert len(events_captured) == 2
required_event = events_captured[0][1]
assert required_event.agent_id is None