feat: implement client-initiated real-time human input event streams

- Add HumanInputRequiredEvent and HumanInputCompletedEvent to task_events.py
- Implement HTTP server with WebSocket, SSE, and long-polling endpoints
- Add EventStreamManager for connection and event management
- Integrate event emission in agent executor human input flow
- Add comprehensive tests for server endpoints and event integration
- Add optional FastAPI dependencies for server functionality
- Include documentation and example usage
- Maintain backward compatibility with existing human input flow

Addresses issue #3259 for WebSocket/SSE/long-polling human input events

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-08-02 16:05:43 +00:00
parent 88ed91561f
commit ea560d0af1
12 changed files with 1440 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
import time
import uuid
from typing import TYPE_CHECKING
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
@@ -8,6 +9,8 @@ from crewai.utilities.converter import ConverterError
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.printer import Printer
from crewai.utilities.events.event_listener import event_listener
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
if TYPE_CHECKING:
from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -126,6 +129,42 @@ class CrewAgentExecutorMixin:
def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging."""
event_id = str(uuid.uuid4())
execution_id = getattr(self.crew, 'id', None) if self.crew else None
crew_id = str(execution_id) if execution_id else None
task_id = str(getattr(self.task, 'id', None)) if self.task else None
agent_id = str(getattr(self.agent, 'id', None)) if self.agent else None
if self.crew and getattr(self.crew, "_train", False):
prompt_text = (
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
"This will be used to train better versions of the agent.\n"
"Please provide detailed feedback about the result quality and reasoning process."
)
else:
prompt_text = (
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
"Please follow these guidelines:\n"
" - If you are happy with the result, simply hit Enter without typing anything.\n"
" - Otherwise, provide specific improvement requests.\n"
" - You can provide multiple rounds of feedback until satisfied."
)
crewai_event_bus.emit(
self,
HumanInputRequiredEvent(
execution_id=execution_id,
crew_id=crew_id,
task_id=task_id,
agent_id=agent_id,
prompt=prompt_text,
context=final_answer,
event_id=event_id,
reason_flags={"ambiguity": True, "missing_field": False}
)
)
event_listener.formatter.pause_live_updates()
try:
@@ -133,7 +172,6 @@ class CrewAgentExecutorMixin:
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
)
# Training mode prompt (single iteration)
if self.crew and getattr(self.crew, "_train", False):
prompt = (
"\n\n=====\n"
@@ -142,7 +180,6 @@ class CrewAgentExecutorMixin:
"Please provide detailed feedback about the result quality and reasoning process.\n"
"=====\n"
)
# Regular human-in-the-loop prompt (multiple iterations)
else:
prompt = (
"\n\n=====\n"
@@ -158,6 +195,19 @@ class CrewAgentExecutorMixin:
response = input()
if response.strip() != "":
self._printer.print(content="\nProcessing your feedback...", color="cyan")
crewai_event_bus.emit(
self,
HumanInputCompletedEvent(
execution_id=execution_id,
crew_id=crew_id,
task_id=task_id,
agent_id=agent_id,
event_id=event_id,
human_feedback=response
)
)
return response
finally:
event_listener.formatter.resume_live_updates()

View File

@@ -0,0 +1,4 @@
from .human_input_server import HumanInputServer
from .event_stream_manager import EventStreamManager
__all__ = ["HumanInputServer", "EventStreamManager"]

View File

@@ -0,0 +1,148 @@
import asyncio
import json
import uuid
from typing import Dict, List, Optional, Set
from datetime import datetime, timezone
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
class EventStreamManager:
"""Manages event streams for human input events"""
def __init__(self):
self._websocket_connections: Dict[str, Set] = {}
self._sse_connections: Dict[str, Set] = {}
self._polling_events: Dict[str, List] = {}
self._event_listeners_registered = False
def register_event_listeners(self):
"""Register event listeners for human input events"""
if self._event_listeners_registered:
return
@crewai_event_bus.on(HumanInputRequiredEvent)
def handle_human_input_required(event: HumanInputRequiredEvent):
self._broadcast_event(event)
@crewai_event_bus.on(HumanInputCompletedEvent)
def handle_human_input_completed(event: HumanInputCompletedEvent):
self._broadcast_event(event)
self._event_listeners_registered = True
def add_websocket_connection(self, execution_id: str, websocket):
"""Add a WebSocket connection for an execution"""
if execution_id not in self._websocket_connections:
self._websocket_connections[execution_id] = set()
self._websocket_connections[execution_id].add(websocket)
def remove_websocket_connection(self, execution_id: str, websocket):
"""Remove a WebSocket connection"""
if execution_id in self._websocket_connections:
self._websocket_connections[execution_id].discard(websocket)
if not self._websocket_connections[execution_id]:
del self._websocket_connections[execution_id]
def add_sse_connection(self, execution_id: str, queue):
"""Add an SSE connection for an execution"""
if execution_id not in self._sse_connections:
self._sse_connections[execution_id] = set()
self._sse_connections[execution_id].add(queue)
def remove_sse_connection(self, execution_id: str, queue):
"""Remove an SSE connection"""
if execution_id in self._sse_connections:
self._sse_connections[execution_id].discard(queue)
if not self._sse_connections[execution_id]:
del self._sse_connections[execution_id]
def get_polling_events(self, execution_id: str, last_event_id: Optional[str] = None) -> List[Dict]:
"""Get events for polling clients"""
if execution_id not in self._polling_events:
return []
events = self._polling_events[execution_id]
if last_event_id:
try:
last_index = next(
i for i, event in enumerate(events)
if event.get("event_id") == last_event_id
)
return events[last_index + 1:]
except StopIteration:
pass
return events
def _broadcast_event(self, event):
"""Broadcast event to all relevant connections"""
execution_id = getattr(event, 'execution_id', None)
if not execution_id:
return
event_data = self._serialize_event(event)
self._broadcast_websocket(execution_id, event_data)
self._broadcast_sse(execution_id, event_data)
self._store_polling_event(execution_id, event_data)
def _serialize_event(self, event) -> Dict:
"""Serialize event to dictionary format"""
event_dict = event.to_json()
if not event_dict.get("event_id"):
event_dict["event_id"] = str(uuid.uuid4())
return event_dict
def _broadcast_websocket(self, execution_id: str, event_data: Dict):
"""Broadcast event to WebSocket connections"""
if execution_id not in self._websocket_connections:
return
connections_to_remove = set()
for websocket in self._websocket_connections[execution_id]:
try:
asyncio.create_task(websocket.send_text(json.dumps(event_data)))
except Exception:
connections_to_remove.add(websocket)
for websocket in connections_to_remove:
self.remove_websocket_connection(execution_id, websocket)
def _broadcast_sse(self, execution_id: str, event_data: Dict):
"""Broadcast event to SSE connections"""
if execution_id not in self._sse_connections:
return
connections_to_remove = set()
for queue in self._sse_connections[execution_id]:
try:
queue.put_nowait(event_data)
except Exception:
connections_to_remove.add(queue)
for queue in connections_to_remove:
self.remove_sse_connection(execution_id, queue)
def _store_polling_event(self, execution_id: str, event_data: Dict):
"""Store event for polling clients"""
if execution_id not in self._polling_events:
self._polling_events[execution_id] = []
self._polling_events[execution_id].append(event_data)
if len(self._polling_events[execution_id]) > 100:
self._polling_events[execution_id] = self._polling_events[execution_id][-100:]
def cleanup_execution(self, execution_id: str):
"""Clean up all connections and events for an execution"""
self._websocket_connections.pop(execution_id, None)
self._sse_connections.pop(execution_id, None)
self._polling_events.pop(execution_id, None)
event_stream_manager = EventStreamManager()

View File

@@ -0,0 +1,124 @@
import asyncio
import json
import queue
import uuid
from typing import Optional, Dict, Any
from datetime import datetime, timezone
try:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends, Query
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import uvicorn
FASTAPI_AVAILABLE = True
except ImportError:
FASTAPI_AVAILABLE = False
from .event_stream_manager import event_stream_manager
class HumanInputServer:
"""HTTP server for human input event streaming"""
def __init__(self, host: str = "localhost", port: int = 8000, api_key: Optional[str] = None):
if not FASTAPI_AVAILABLE:
raise ImportError(
"FastAPI dependencies not available. Install with: pip install fastapi uvicorn websockets"
)
self.host = host
self.port = port
self.api_key = api_key
self.app = FastAPI(title="CrewAI Human Input Event Stream API")
self.security = HTTPBearer() if api_key else None
self._setup_routes()
event_stream_manager.register_event_listeners()
def _setup_routes(self):
"""Setup FastAPI routes"""
@self.app.websocket("/ws/human-input/{execution_id}")
async def websocket_endpoint(websocket: WebSocket, execution_id: str):
if self.api_key:
token = websocket.query_params.get("token")
if not token or token != self.api_key:
await websocket.close(code=4001, reason="Unauthorized")
return
await websocket.accept()
event_stream_manager.add_websocket_connection(execution_id, websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
pass
finally:
event_stream_manager.remove_websocket_connection(execution_id, websocket)
@self.app.get("/events/human-input/{execution_id}")
async def sse_endpoint(execution_id: str, credentials: Optional[HTTPAuthorizationCredentials] = Depends(self.security) if self.security else None):
if self.api_key and (not credentials or credentials.credentials != self.api_key):
raise HTTPException(status_code=401, detail="Unauthorized")
async def event_stream():
event_queue = asyncio.Queue()
event_stream_manager.add_sse_connection(execution_id, event_queue)
try:
while True:
try:
event_data = await asyncio.wait_for(event_queue.get(), timeout=30.0)
yield f"data: {json.dumps(event_data)}\n\n"
except asyncio.TimeoutError:
yield "data: {\"type\": \"heartbeat\"}\n\n"
except asyncio.CancelledError:
pass
finally:
event_stream_manager.remove_sse_connection(execution_id, event_queue)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
}
)
@self.app.get("/poll/human-input/{execution_id}")
async def polling_endpoint(
execution_id: str,
last_event_id: Optional[str] = Query(None),
credentials: Optional[HTTPAuthorizationCredentials] = Depends(self.security) if self.security else None
):
if self.api_key and (not credentials or credentials.credentials != self.api_key):
raise HTTPException(status_code=401, detail="Unauthorized")
events = event_stream_manager.get_polling_events(execution_id, last_event_id)
return {"events": events}
@self.app.get("/health")
async def health_check():
return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()}
async def start_async(self):
"""Start the server asynchronously"""
config = uvicorn.Config(
self.app,
host=self.host,
port=self.port,
log_level="info"
)
server = uvicorn.Server(config)
await server.serve()
def start(self):
"""Start the server synchronously"""
uvicorn.run(
self.app,
host=self.host,
port=self.port,
log_level="info"
)

View File

@@ -35,6 +35,8 @@ from .llm_guardrail_events import (
LLMGuardrailStartedEvent,
)
from .task_events import (
HumanInputCompletedEvent,
HumanInputRequiredEvent,
TaskCompletedEvent,
TaskFailedEvent,
TaskStartedEvent,
@@ -85,6 +87,8 @@ EventTypes = Union[
TaskStartedEvent,
TaskCompletedEvent,
TaskFailedEvent,
HumanInputRequiredEvent,
HumanInputCompletedEvent,
FlowStartedEvent,
FlowFinishedEvent,
MethodExecutionStartedEvent,

View File

@@ -82,3 +82,35 @@ class TaskEvaluationEvent(BaseEvent):
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
class HumanInputRequiredEvent(BaseEvent):
"""Event emitted when human input is required during task execution"""
type: str = "human_input_required"
execution_id: Optional[str] = None
crew_id: Optional[str] = None
task_id: Optional[str] = None
agent_id: Optional[str] = None
prompt: Optional[str] = None
context: Optional[str] = None
reason_flags: Optional[dict] = None
event_id: Optional[str] = None
def __init__(self, **data):
super().__init__(**data)
class HumanInputCompletedEvent(BaseEvent):
"""Event emitted when human input is completed"""
type: str = "human_input_completed"
execution_id: Optional[str] = None
crew_id: Optional[str] = None
task_id: Optional[str] = None
agent_id: Optional[str] = None
event_id: Optional[str] = None
human_feedback: Optional[str] = None
def __init__(self, **data):
super().__init__(**data)