mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 13:28:13 +00:00
feat: implement client-initiated real-time human input event streams
- Add HumanInputRequiredEvent and HumanInputCompletedEvent to task_events.py - Implement HTTP server with WebSocket, SSE, and long-polling endpoints - Add EventStreamManager for connection and event management - Integrate event emission in agent executor human input flow - Add comprehensive tests for server endpoints and event integration - Add optional FastAPI dependencies for server functionality - Include documentation and example usage - Maintain backward compatibility with existing human input flow Addresses issue #3259 for WebSocket/SSE/long-polling human input events Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
264
docs/human_input_event_streaming.md
Normal file
264
docs/human_input_event_streaming.md
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
# Human Input Event Streaming
|
||||||
|
|
||||||
|
CrewAI supports real-time event streaming for human input events, allowing clients to receive notifications when human input is required during crew execution. This feature provides an alternative to webhook-only approaches and supports multiple streaming protocols.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
When a task requires human input (`task.human_input=True`), CrewAI emits events that can be consumed via:
|
||||||
|
|
||||||
|
- **WebSocket**: Real-time bidirectional communication
|
||||||
|
- **Server-Sent Events (SSE)**: Unidirectional server-to-client streaming
|
||||||
|
- **Long Polling**: HTTP-based polling for events
|
||||||
|
|
||||||
|
## Event Types
|
||||||
|
|
||||||
|
### HumanInputRequiredEvent
|
||||||
|
|
||||||
|
Emitted when human input is required during task execution.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "human_input_required",
|
||||||
|
"execution_id": "uuid",
|
||||||
|
"crew_id": "uuid",
|
||||||
|
"task_id": "uuid",
|
||||||
|
"agent_id": "uuid",
|
||||||
|
"prompt": "string",
|
||||||
|
"context": "string",
|
||||||
|
"timestamp": "ISO8601",
|
||||||
|
"event_id": "uuid",
|
||||||
|
"reason_flags": {
|
||||||
|
"ambiguity": true,
|
||||||
|
"missing_field": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### HumanInputCompletedEvent
|
||||||
|
|
||||||
|
Emitted when human input is completed.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "human_input_completed",
|
||||||
|
"execution_id": "uuid",
|
||||||
|
"crew_id": "uuid",
|
||||||
|
"task_id": "uuid",
|
||||||
|
"agent_id": "uuid",
|
||||||
|
"event_id": "uuid",
|
||||||
|
"human_feedback": "string",
|
||||||
|
"timestamp": "ISO8601"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Server Setup
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
Install the server dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install crewai[server]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Starting the Server
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai.server.human_input_server import HumanInputServer
|
||||||
|
|
||||||
|
# Start server with authentication
|
||||||
|
server = HumanInputServer(
|
||||||
|
host="localhost",
|
||||||
|
port=8000,
|
||||||
|
api_key="your-api-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start synchronously
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# Or start asynchronously
|
||||||
|
await server.start_async()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Options
|
||||||
|
|
||||||
|
- `host`: Server host (default: "localhost")
|
||||||
|
- `port`: Server port (default: 8000)
|
||||||
|
- `api_key`: Optional API key for authentication
|
||||||
|
|
||||||
|
## Client Integration
|
||||||
|
|
||||||
|
### WebSocket Client
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
async def websocket_client(execution_id: str, api_key: str = None):
|
||||||
|
uri = f"ws://localhost:8000/ws/human-input/{execution_id}"
|
||||||
|
if api_key:
|
||||||
|
uri += f"?token={api_key}"
|
||||||
|
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
async for message in websocket:
|
||||||
|
event = json.loads(message)
|
||||||
|
|
||||||
|
if event['type'] == 'human_input_required':
|
||||||
|
print(f"Human input needed: {event['prompt']}")
|
||||||
|
print(f"Context: {event['context']}")
|
||||||
|
elif event['type'] == 'human_input_completed':
|
||||||
|
print(f"Input completed: {event['human_feedback']}")
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
asyncio.run(websocket_client("execution-id", "api-key"))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Server-Sent Events (SSE) Client
|
||||||
|
|
||||||
|
```python
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def sse_client(execution_id: str, api_key: str = None):
|
||||||
|
url = f"http://localhost:8000/events/human-input/{execution_id}"
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
async with client.stream("GET", url, headers=headers) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
event = json.loads(line[6:])
|
||||||
|
if event.get('type') != 'heartbeat':
|
||||||
|
print(f"Received: {event}")
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
asyncio.run(sse_client("execution-id", "api-key"))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Long Polling Client
|
||||||
|
|
||||||
|
```python
|
||||||
|
import httpx
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def polling_client(execution_id: str, api_key: str = None):
|
||||||
|
url = f"http://localhost:8000/poll/human-input/{execution_id}"
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
last_event_id = None
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
while True:
|
||||||
|
params = {}
|
||||||
|
if last_event_id:
|
||||||
|
params["last_event_id"] = last_event_id
|
||||||
|
|
||||||
|
response = await client.get(url, headers=headers, params=params)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for event in data.get("events", []):
|
||||||
|
print(f"Received: {event}")
|
||||||
|
last_event_id = event.get('event_id')
|
||||||
|
|
||||||
|
await asyncio.sleep(2) # Poll every 2 seconds
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
asyncio.run(polling_client("execution-id", "api-key"))
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### WebSocket Endpoint
|
||||||
|
|
||||||
|
- **URL**: `/ws/human-input/{execution_id}`
|
||||||
|
- **Protocol**: WebSocket
|
||||||
|
- **Authentication**: Query parameter `token` (if API key configured)
|
||||||
|
|
||||||
|
### SSE Endpoint
|
||||||
|
|
||||||
|
- **URL**: `/events/human-input/{execution_id}`
|
||||||
|
- **Method**: GET
|
||||||
|
- **Headers**: `Authorization: Bearer <api_key>` (if configured)
|
||||||
|
- **Response**: `text/event-stream`
|
||||||
|
|
||||||
|
### Polling Endpoint
|
||||||
|
|
||||||
|
- **URL**: `/poll/human-input/{execution_id}`
|
||||||
|
- **Method**: GET
|
||||||
|
- **Headers**: `Authorization: Bearer <api_key>` (if configured)
|
||||||
|
- **Query Parameters**:
|
||||||
|
- `last_event_id`: Get events after this ID
|
||||||
|
- **Response**: JSON with `events` array
|
||||||
|
|
||||||
|
### Health Check
|
||||||
|
|
||||||
|
- **URL**: `/health`
|
||||||
|
- **Method**: GET
|
||||||
|
- **Response**: `{"status": "healthy", "timestamp": "..."}`
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
When an API key is configured, clients must authenticate:
|
||||||
|
|
||||||
|
- **WebSocket**: Include `token` query parameter
|
||||||
|
- **SSE/Polling**: Include `Authorization: Bearer <api_key>` header
|
||||||
|
|
||||||
|
## Integration with Crew Execution
|
||||||
|
|
||||||
|
The event streaming works automatically with existing crew execution:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
|
||||||
|
# Create crew with human input task
|
||||||
|
agent = Agent(...)
|
||||||
|
task = Task(
|
||||||
|
description="...",
|
||||||
|
human_input=True, # This enables human input
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
|
||||||
|
# Start event server (optional)
|
||||||
|
server = HumanInputServer(port=8000)
|
||||||
|
server_thread = threading.Thread(target=server.start, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
# Execute crew - events will be emitted automatically
|
||||||
|
result = crew.kickoff()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
- **Connection Errors**: Clients should implement reconnection logic
|
||||||
|
- **Authentication Errors**: Server returns 401 for invalid credentials
|
||||||
|
- **Rate Limiting**: Consider implementing client-side rate limiting for polling
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Use WebSocket** for real-time applications requiring immediate notifications
|
||||||
|
2. **Use SSE** for one-way streaming with automatic reconnection support
|
||||||
|
3. **Use Polling** for simple implementations or when WebSocket/SSE aren't available
|
||||||
|
4. **Implement Authentication** in production environments
|
||||||
|
5. **Handle Connection Failures** gracefully with retry logic
|
||||||
|
6. **Filter Events** by execution_id to avoid processing irrelevant events
|
||||||
|
|
||||||
|
## Backward Compatibility
|
||||||
|
|
||||||
|
This feature is fully backward compatible:
|
||||||
|
|
||||||
|
- Existing webhook functionality remains unchanged
|
||||||
|
- Console-based human input continues to work
|
||||||
|
- No breaking changes to existing APIs
|
||||||
|
|
||||||
|
## Example Applications
|
||||||
|
|
||||||
|
- **Web Dashboards**: Real-time crew execution monitoring
|
||||||
|
- **Mobile Apps**: Push notifications for human input requests
|
||||||
|
- **Integration Platforms**: Event-driven workflow automation
|
||||||
|
- **Monitoring Systems**: Real-time alerting and logging
|
||||||
242
examples/human_input_event_streaming.py
Normal file
242
examples/human_input_event_streaming.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Example demonstrating how to use the human input event streaming feature.
|
||||||
|
|
||||||
|
This example shows how to:
|
||||||
|
1. Start the human input event server
|
||||||
|
2. Connect to WebSocket/SSE/polling endpoints
|
||||||
|
3. Handle human input events in real-time
|
||||||
|
4. Integrate with crew execution
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
import httpx
|
||||||
|
from crewai.server.human_input_server import HumanInputServer
|
||||||
|
from crewai.server.event_stream_manager import event_stream_manager
|
||||||
|
DEPENDENCIES_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
DEPENDENCIES_AVAILABLE = False
|
||||||
|
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
from crewai.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def start_event_server(port: int = 8000, api_key: Optional[str] = None):
|
||||||
|
"""Start the human input event server in a separate thread"""
|
||||||
|
if not DEPENDENCIES_AVAILABLE:
|
||||||
|
print("Server dependencies not available. Install with: pip install crewai[server]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
server = HumanInputServer(host="localhost", port=port, api_key=api_key)
|
||||||
|
|
||||||
|
def run_server():
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
print(f"Human input event server started on http://localhost:{port}")
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
async def websocket_client_example(execution_id: str, api_key: Optional[str] = None):
|
||||||
|
"""Example WebSocket client for receiving human input events"""
|
||||||
|
if not DEPENDENCIES_AVAILABLE:
|
||||||
|
print("WebSocket dependencies not available")
|
||||||
|
return
|
||||||
|
|
||||||
|
uri = f"ws://localhost:8000/ws/human-input/{execution_id}"
|
||||||
|
if api_key:
|
||||||
|
uri += f"?token={api_key}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
print(f"Connected to WebSocket for execution {execution_id}")
|
||||||
|
|
||||||
|
async for message in websocket:
|
||||||
|
event_data = json.loads(message)
|
||||||
|
print(f"Received WebSocket event: {event_data['type']}")
|
||||||
|
|
||||||
|
if event_data['type'] == 'human_input_required':
|
||||||
|
print(f"Human input required for task: {event_data.get('task_id')}")
|
||||||
|
print(f"Prompt: {event_data.get('prompt')}")
|
||||||
|
print(f"Context: {event_data.get('context')}")
|
||||||
|
elif event_data['type'] == 'human_input_completed':
|
||||||
|
print(f"Human input completed: {event_data.get('human_feedback')}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def sse_client_example(execution_id: str, api_key: Optional[str] = None):
|
||||||
|
"""Example SSE client for receiving human input events"""
|
||||||
|
if not DEPENDENCIES_AVAILABLE:
|
||||||
|
print("SSE dependencies not available")
|
||||||
|
return
|
||||||
|
|
||||||
|
url = f"http://localhost:8000/events/human-input/{execution_id}"
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
async with client.stream("GET", url, headers=headers) as response:
|
||||||
|
print(f"Connected to SSE for execution {execution_id}")
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
event_data = json.loads(line[6:])
|
||||||
|
if event_data.get('type') != 'heartbeat':
|
||||||
|
print(f"Received SSE event: {event_data['type']}")
|
||||||
|
|
||||||
|
if event_data['type'] == 'human_input_required':
|
||||||
|
print(f"Human input required for task: {event_data.get('task_id')}")
|
||||||
|
print(f"Prompt: {event_data.get('prompt')}")
|
||||||
|
elif event_data['type'] == 'human_input_completed':
|
||||||
|
print(f"Human input completed: {event_data.get('human_feedback')}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"SSE error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def polling_client_example(execution_id: str, api_key: Optional[str] = None):
|
||||||
|
"""Example polling client for receiving human input events"""
|
||||||
|
if not DEPENDENCIES_AVAILABLE:
|
||||||
|
print("Polling dependencies not available")
|
||||||
|
return
|
||||||
|
|
||||||
|
url = f"http://localhost:8000/poll/human-input/{execution_id}"
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
last_event_id = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
print(f"Starting polling for execution {execution_id}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
params = {}
|
||||||
|
if last_event_id:
|
||||||
|
params["last_event_id"] = last_event_id
|
||||||
|
|
||||||
|
response = await client.get(url, headers=headers, params=params)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for event in data.get("events", []):
|
||||||
|
print(f"Received polling event: {event['type']}")
|
||||||
|
|
||||||
|
if event['type'] == 'human_input_required':
|
||||||
|
print(f"Human input required for task: {event.get('task_id')}")
|
||||||
|
print(f"Prompt: {event.get('prompt')}")
|
||||||
|
elif event['type'] == 'human_input_completed':
|
||||||
|
print(f"Human input completed: {event.get('human_feedback')}")
|
||||||
|
|
||||||
|
last_event_id = event.get('event_id')
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Polling error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_sample_crew():
|
||||||
|
"""Create a sample crew that requires human input"""
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-4o-mini")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Research Assistant",
|
||||||
|
goal="Help with research tasks and get human feedback",
|
||||||
|
backstory="You are a helpful research assistant that works with humans to complete tasks.",
|
||||||
|
llm=llm,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Research the latest trends in AI and provide a summary. Ask for human feedback on the findings.",
|
||||||
|
expected_output="A comprehensive summary of AI trends with human feedback incorporated.",
|
||||||
|
agent=agent,
|
||||||
|
human_input=True
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return crew
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main example function"""
|
||||||
|
print("CrewAI Human Input Event Streaming Example")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
api_key = "demo-api-key"
|
||||||
|
|
||||||
|
server = start_event_server(port=8000, api_key=api_key)
|
||||||
|
if not server:
|
||||||
|
return
|
||||||
|
|
||||||
|
crew = create_sample_crew()
|
||||||
|
execution_id = str(crew.id)
|
||||||
|
|
||||||
|
print(f"Crew execution ID: {execution_id}")
|
||||||
|
print("\nStarting event stream clients...")
|
||||||
|
|
||||||
|
websocket_task = asyncio.create_task(
|
||||||
|
websocket_client_example(execution_id, api_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
sse_task = asyncio.create_task(
|
||||||
|
sse_client_example(execution_id, api_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
polling_task = asyncio.create_task(
|
||||||
|
polling_client_example(execution_id, api_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
print("\nStarting crew execution...")
|
||||||
|
print("Note: This will prompt for human input in the console.")
|
||||||
|
print("The event streams above will also receive the events in real-time.")
|
||||||
|
|
||||||
|
def run_crew():
|
||||||
|
try:
|
||||||
|
result = crew.kickoff()
|
||||||
|
print(f"\nCrew execution completed: {result}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Crew execution error: {e}")
|
||||||
|
|
||||||
|
crew_thread = threading.Thread(target=run_crew)
|
||||||
|
crew_thread.start()
|
||||||
|
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
|
||||||
|
websocket_task.cancel()
|
||||||
|
sse_task.cancel()
|
||||||
|
polling_task.cancel()
|
||||||
|
|
||||||
|
crew_thread.join(timeout=5)
|
||||||
|
|
||||||
|
print("\nExample completed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if DEPENDENCIES_AVAILABLE:
|
||||||
|
asyncio.run(main())
|
||||||
|
else:
|
||||||
|
print("Dependencies not available. Install with: pip install crewai[server]")
|
||||||
|
print("This example requires FastAPI, uvicorn, websockets, and httpx.")
|
||||||
@@ -69,6 +69,11 @@ docling = [
|
|||||||
aisuite = [
|
aisuite = [
|
||||||
"aisuite>=0.1.10",
|
"aisuite>=0.1.10",
|
||||||
]
|
]
|
||||||
|
server = [
|
||||||
|
"fastapi>=0.104.0",
|
||||||
|
"uvicorn>=0.24.0",
|
||||||
|
"websockets>=12.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
@@ -86,6 +91,10 @@ dev-dependencies = [
|
|||||||
"pytest-timeout>=2.3.1",
|
"pytest-timeout>=2.3.1",
|
||||||
"pytest-xdist>=3.6.1",
|
"pytest-xdist>=3.6.1",
|
||||||
"pytest-split>=0.9.0",
|
"pytest-split>=0.9.0",
|
||||||
|
"fastapi>=0.104.0",
|
||||||
|
"uvicorn>=0.24.0",
|
||||||
|
"websockets>=12.0",
|
||||||
|
"httpx>=0.25.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||||
@@ -8,6 +9,8 @@ from crewai.utilities.converter import ConverterError
|
|||||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
from crewai.utilities.events.event_listener import event_listener
|
from crewai.utilities.events.event_listener import event_listener
|
||||||
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
@@ -126,6 +129,42 @@ class CrewAgentExecutorMixin:
|
|||||||
|
|
||||||
def _ask_human_input(self, final_answer: str) -> str:
|
def _ask_human_input(self, final_answer: str) -> str:
|
||||||
"""Prompt human input with mode-appropriate messaging."""
|
"""Prompt human input with mode-appropriate messaging."""
|
||||||
|
event_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
execution_id = getattr(self.crew, 'id', None) if self.crew else None
|
||||||
|
crew_id = str(execution_id) if execution_id else None
|
||||||
|
task_id = str(getattr(self.task, 'id', None)) if self.task else None
|
||||||
|
agent_id = str(getattr(self.agent, 'id', None)) if self.agent else None
|
||||||
|
|
||||||
|
if self.crew and getattr(self.crew, "_train", False):
|
||||||
|
prompt_text = (
|
||||||
|
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
|
||||||
|
"This will be used to train better versions of the agent.\n"
|
||||||
|
"Please provide detailed feedback about the result quality and reasoning process."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt_text = (
|
||||||
|
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
|
||||||
|
"Please follow these guidelines:\n"
|
||||||
|
" - If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||||
|
" - Otherwise, provide specific improvement requests.\n"
|
||||||
|
" - You can provide multiple rounds of feedback until satisfied."
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
HumanInputRequiredEvent(
|
||||||
|
execution_id=execution_id,
|
||||||
|
crew_id=crew_id,
|
||||||
|
task_id=task_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
prompt=prompt_text,
|
||||||
|
context=final_answer,
|
||||||
|
event_id=event_id,
|
||||||
|
reason_flags={"ambiguity": True, "missing_field": False}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
event_listener.formatter.pause_live_updates()
|
event_listener.formatter.pause_live_updates()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -133,7 +172,6 @@ class CrewAgentExecutorMixin:
|
|||||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training mode prompt (single iteration)
|
|
||||||
if self.crew and getattr(self.crew, "_train", False):
|
if self.crew and getattr(self.crew, "_train", False):
|
||||||
prompt = (
|
prompt = (
|
||||||
"\n\n=====\n"
|
"\n\n=====\n"
|
||||||
@@ -142,7 +180,6 @@ class CrewAgentExecutorMixin:
|
|||||||
"Please provide detailed feedback about the result quality and reasoning process.\n"
|
"Please provide detailed feedback about the result quality and reasoning process.\n"
|
||||||
"=====\n"
|
"=====\n"
|
||||||
)
|
)
|
||||||
# Regular human-in-the-loop prompt (multiple iterations)
|
|
||||||
else:
|
else:
|
||||||
prompt = (
|
prompt = (
|
||||||
"\n\n=====\n"
|
"\n\n=====\n"
|
||||||
@@ -158,6 +195,19 @@ class CrewAgentExecutorMixin:
|
|||||||
response = input()
|
response = input()
|
||||||
if response.strip() != "":
|
if response.strip() != "":
|
||||||
self._printer.print(content="\nProcessing your feedback...", color="cyan")
|
self._printer.print(content="\nProcessing your feedback...", color="cyan")
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
HumanInputCompletedEvent(
|
||||||
|
execution_id=execution_id,
|
||||||
|
crew_id=crew_id,
|
||||||
|
task_id=task_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
event_id=event_id,
|
||||||
|
human_feedback=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
finally:
|
finally:
|
||||||
event_listener.formatter.resume_live_updates()
|
event_listener.formatter.resume_live_updates()
|
||||||
|
|||||||
4
src/crewai/server/__init__.py
Normal file
4
src/crewai/server/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .human_input_server import HumanInputServer
|
||||||
|
from .event_stream_manager import EventStreamManager
|
||||||
|
|
||||||
|
__all__ = ["HumanInputServer", "EventStreamManager"]
|
||||||
148
src/crewai/server/event_stream_manager.py
Normal file
148
src/crewai/server/event_stream_manager.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
|
||||||
|
|
||||||
|
|
||||||
|
class EventStreamManager:
|
||||||
|
"""Manages event streams for human input events"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._websocket_connections: Dict[str, Set] = {}
|
||||||
|
self._sse_connections: Dict[str, Set] = {}
|
||||||
|
self._polling_events: Dict[str, List] = {}
|
||||||
|
self._event_listeners_registered = False
|
||||||
|
|
||||||
|
def register_event_listeners(self):
|
||||||
|
"""Register event listeners for human input events"""
|
||||||
|
if self._event_listeners_registered:
|
||||||
|
return
|
||||||
|
|
||||||
|
@crewai_event_bus.on(HumanInputRequiredEvent)
|
||||||
|
def handle_human_input_required(event: HumanInputRequiredEvent):
|
||||||
|
self._broadcast_event(event)
|
||||||
|
|
||||||
|
@crewai_event_bus.on(HumanInputCompletedEvent)
|
||||||
|
def handle_human_input_completed(event: HumanInputCompletedEvent):
|
||||||
|
self._broadcast_event(event)
|
||||||
|
|
||||||
|
self._event_listeners_registered = True
|
||||||
|
|
||||||
|
def add_websocket_connection(self, execution_id: str, websocket):
|
||||||
|
"""Add a WebSocket connection for an execution"""
|
||||||
|
if execution_id not in self._websocket_connections:
|
||||||
|
self._websocket_connections[execution_id] = set()
|
||||||
|
self._websocket_connections[execution_id].add(websocket)
|
||||||
|
|
||||||
|
def remove_websocket_connection(self, execution_id: str, websocket):
|
||||||
|
"""Remove a WebSocket connection"""
|
||||||
|
if execution_id in self._websocket_connections:
|
||||||
|
self._websocket_connections[execution_id].discard(websocket)
|
||||||
|
if not self._websocket_connections[execution_id]:
|
||||||
|
del self._websocket_connections[execution_id]
|
||||||
|
|
||||||
|
def add_sse_connection(self, execution_id: str, queue):
|
||||||
|
"""Add an SSE connection for an execution"""
|
||||||
|
if execution_id not in self._sse_connections:
|
||||||
|
self._sse_connections[execution_id] = set()
|
||||||
|
self._sse_connections[execution_id].add(queue)
|
||||||
|
|
||||||
|
def remove_sse_connection(self, execution_id: str, queue):
|
||||||
|
"""Remove an SSE connection"""
|
||||||
|
if execution_id in self._sse_connections:
|
||||||
|
self._sse_connections[execution_id].discard(queue)
|
||||||
|
if not self._sse_connections[execution_id]:
|
||||||
|
del self._sse_connections[execution_id]
|
||||||
|
|
||||||
|
def get_polling_events(self, execution_id: str, last_event_id: Optional[str] = None) -> List[Dict]:
|
||||||
|
"""Get events for polling clients"""
|
||||||
|
if execution_id not in self._polling_events:
|
||||||
|
return []
|
||||||
|
|
||||||
|
events = self._polling_events[execution_id]
|
||||||
|
|
||||||
|
if last_event_id:
|
||||||
|
try:
|
||||||
|
last_index = next(
|
||||||
|
i for i, event in enumerate(events)
|
||||||
|
if event.get("event_id") == last_event_id
|
||||||
|
)
|
||||||
|
return events[last_index + 1:]
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
def _broadcast_event(self, event):
|
||||||
|
"""Broadcast event to all relevant connections"""
|
||||||
|
execution_id = getattr(event, 'execution_id', None)
|
||||||
|
if not execution_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
event_data = self._serialize_event(event)
|
||||||
|
|
||||||
|
self._broadcast_websocket(execution_id, event_data)
|
||||||
|
self._broadcast_sse(execution_id, event_data)
|
||||||
|
self._store_polling_event(execution_id, event_data)
|
||||||
|
|
||||||
|
def _serialize_event(self, event) -> Dict:
|
||||||
|
"""Serialize event to dictionary format"""
|
||||||
|
event_dict = event.to_json()
|
||||||
|
|
||||||
|
if not event_dict.get("event_id"):
|
||||||
|
event_dict["event_id"] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
return event_dict
|
||||||
|
|
||||||
|
def _broadcast_websocket(self, execution_id: str, event_data: Dict):
|
||||||
|
"""Broadcast event to WebSocket connections"""
|
||||||
|
if execution_id not in self._websocket_connections:
|
||||||
|
return
|
||||||
|
|
||||||
|
connections_to_remove = set()
|
||||||
|
for websocket in self._websocket_connections[execution_id]:
|
||||||
|
try:
|
||||||
|
asyncio.create_task(websocket.send_text(json.dumps(event_data)))
|
||||||
|
except Exception:
|
||||||
|
connections_to_remove.add(websocket)
|
||||||
|
|
||||||
|
for websocket in connections_to_remove:
|
||||||
|
self.remove_websocket_connection(execution_id, websocket)
|
||||||
|
|
||||||
|
def _broadcast_sse(self, execution_id: str, event_data: Dict):
|
||||||
|
"""Broadcast event to SSE connections"""
|
||||||
|
if execution_id not in self._sse_connections:
|
||||||
|
return
|
||||||
|
|
||||||
|
connections_to_remove = set()
|
||||||
|
for queue in self._sse_connections[execution_id]:
|
||||||
|
try:
|
||||||
|
queue.put_nowait(event_data)
|
||||||
|
except Exception:
|
||||||
|
connections_to_remove.add(queue)
|
||||||
|
|
||||||
|
for queue in connections_to_remove:
|
||||||
|
self.remove_sse_connection(execution_id, queue)
|
||||||
|
|
||||||
|
def _store_polling_event(self, execution_id: str, event_data: Dict):
|
||||||
|
"""Store event for polling clients"""
|
||||||
|
if execution_id not in self._polling_events:
|
||||||
|
self._polling_events[execution_id] = []
|
||||||
|
|
||||||
|
self._polling_events[execution_id].append(event_data)
|
||||||
|
|
||||||
|
if len(self._polling_events[execution_id]) > 100:
|
||||||
|
self._polling_events[execution_id] = self._polling_events[execution_id][-100:]
|
||||||
|
|
||||||
|
def cleanup_execution(self, execution_id: str):
|
||||||
|
"""Clean up all connections and events for an execution"""
|
||||||
|
self._websocket_connections.pop(execution_id, None)
|
||||||
|
self._sse_connections.pop(execution_id, None)
|
||||||
|
self._polling_events.pop(execution_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
event_stream_manager = EventStreamManager()
|
||||||
124
src/crewai/server/human_input_server.py
Normal file
124
src/crewai/server/human_input_server.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
import uuid
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends, Query
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
import uvicorn
|
||||||
|
FASTAPI_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
FASTAPI_AVAILABLE = False
|
||||||
|
|
||||||
|
from .event_stream_manager import event_stream_manager
|
||||||
|
|
||||||
|
|
||||||
|
class HumanInputServer:
|
||||||
|
"""HTTP server for human input event streaming"""
|
||||||
|
|
||||||
|
def __init__(self, host: str = "localhost", port: int = 8000, api_key: Optional[str] = None):
|
||||||
|
if not FASTAPI_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"FastAPI dependencies not available. Install with: pip install fastapi uvicorn websockets"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.api_key = api_key
|
||||||
|
self.app = FastAPI(title="CrewAI Human Input Event Stream API")
|
||||||
|
self.security = HTTPBearer() if api_key else None
|
||||||
|
self._setup_routes()
|
||||||
|
event_stream_manager.register_event_listeners()
|
||||||
|
|
||||||
|
def _setup_routes(self):
|
||||||
|
"""Setup FastAPI routes"""
|
||||||
|
|
||||||
|
@self.app.websocket("/ws/human-input/{execution_id}")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket, execution_id: str):
|
||||||
|
if self.api_key:
|
||||||
|
token = websocket.query_params.get("token")
|
||||||
|
if not token or token != self.api_key:
|
||||||
|
await websocket.close(code=4001, reason="Unauthorized")
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
event_stream_manager.add_websocket_connection(execution_id, websocket)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await websocket.receive_text()
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
event_stream_manager.remove_websocket_connection(execution_id, websocket)
|
||||||
|
|
||||||
|
@self.app.get("/events/human-input/{execution_id}")
|
||||||
|
async def sse_endpoint(execution_id: str, credentials: Optional[HTTPAuthorizationCredentials] = Depends(self.security) if self.security else None):
|
||||||
|
if self.api_key and (not credentials or credentials.credentials != self.api_key):
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
async def event_stream():
|
||||||
|
event_queue = asyncio.Queue()
|
||||||
|
event_stream_manager.add_sse_connection(execution_id, event_queue)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event_data = await asyncio.wait_for(event_queue.get(), timeout=30.0)
|
||||||
|
yield f"data: {json.dumps(event_data)}\n\n"
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
yield "data: {\"type\": \"heartbeat\"}\n\n"
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
event_stream_manager.remove_sse_connection(execution_id, event_queue)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_stream(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@self.app.get("/poll/human-input/{execution_id}")
|
||||||
|
async def polling_endpoint(
|
||||||
|
execution_id: str,
|
||||||
|
last_event_id: Optional[str] = Query(None),
|
||||||
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(self.security) if self.security else None
|
||||||
|
):
|
||||||
|
if self.api_key and (not credentials or credentials.credentials != self.api_key):
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
events = event_stream_manager.get_polling_events(execution_id, last_event_id)
|
||||||
|
return {"events": events}
|
||||||
|
|
||||||
|
@self.app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||||
|
|
||||||
|
async def start_async(self):
|
||||||
|
"""Start the server asynchronously"""
|
||||||
|
config = uvicorn.Config(
|
||||||
|
self.app,
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
log_level="info"
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the server synchronously"""
|
||||||
|
uvicorn.run(
|
||||||
|
self.app,
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
log_level="info"
|
||||||
|
)
|
||||||
@@ -35,6 +35,8 @@ from .llm_guardrail_events import (
|
|||||||
LLMGuardrailStartedEvent,
|
LLMGuardrailStartedEvent,
|
||||||
)
|
)
|
||||||
from .task_events import (
|
from .task_events import (
|
||||||
|
HumanInputCompletedEvent,
|
||||||
|
HumanInputRequiredEvent,
|
||||||
TaskCompletedEvent,
|
TaskCompletedEvent,
|
||||||
TaskFailedEvent,
|
TaskFailedEvent,
|
||||||
TaskStartedEvent,
|
TaskStartedEvent,
|
||||||
@@ -85,6 +87,8 @@ EventTypes = Union[
|
|||||||
TaskStartedEvent,
|
TaskStartedEvent,
|
||||||
TaskCompletedEvent,
|
TaskCompletedEvent,
|
||||||
TaskFailedEvent,
|
TaskFailedEvent,
|
||||||
|
HumanInputRequiredEvent,
|
||||||
|
HumanInputCompletedEvent,
|
||||||
FlowStartedEvent,
|
FlowStartedEvent,
|
||||||
FlowFinishedEvent,
|
FlowFinishedEvent,
|
||||||
MethodExecutionStartedEvent,
|
MethodExecutionStartedEvent,
|
||||||
|
|||||||
@@ -82,3 +82,35 @@ class TaskEvaluationEvent(BaseEvent):
|
|||||||
and self.task.fingerprint.metadata
|
and self.task.fingerprint.metadata
|
||||||
):
|
):
|
||||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||||
|
|
||||||
|
|
||||||
|
class HumanInputRequiredEvent(BaseEvent):
|
||||||
|
"""Event emitted when human input is required during task execution"""
|
||||||
|
|
||||||
|
type: str = "human_input_required"
|
||||||
|
execution_id: Optional[str] = None
|
||||||
|
crew_id: Optional[str] = None
|
||||||
|
task_id: Optional[str] = None
|
||||||
|
agent_id: Optional[str] = None
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
context: Optional[str] = None
|
||||||
|
reason_flags: Optional[dict] = None
|
||||||
|
event_id: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
|
||||||
|
class HumanInputCompletedEvent(BaseEvent):
|
||||||
|
"""Event emitted when human input is completed"""
|
||||||
|
|
||||||
|
type: str = "human_input_completed"
|
||||||
|
execution_id: Optional[str] = None
|
||||||
|
crew_id: Optional[str] = None
|
||||||
|
task_id: Optional[str] = None
|
||||||
|
agent_id: Optional[str] = None
|
||||||
|
event_id: Optional[str] = None
|
||||||
|
human_feedback: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
|||||||
213
tests/server/test_event_stream_manager.py
Normal file
213
tests/server/test_event_stream_manager.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from crewai.server.event_stream_manager import EventStreamManager
|
||||||
|
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventStreamManager:
|
||||||
|
"""Test the event stream manager"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test environment"""
|
||||||
|
self.manager = EventStreamManager()
|
||||||
|
self.manager._websocket_connections.clear()
|
||||||
|
self.manager._sse_connections.clear()
|
||||||
|
self.manager._polling_events.clear()
|
||||||
|
|
||||||
|
def test_websocket_connection_management(self):
|
||||||
|
"""Test WebSocket connection management"""
|
||||||
|
execution_id = "test-execution"
|
||||||
|
websocket1 = MagicMock()
|
||||||
|
websocket2 = MagicMock()
|
||||||
|
|
||||||
|
self.manager.add_websocket_connection(execution_id, websocket1)
|
||||||
|
assert execution_id in self.manager._websocket_connections
|
||||||
|
assert websocket1 in self.manager._websocket_connections[execution_id]
|
||||||
|
|
||||||
|
self.manager.add_websocket_connection(execution_id, websocket2)
|
||||||
|
assert len(self.manager._websocket_connections[execution_id]) == 2
|
||||||
|
|
||||||
|
self.manager.remove_websocket_connection(execution_id, websocket1)
|
||||||
|
assert websocket1 not in self.manager._websocket_connections[execution_id]
|
||||||
|
assert websocket2 in self.manager._websocket_connections[execution_id]
|
||||||
|
|
||||||
|
self.manager.remove_websocket_connection(execution_id, websocket2)
|
||||||
|
assert execution_id not in self.manager._websocket_connections
|
||||||
|
|
||||||
|
def test_sse_connection_management(self):
|
||||||
|
"""Test SSE connection management"""
|
||||||
|
execution_id = "test-execution"
|
||||||
|
queue1 = MagicMock()
|
||||||
|
queue2 = MagicMock()
|
||||||
|
|
||||||
|
self.manager.add_sse_connection(execution_id, queue1)
|
||||||
|
assert execution_id in self.manager._sse_connections
|
||||||
|
assert queue1 in self.manager._sse_connections[execution_id]
|
||||||
|
|
||||||
|
self.manager.add_sse_connection(execution_id, queue2)
|
||||||
|
assert len(self.manager._sse_connections[execution_id]) == 2
|
||||||
|
|
||||||
|
self.manager.remove_sse_connection(execution_id, queue1)
|
||||||
|
assert queue1 not in self.manager._sse_connections[execution_id]
|
||||||
|
assert queue2 in self.manager._sse_connections[execution_id]
|
||||||
|
|
||||||
|
self.manager.remove_sse_connection(execution_id, queue2)
|
||||||
|
assert execution_id not in self.manager._sse_connections
|
||||||
|
|
||||||
|
def test_polling_events_storage(self):
|
||||||
|
"""Test polling events storage and retrieval"""
|
||||||
|
execution_id = "test-execution"
|
||||||
|
|
||||||
|
event1 = {"event_id": "event-1", "type": "test", "data": "test1"}
|
||||||
|
event2 = {"event_id": "event-2", "type": "test", "data": "test2"}
|
||||||
|
|
||||||
|
self.manager._store_polling_event(execution_id, event1)
|
||||||
|
self.manager._store_polling_event(execution_id, event2)
|
||||||
|
|
||||||
|
events = self.manager.get_polling_events(execution_id)
|
||||||
|
assert len(events) == 2
|
||||||
|
assert events[0] == event1
|
||||||
|
assert events[1] == event2
|
||||||
|
|
||||||
|
def test_polling_events_with_last_event_id(self):
|
||||||
|
"""Test polling events retrieval with last_event_id"""
|
||||||
|
execution_id = "test-execution"
|
||||||
|
|
||||||
|
event1 = {"event_id": "event-1", "type": "test", "data": "test1"}
|
||||||
|
event2 = {"event_id": "event-2", "type": "test", "data": "test2"}
|
||||||
|
event3 = {"event_id": "event-3", "type": "test", "data": "test3"}
|
||||||
|
|
||||||
|
self.manager._store_polling_event(execution_id, event1)
|
||||||
|
self.manager._store_polling_event(execution_id, event2)
|
||||||
|
self.manager._store_polling_event(execution_id, event3)
|
||||||
|
|
||||||
|
events = self.manager.get_polling_events(execution_id, "event-1")
|
||||||
|
assert len(events) == 2
|
||||||
|
assert events[0] == event2
|
||||||
|
assert events[1] == event3
|
||||||
|
|
||||||
|
def test_polling_events_limit(self):
|
||||||
|
"""Test polling events storage limit"""
|
||||||
|
execution_id = "test-execution"
|
||||||
|
|
||||||
|
for i in range(105):
|
||||||
|
event = {"event_id": f"event-{i}", "type": "test", "data": f"test{i}"}
|
||||||
|
self.manager._store_polling_event(execution_id, event)
|
||||||
|
|
||||||
|
events = self.manager.get_polling_events(execution_id)
|
||||||
|
assert len(events) == 100
|
||||||
|
assert events[0]["event_id"] == "event-5"
|
||||||
|
assert events[-1]["event_id"] == "event-104"
|
||||||
|
|
||||||
|
def test_event_serialization(self):
|
||||||
|
"""Test event serialization"""
|
||||||
|
event = HumanInputRequiredEvent(
|
||||||
|
execution_id="test-execution",
|
||||||
|
crew_id="test-crew",
|
||||||
|
task_id="test-task",
|
||||||
|
prompt="Test prompt"
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized = self.manager._serialize_event(event)
|
||||||
|
assert isinstance(serialized, dict)
|
||||||
|
assert serialized["type"] == "human_input_required"
|
||||||
|
assert serialized["execution_id"] == "test-execution"
|
||||||
|
assert "event_id" in serialized
|
||||||
|
|
||||||
|
def test_broadcast_websocket(self):
|
||||||
|
"""Test WebSocket broadcasting"""
|
||||||
|
execution_id = "test-execution"
|
||||||
|
websocket = MagicMock()
|
||||||
|
|
||||||
|
self.manager.add_websocket_connection(execution_id, websocket)
|
||||||
|
|
||||||
|
event_data = {"type": "test", "data": "test"}
|
||||||
|
|
||||||
|
with patch('asyncio.create_task') as mock_create_task:
|
||||||
|
self.manager._broadcast_websocket(execution_id, event_data)
|
||||||
|
mock_create_task.assert_called_once()
|
||||||
|
|
||||||
|
def test_broadcast_sse(self):
|
||||||
|
"""Test SSE broadcasting"""
|
||||||
|
execution_id = "test-execution"
|
||||||
|
queue = MagicMock()
|
||||||
|
|
||||||
|
self.manager.add_sse_connection(execution_id, queue)
|
||||||
|
|
||||||
|
event_data = {"type": "test", "data": "test"}
|
||||||
|
self.manager._broadcast_sse(execution_id, event_data)
|
||||||
|
|
||||||
|
queue.put_nowait.assert_called_once_with(event_data)
|
||||||
|
|
||||||
|
def test_broadcast_event(self):
|
||||||
|
"""Test complete event broadcasting"""
|
||||||
|
execution_id = "test-execution"
|
||||||
|
|
||||||
|
event = HumanInputRequiredEvent(
|
||||||
|
execution_id=execution_id,
|
||||||
|
crew_id="test-crew",
|
||||||
|
task_id="test-task",
|
||||||
|
prompt="Test prompt"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(self.manager, '_broadcast_websocket') as mock_ws, \
|
||||||
|
patch.object(self.manager, '_broadcast_sse') as mock_sse, \
|
||||||
|
patch.object(self.manager, '_store_polling_event') as mock_poll:
|
||||||
|
|
||||||
|
self.manager._broadcast_event(event)
|
||||||
|
|
||||||
|
mock_ws.assert_called_once()
|
||||||
|
mock_sse.assert_called_once()
|
||||||
|
mock_poll.assert_called_once()
|
||||||
|
|
||||||
|
def test_cleanup_execution(self):
|
||||||
|
"""Test execution cleanup"""
|
||||||
|
execution_id = "test-execution"
|
||||||
|
|
||||||
|
websocket = MagicMock()
|
||||||
|
queue = MagicMock()
|
||||||
|
event = {"event_id": "event-1", "type": "test"}
|
||||||
|
|
||||||
|
self.manager.add_websocket_connection(execution_id, websocket)
|
||||||
|
self.manager.add_sse_connection(execution_id, queue)
|
||||||
|
self.manager._store_polling_event(execution_id, event)
|
||||||
|
|
||||||
|
assert execution_id in self.manager._websocket_connections
|
||||||
|
assert execution_id in self.manager._sse_connections
|
||||||
|
assert execution_id in self.manager._polling_events
|
||||||
|
|
||||||
|
self.manager.cleanup_execution(execution_id)
|
||||||
|
|
||||||
|
assert execution_id not in self.manager._websocket_connections
|
||||||
|
assert execution_id not in self.manager._sse_connections
|
||||||
|
assert execution_id not in self.manager._polling_events
|
||||||
|
|
||||||
|
def test_register_event_listeners(self):
|
||||||
|
"""Test event listener registration"""
|
||||||
|
with patch('crewai.utilities.events.crewai_event_bus.crewai_event_bus.on') as mock_on:
|
||||||
|
self.manager.register_event_listeners()
|
||||||
|
assert mock_on.call_count == 2
|
||||||
|
|
||||||
|
self.manager.register_event_listeners()
|
||||||
|
assert mock_on.call_count == 2
|
||||||
|
|
||||||
|
def test_broadcast_event_without_execution_id(self):
|
||||||
|
"""Test broadcasting event without execution_id"""
|
||||||
|
event = HumanInputRequiredEvent(
|
||||||
|
crew_id="test-crew",
|
||||||
|
task_id="test-task",
|
||||||
|
prompt="Test prompt"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(self.manager, '_broadcast_websocket') as mock_ws, \
|
||||||
|
patch.object(self.manager, '_broadcast_sse') as mock_sse, \
|
||||||
|
patch.object(self.manager, '_store_polling_event') as mock_poll:
|
||||||
|
|
||||||
|
self.manager._broadcast_event(event)
|
||||||
|
|
||||||
|
mock_ws.assert_not_called()
|
||||||
|
mock_sse.assert_not_called()
|
||||||
|
mock_poll.assert_not_called()
|
||||||
139
tests/server/test_human_input_server.py
Normal file
139
tests/server/test_human_input_server.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from crewai.server.human_input_server import HumanInputServer
|
||||||
|
from crewai.server.event_stream_manager import event_stream_manager
|
||||||
|
from crewai.utilities.events.task_events import HumanInputRequiredEvent
|
||||||
|
FASTAPI_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
FASTAPI_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not FASTAPI_AVAILABLE, reason="FastAPI dependencies not available")
|
||||||
|
class TestHumanInputServer:
|
||||||
|
"""Test the human input server endpoints"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test environment"""
|
||||||
|
self.server = HumanInputServer(host="localhost", port=8001, api_key="test-key")
|
||||||
|
self.client = TestClient(self.server.app)
|
||||||
|
event_stream_manager._websocket_connections.clear()
|
||||||
|
event_stream_manager._sse_connections.clear()
|
||||||
|
event_stream_manager._polling_events.clear()
|
||||||
|
|
||||||
|
def test_health_endpoint(self):
|
||||||
|
"""Test health check endpoint"""
|
||||||
|
response = self.client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
assert "timestamp" in data
|
||||||
|
|
||||||
|
def test_polling_endpoint_unauthorized(self):
|
||||||
|
"""Test polling endpoint without authentication"""
|
||||||
|
response = self.client.get("/poll/human-input/test-execution-id")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_polling_endpoint_authorized(self):
|
||||||
|
"""Test polling endpoint with authentication"""
|
||||||
|
headers = {"Authorization": "Bearer test-key"}
|
||||||
|
response = self.client.get("/poll/human-input/test-execution-id", headers=headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "events" in data
|
||||||
|
assert isinstance(data["events"], list)
|
||||||
|
|
||||||
|
def test_polling_endpoint_with_events(self):
|
||||||
|
"""Test polling endpoint returns stored events"""
|
||||||
|
execution_id = "test-execution-id"
|
||||||
|
|
||||||
|
event = HumanInputRequiredEvent(
|
||||||
|
execution_id=execution_id,
|
||||||
|
crew_id="test-crew",
|
||||||
|
task_id="test-task",
|
||||||
|
agent_id="test-agent",
|
||||||
|
prompt="Test prompt",
|
||||||
|
context="Test context",
|
||||||
|
event_id="test-event-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
event_stream_manager._store_polling_event(execution_id, event.to_json())
|
||||||
|
|
||||||
|
headers = {"Authorization": "Bearer test-key"}
|
||||||
|
response = self.client.get(f"/poll/human-input/{execution_id}", headers=headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["events"]) == 1
|
||||||
|
assert data["events"][0]["type"] == "human_input_required"
|
||||||
|
assert data["events"][0]["execution_id"] == execution_id
|
||||||
|
|
||||||
|
def test_polling_endpoint_with_last_event_id(self):
|
||||||
|
"""Test polling endpoint with last_event_id parameter"""
|
||||||
|
execution_id = "test-execution-id"
|
||||||
|
|
||||||
|
event1 = HumanInputRequiredEvent(
|
||||||
|
execution_id=execution_id,
|
||||||
|
event_id="event-1"
|
||||||
|
)
|
||||||
|
event2 = HumanInputRequiredEvent(
|
||||||
|
execution_id=execution_id,
|
||||||
|
event_id="event-2"
|
||||||
|
)
|
||||||
|
|
||||||
|
event_stream_manager._store_polling_event(execution_id, event1.to_json())
|
||||||
|
event_stream_manager._store_polling_event(execution_id, event2.to_json())
|
||||||
|
|
||||||
|
headers = {"Authorization": "Bearer test-key"}
|
||||||
|
response = self.client.get(
|
||||||
|
f"/poll/human-input/{execution_id}?last_event_id=event-1",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["events"]) == 1
|
||||||
|
assert data["events"][0]["event_id"] == "event-2"
|
||||||
|
|
||||||
|
def test_sse_endpoint_unauthorized(self):
|
||||||
|
"""Test SSE endpoint without authentication"""
|
||||||
|
response = self.client.get("/events/human-input/test-execution-id")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_sse_endpoint_authorized(self):
|
||||||
|
"""Test SSE endpoint with authentication"""
|
||||||
|
headers = {"Authorization": "Bearer test-key"}
|
||||||
|
with self.client.stream("GET", "/events/human-input/test-execution-id", headers=headers) as response:
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||||
|
|
||||||
|
def test_websocket_endpoint_unauthorized(self):
|
||||||
|
"""Test WebSocket endpoint without authentication"""
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with self.client.websocket_connect("/ws/human-input/test-execution-id"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_websocket_endpoint_authorized(self):
|
||||||
|
"""Test WebSocket endpoint with authentication"""
|
||||||
|
with self.client.websocket_connect("/ws/human-input/test-execution-id?token=test-key") as websocket:
|
||||||
|
assert websocket is not None
|
||||||
|
|
||||||
|
def test_server_without_api_key(self):
|
||||||
|
"""Test server initialization without API key"""
|
||||||
|
server = HumanInputServer(host="localhost", port=8002)
|
||||||
|
client = TestClient(server.app)
|
||||||
|
|
||||||
|
response = client.get("/poll/human-input/test-execution-id")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response = client.get("/events/human-input/test-execution-id")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(FASTAPI_AVAILABLE, reason="Testing import error handling")
|
||||||
|
def test_server_without_fastapi():
|
||||||
|
"""Test server initialization without FastAPI dependencies"""
|
||||||
|
with pytest.raises(ImportError, match="FastAPI dependencies not available"):
|
||||||
|
HumanInputServer()
|
||||||
209
tests/test_human_input_event_integration.py
Normal file
209
tests/test_human_input_event_integration.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
|
||||||
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestHumanInputEventIntegration:
|
||||||
|
"""Test integration between human input flow and event system"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test environment"""
|
||||||
|
self.executor = CrewAgentExecutorMixin()
|
||||||
|
self.executor.crew = MagicMock()
|
||||||
|
self.executor.crew.id = str(uuid.uuid4())
|
||||||
|
self.executor.crew._train = False
|
||||||
|
self.executor.task = MagicMock()
|
||||||
|
self.executor.task.id = str(uuid.uuid4())
|
||||||
|
self.executor.agent = MagicMock()
|
||||||
|
self.executor.agent.id = str(uuid.uuid4())
|
||||||
|
self.executor._printer = MagicMock()
|
||||||
|
|
||||||
|
@patch('builtins.input', return_value='test feedback')
|
||||||
|
def test_human_input_emits_required_event(self, mock_input):
|
||||||
|
"""Test that human input emits HumanInputRequiredEvent"""
|
||||||
|
events_captured = []
|
||||||
|
|
||||||
|
def capture_event(event):
|
||||||
|
events_captured.append(event)
|
||||||
|
|
||||||
|
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
|
||||||
|
result = self.executor._ask_human_input("Test result")
|
||||||
|
|
||||||
|
assert result == 'test feedback'
|
||||||
|
assert len(events_captured) == 2
|
||||||
|
|
||||||
|
required_event = events_captured[0][1]
|
||||||
|
completed_event = events_captured[1][1]
|
||||||
|
|
||||||
|
assert isinstance(required_event, HumanInputRequiredEvent)
|
||||||
|
assert isinstance(completed_event, HumanInputCompletedEvent)
|
||||||
|
|
||||||
|
assert required_event.execution_id == str(self.executor.crew.id)
|
||||||
|
assert required_event.crew_id == str(self.executor.crew.id)
|
||||||
|
assert required_event.task_id == str(self.executor.task.id)
|
||||||
|
assert required_event.agent_id == str(self.executor.agent.id)
|
||||||
|
assert "HUMAN FEEDBACK" in required_event.prompt
|
||||||
|
assert required_event.context == "Test result"
|
||||||
|
assert required_event.event_id is not None
|
||||||
|
|
||||||
|
assert completed_event.execution_id == str(self.executor.crew.id)
|
||||||
|
assert completed_event.human_feedback == 'test feedback'
|
||||||
|
assert completed_event.event_id == required_event.event_id
|
||||||
|
|
||||||
|
@patch('builtins.input', return_value='training feedback')
|
||||||
|
def test_training_mode_human_input_events(self, mock_input):
|
||||||
|
"""Test human input events in training mode"""
|
||||||
|
self.executor.crew._train = True
|
||||||
|
events_captured = []
|
||||||
|
|
||||||
|
def capture_event(event):
|
||||||
|
events_captured.append(event)
|
||||||
|
|
||||||
|
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
|
||||||
|
result = self.executor._ask_human_input("Test result")
|
||||||
|
|
||||||
|
assert result == 'training feedback'
|
||||||
|
assert len(events_captured) == 2
|
||||||
|
|
||||||
|
required_event = events_captured[0][1]
|
||||||
|
assert isinstance(required_event, HumanInputRequiredEvent)
|
||||||
|
assert "TRAINING MODE" in required_event.prompt
|
||||||
|
|
||||||
|
@patch('builtins.input', return_value='')
|
||||||
|
def test_empty_feedback_events(self, mock_input):
|
||||||
|
"""Test events with empty feedback"""
|
||||||
|
events_captured = []
|
||||||
|
|
||||||
|
def capture_event(event):
|
||||||
|
events_captured.append(event)
|
||||||
|
|
||||||
|
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
|
||||||
|
result = self.executor._ask_human_input("Test result")
|
||||||
|
|
||||||
|
assert result == ''
|
||||||
|
assert len(events_captured) == 2
|
||||||
|
|
||||||
|
completed_event = events_captured[1][1]
|
||||||
|
assert isinstance(completed_event, HumanInputCompletedEvent)
|
||||||
|
assert completed_event.human_feedback == ''
|
||||||
|
|
||||||
|
def test_event_payload_structure(self):
|
||||||
|
"""Test that event payload matches GitHub issue specification"""
|
||||||
|
event = HumanInputRequiredEvent(
|
||||||
|
execution_id="test-execution-id",
|
||||||
|
crew_id="test-crew-id",
|
||||||
|
task_id="test-task-id",
|
||||||
|
agent_id="test-agent-id",
|
||||||
|
prompt="Test prompt",
|
||||||
|
context="Test context",
|
||||||
|
reason_flags={"ambiguity": True, "missing_field": False},
|
||||||
|
event_id="test-event-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = event.to_json()
|
||||||
|
|
||||||
|
assert payload["type"] == "human_input_required"
|
||||||
|
assert payload["execution_id"] == "test-execution-id"
|
||||||
|
assert payload["crew_id"] == "test-crew-id"
|
||||||
|
assert payload["task_id"] == "test-task-id"
|
||||||
|
assert payload["agent_id"] == "test-agent-id"
|
||||||
|
assert payload["prompt"] == "Test prompt"
|
||||||
|
assert payload["context"] == "Test context"
|
||||||
|
assert payload["reason_flags"]["ambiguity"] is True
|
||||||
|
assert payload["reason_flags"]["missing_field"] is False
|
||||||
|
assert payload["event_id"] == "test-event-id"
|
||||||
|
assert "timestamp" in payload
|
||||||
|
|
||||||
|
def test_completed_event_payload_structure(self):
|
||||||
|
"""Test that completed event payload is correct"""
|
||||||
|
event = HumanInputCompletedEvent(
|
||||||
|
execution_id="test-execution-id",
|
||||||
|
crew_id="test-crew-id",
|
||||||
|
task_id="test-task-id",
|
||||||
|
agent_id="test-agent-id",
|
||||||
|
event_id="test-event-id",
|
||||||
|
human_feedback="User feedback"
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = event.to_json()
|
||||||
|
|
||||||
|
assert payload["type"] == "human_input_completed"
|
||||||
|
assert payload["execution_id"] == "test-execution-id"
|
||||||
|
assert payload["crew_id"] == "test-crew-id"
|
||||||
|
assert payload["task_id"] == "test-task-id"
|
||||||
|
assert payload["agent_id"] == "test-agent-id"
|
||||||
|
assert payload["event_id"] == "test-event-id"
|
||||||
|
assert payload["human_feedback"] == "User feedback"
|
||||||
|
assert "timestamp" in payload
|
||||||
|
|
||||||
|
@patch('builtins.input', side_effect=KeyboardInterrupt("Test interrupt"))
|
||||||
|
def test_human_input_exception_handling(self, mock_input):
|
||||||
|
"""Test that events are still emitted even if input is interrupted"""
|
||||||
|
events_captured = []
|
||||||
|
|
||||||
|
def capture_event(event):
|
||||||
|
events_captured.append(event)
|
||||||
|
|
||||||
|
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
|
||||||
|
with pytest.raises(KeyboardInterrupt):
|
||||||
|
self.executor._ask_human_input("Test result")
|
||||||
|
|
||||||
|
assert len(events_captured) == 1
|
||||||
|
required_event = events_captured[0][1]
|
||||||
|
assert isinstance(required_event, HumanInputRequiredEvent)
|
||||||
|
|
||||||
|
def test_human_input_without_crew(self):
|
||||||
|
"""Test human input events when crew is None"""
|
||||||
|
self.executor.crew = None
|
||||||
|
events_captured = []
|
||||||
|
|
||||||
|
def capture_event(event):
|
||||||
|
events_captured.append(event)
|
||||||
|
|
||||||
|
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event), \
|
||||||
|
patch('builtins.input', return_value='test'):
|
||||||
|
|
||||||
|
result = self.executor._ask_human_input("Test result")
|
||||||
|
|
||||||
|
assert len(events_captured) == 2
|
||||||
|
required_event = events_captured[0][1]
|
||||||
|
assert required_event.execution_id is None
|
||||||
|
assert required_event.crew_id is None
|
||||||
|
|
||||||
|
def test_human_input_without_task(self):
|
||||||
|
"""Test human input events when task is None"""
|
||||||
|
self.executor.task = None
|
||||||
|
events_captured = []
|
||||||
|
|
||||||
|
def capture_event(event):
|
||||||
|
events_captured.append(event)
|
||||||
|
|
||||||
|
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event), \
|
||||||
|
patch('builtins.input', return_value='test'):
|
||||||
|
|
||||||
|
result = self.executor._ask_human_input("Test result")
|
||||||
|
|
||||||
|
assert len(events_captured) == 2
|
||||||
|
required_event = events_captured[0][1]
|
||||||
|
assert required_event.task_id is None
|
||||||
|
|
||||||
|
def test_human_input_without_agent(self):
|
||||||
|
"""Test human input events when agent is None"""
|
||||||
|
self.executor.agent = None
|
||||||
|
events_captured = []
|
||||||
|
|
||||||
|
def capture_event(event):
|
||||||
|
events_captured.append(event)
|
||||||
|
|
||||||
|
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event), \
|
||||||
|
patch('builtins.input', return_value='test'):
|
||||||
|
|
||||||
|
result = self.executor._ask_human_input("Test result")
|
||||||
|
|
||||||
|
assert len(events_captured) == 2
|
||||||
|
required_event = events_captured[0][1]
|
||||||
|
assert required_event.agent_id is None
|
||||||
Reference in New Issue
Block a user