mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-15 19:18:30 +00:00
feat: implement client-initiated real-time human input event streams
- Add HumanInputRequiredEvent and HumanInputCompletedEvent to task_events.py - Implement HTTP server with WebSocket, SSE, and long-polling endpoints - Add EventStreamManager for connection and event management - Integrate event emission in agent executor human input flow - Add comprehensive tests for server endpoints and event integration - Add optional FastAPI dependencies for server functionality - Include documentation and example usage - Maintain backward compatibility with existing human input flow Addresses issue #3259 for WebSocket/SSE/long-polling human input events Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
@@ -8,6 +9,8 @@ from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.events.event_listener import event_listener
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -126,6 +129,42 @@ class CrewAgentExecutorMixin:
|
||||
|
||||
def _ask_human_input(self, final_answer: str) -> str:
|
||||
"""Prompt human input with mode-appropriate messaging."""
|
||||
event_id = str(uuid.uuid4())
|
||||
|
||||
execution_id = getattr(self.crew, 'id', None) if self.crew else None
|
||||
crew_id = str(execution_id) if execution_id else None
|
||||
task_id = str(getattr(self.task, 'id', None)) if self.task else None
|
||||
agent_id = str(getattr(self.agent, 'id', None)) if self.agent else None
|
||||
|
||||
if self.crew and getattr(self.crew, "_train", False):
|
||||
prompt_text = (
|
||||
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
|
||||
"This will be used to train better versions of the agent.\n"
|
||||
"Please provide detailed feedback about the result quality and reasoning process."
|
||||
)
|
||||
else:
|
||||
prompt_text = (
|
||||
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
|
||||
"Please follow these guidelines:\n"
|
||||
" - If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||
" - Otherwise, provide specific improvement requests.\n"
|
||||
" - You can provide multiple rounds of feedback until satisfied."
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
HumanInputRequiredEvent(
|
||||
execution_id=execution_id,
|
||||
crew_id=crew_id,
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
prompt=prompt_text,
|
||||
context=final_answer,
|
||||
event_id=event_id,
|
||||
reason_flags={"ambiguity": True, "missing_field": False}
|
||||
)
|
||||
)
|
||||
|
||||
event_listener.formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
@@ -133,7 +172,6 @@ class CrewAgentExecutorMixin:
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
||||
)
|
||||
|
||||
# Training mode prompt (single iteration)
|
||||
if self.crew and getattr(self.crew, "_train", False):
|
||||
prompt = (
|
||||
"\n\n=====\n"
|
||||
@@ -142,7 +180,6 @@ class CrewAgentExecutorMixin:
|
||||
"Please provide detailed feedback about the result quality and reasoning process.\n"
|
||||
"=====\n"
|
||||
)
|
||||
# Regular human-in-the-loop prompt (multiple iterations)
|
||||
else:
|
||||
prompt = (
|
||||
"\n\n=====\n"
|
||||
@@ -158,6 +195,19 @@ class CrewAgentExecutorMixin:
|
||||
response = input()
|
||||
if response.strip() != "":
|
||||
self._printer.print(content="\nProcessing your feedback...", color="cyan")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
HumanInputCompletedEvent(
|
||||
execution_id=execution_id,
|
||||
crew_id=crew_id,
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
event_id=event_id,
|
||||
human_feedback=response
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
finally:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
|
||||
4
src/crewai/server/__init__.py
Normal file
4
src/crewai/server/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .human_input_server import HumanInputServer
|
||||
from .event_stream_manager import EventStreamManager
|
||||
|
||||
__all__ = ["HumanInputServer", "EventStreamManager"]
|
||||
148
src/crewai/server/event_stream_manager.py
Normal file
148
src/crewai/server/event_stream_manager.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
|
||||
|
||||
|
||||
class EventStreamManager:
|
||||
"""Manages event streams for human input events"""
|
||||
|
||||
def __init__(self):
|
||||
self._websocket_connections: Dict[str, Set] = {}
|
||||
self._sse_connections: Dict[str, Set] = {}
|
||||
self._polling_events: Dict[str, List] = {}
|
||||
self._event_listeners_registered = False
|
||||
|
||||
def register_event_listeners(self):
|
||||
"""Register event listeners for human input events"""
|
||||
if self._event_listeners_registered:
|
||||
return
|
||||
|
||||
@crewai_event_bus.on(HumanInputRequiredEvent)
|
||||
def handle_human_input_required(event: HumanInputRequiredEvent):
|
||||
self._broadcast_event(event)
|
||||
|
||||
@crewai_event_bus.on(HumanInputCompletedEvent)
|
||||
def handle_human_input_completed(event: HumanInputCompletedEvent):
|
||||
self._broadcast_event(event)
|
||||
|
||||
self._event_listeners_registered = True
|
||||
|
||||
def add_websocket_connection(self, execution_id: str, websocket):
|
||||
"""Add a WebSocket connection for an execution"""
|
||||
if execution_id not in self._websocket_connections:
|
||||
self._websocket_connections[execution_id] = set()
|
||||
self._websocket_connections[execution_id].add(websocket)
|
||||
|
||||
def remove_websocket_connection(self, execution_id: str, websocket):
|
||||
"""Remove a WebSocket connection"""
|
||||
if execution_id in self._websocket_connections:
|
||||
self._websocket_connections[execution_id].discard(websocket)
|
||||
if not self._websocket_connections[execution_id]:
|
||||
del self._websocket_connections[execution_id]
|
||||
|
||||
def add_sse_connection(self, execution_id: str, queue):
|
||||
"""Add an SSE connection for an execution"""
|
||||
if execution_id not in self._sse_connections:
|
||||
self._sse_connections[execution_id] = set()
|
||||
self._sse_connections[execution_id].add(queue)
|
||||
|
||||
def remove_sse_connection(self, execution_id: str, queue):
|
||||
"""Remove an SSE connection"""
|
||||
if execution_id in self._sse_connections:
|
||||
self._sse_connections[execution_id].discard(queue)
|
||||
if not self._sse_connections[execution_id]:
|
||||
del self._sse_connections[execution_id]
|
||||
|
||||
def get_polling_events(self, execution_id: str, last_event_id: Optional[str] = None) -> List[Dict]:
|
||||
"""Get events for polling clients"""
|
||||
if execution_id not in self._polling_events:
|
||||
return []
|
||||
|
||||
events = self._polling_events[execution_id]
|
||||
|
||||
if last_event_id:
|
||||
try:
|
||||
last_index = next(
|
||||
i for i, event in enumerate(events)
|
||||
if event.get("event_id") == last_event_id
|
||||
)
|
||||
return events[last_index + 1:]
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return events
|
||||
|
||||
def _broadcast_event(self, event):
|
||||
"""Broadcast event to all relevant connections"""
|
||||
execution_id = getattr(event, 'execution_id', None)
|
||||
if not execution_id:
|
||||
return
|
||||
|
||||
event_data = self._serialize_event(event)
|
||||
|
||||
self._broadcast_websocket(execution_id, event_data)
|
||||
self._broadcast_sse(execution_id, event_data)
|
||||
self._store_polling_event(execution_id, event_data)
|
||||
|
||||
def _serialize_event(self, event) -> Dict:
|
||||
"""Serialize event to dictionary format"""
|
||||
event_dict = event.to_json()
|
||||
|
||||
if not event_dict.get("event_id"):
|
||||
event_dict["event_id"] = str(uuid.uuid4())
|
||||
|
||||
return event_dict
|
||||
|
||||
def _broadcast_websocket(self, execution_id: str, event_data: Dict):
|
||||
"""Broadcast event to WebSocket connections"""
|
||||
if execution_id not in self._websocket_connections:
|
||||
return
|
||||
|
||||
connections_to_remove = set()
|
||||
for websocket in self._websocket_connections[execution_id]:
|
||||
try:
|
||||
asyncio.create_task(websocket.send_text(json.dumps(event_data)))
|
||||
except Exception:
|
||||
connections_to_remove.add(websocket)
|
||||
|
||||
for websocket in connections_to_remove:
|
||||
self.remove_websocket_connection(execution_id, websocket)
|
||||
|
||||
def _broadcast_sse(self, execution_id: str, event_data: Dict):
|
||||
"""Broadcast event to SSE connections"""
|
||||
if execution_id not in self._sse_connections:
|
||||
return
|
||||
|
||||
connections_to_remove = set()
|
||||
for queue in self._sse_connections[execution_id]:
|
||||
try:
|
||||
queue.put_nowait(event_data)
|
||||
except Exception:
|
||||
connections_to_remove.add(queue)
|
||||
|
||||
for queue in connections_to_remove:
|
||||
self.remove_sse_connection(execution_id, queue)
|
||||
|
||||
def _store_polling_event(self, execution_id: str, event_data: Dict):
|
||||
"""Store event for polling clients"""
|
||||
if execution_id not in self._polling_events:
|
||||
self._polling_events[execution_id] = []
|
||||
|
||||
self._polling_events[execution_id].append(event_data)
|
||||
|
||||
if len(self._polling_events[execution_id]) > 100:
|
||||
self._polling_events[execution_id] = self._polling_events[execution_id][-100:]
|
||||
|
||||
def cleanup_execution(self, execution_id: str):
|
||||
"""Clean up all connections and events for an execution"""
|
||||
self._websocket_connections.pop(execution_id, None)
|
||||
self._sse_connections.pop(execution_id, None)
|
||||
self._polling_events.pop(execution_id, None)
|
||||
|
||||
|
||||
event_stream_manager = EventStreamManager()
|
||||
124
src/crewai/server/human_input_server.py
Normal file
124
src/crewai/server/human_input_server.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timezone
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import uvicorn
|
||||
FASTAPI_AVAILABLE = True
|
||||
except ImportError:
|
||||
FASTAPI_AVAILABLE = False
|
||||
|
||||
from .event_stream_manager import event_stream_manager
|
||||
|
||||
|
||||
class HumanInputServer:
|
||||
"""HTTP server for human input event streaming"""
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8000, api_key: Optional[str] = None):
|
||||
if not FASTAPI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"FastAPI dependencies not available. Install with: pip install fastapi uvicorn websockets"
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.api_key = api_key
|
||||
self.app = FastAPI(title="CrewAI Human Input Event Stream API")
|
||||
self.security = HTTPBearer() if api_key else None
|
||||
self._setup_routes()
|
||||
event_stream_manager.register_event_listeners()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Setup FastAPI routes"""
|
||||
|
||||
@self.app.websocket("/ws/human-input/{execution_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, execution_id: str):
|
||||
if self.api_key:
|
||||
token = websocket.query_params.get("token")
|
||||
if not token or token != self.api_key:
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
event_stream_manager.add_websocket_connection(execution_id, websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
event_stream_manager.remove_websocket_connection(execution_id, websocket)
|
||||
|
||||
@self.app.get("/events/human-input/{execution_id}")
|
||||
async def sse_endpoint(execution_id: str, credentials: Optional[HTTPAuthorizationCredentials] = Depends(self.security) if self.security else None):
|
||||
if self.api_key and (not credentials or credentials.credentials != self.api_key):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
async def event_stream():
|
||||
event_queue = asyncio.Queue()
|
||||
event_stream_manager.add_sse_connection(execution_id, event_queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event_data = await asyncio.wait_for(event_queue.get(), timeout=30.0)
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
yield "data: {\"type\": \"heartbeat\"}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
event_stream_manager.remove_sse_connection(execution_id, event_queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
}
|
||||
)
|
||||
|
||||
@self.app.get("/poll/human-input/{execution_id}")
|
||||
async def polling_endpoint(
|
||||
execution_id: str,
|
||||
last_event_id: Optional[str] = Query(None),
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(self.security) if self.security else None
|
||||
):
|
||||
if self.api_key and (not credentials or credentials.credentials != self.api_key):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
events = event_stream_manager.get_polling_events(execution_id, last_event_id)
|
||||
return {"events": events}
|
||||
|
||||
@self.app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
async def start_async(self):
|
||||
"""Start the server asynchronously"""
|
||||
config = uvicorn.Config(
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level="info"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
def start(self):
|
||||
"""Start the server synchronously"""
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level="info"
|
||||
)
|
||||
@@ -35,6 +35,8 @@ from .llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from .task_events import (
|
||||
HumanInputCompletedEvent,
|
||||
HumanInputRequiredEvent,
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
TaskStartedEvent,
|
||||
@@ -85,6 +87,8 @@ EventTypes = Union[
|
||||
TaskStartedEvent,
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
HumanInputRequiredEvent,
|
||||
HumanInputCompletedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
|
||||
@@ -82,3 +82,35 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
|
||||
|
||||
class HumanInputRequiredEvent(BaseEvent):
|
||||
"""Event emitted when human input is required during task execution"""
|
||||
|
||||
type: str = "human_input_required"
|
||||
execution_id: Optional[str] = None
|
||||
crew_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
reason_flags: Optional[dict] = None
|
||||
event_id: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class HumanInputCompletedEvent(BaseEvent):
|
||||
"""Event emitted when human input is completed"""
|
||||
|
||||
type: str = "human_input_completed"
|
||||
execution_id: Optional[str] = None
|
||||
crew_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
event_id: Optional[str] = None
|
||||
human_feedback: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
Reference in New Issue
Block a user