mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Compare commits
6 Commits
devin/1753
...
devin/1754
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56f8ad5297 | ||
|
|
f133d13a00 | ||
|
|
ea560d0af1 | ||
|
|
88ed91561f | ||
|
|
9a347ad458 | ||
|
|
34c3075fdb |
264
docs/human_input_event_streaming.md
Normal file
264
docs/human_input_event_streaming.md
Normal file
@@ -0,0 +1,264 @@
|
||||
# Human Input Event Streaming
|
||||
|
||||
CrewAI supports real-time event streaming for human input events, allowing clients to receive notifications when human input is required during crew execution. This feature provides an alternative to webhook-only approaches and supports multiple streaming protocols.
|
||||
|
||||
## Overview
|
||||
|
||||
When a task requires human input (`task.human_input=True`), CrewAI emits events that can be consumed via:
|
||||
|
||||
- **WebSocket**: Real-time bidirectional communication
|
||||
- **Server-Sent Events (SSE)**: Unidirectional server-to-client streaming
|
||||
- **Long Polling**: HTTP-based polling for events
|
||||
|
||||
## Event Types
|
||||
|
||||
### HumanInputRequiredEvent
|
||||
|
||||
Emitted when human input is required during task execution.
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "human_input_required",
|
||||
"execution_id": "uuid",
|
||||
"crew_id": "uuid",
|
||||
"task_id": "uuid",
|
||||
"agent_id": "uuid",
|
||||
"prompt": "string",
|
||||
"context": "string",
|
||||
"timestamp": "ISO8601",
|
||||
"event_id": "uuid",
|
||||
"reason_flags": {
|
||||
"ambiguity": true,
|
||||
"missing_field": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### HumanInputCompletedEvent
|
||||
|
||||
Emitted when human input is completed.
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "human_input_completed",
|
||||
"execution_id": "uuid",
|
||||
"crew_id": "uuid",
|
||||
"task_id": "uuid",
|
||||
"agent_id": "uuid",
|
||||
"event_id": "uuid",
|
||||
"human_feedback": "string",
|
||||
"timestamp": "ISO8601"
|
||||
}
|
||||
```
|
||||
|
||||
## Server Setup
|
||||
|
||||
### Installation
|
||||
|
||||
Install the server dependencies:
|
||||
|
||||
```bash
|
||||
pip install crewai[server]
|
||||
```
|
||||
|
||||
### Starting the Server
|
||||
|
||||
```python
|
||||
from crewai.server.human_input_server import HumanInputServer
|
||||
|
||||
# Start server with authentication
|
||||
server = HumanInputServer(
|
||||
host="localhost",
|
||||
port=8000,
|
||||
api_key="your-api-key"
|
||||
)
|
||||
|
||||
# Start synchronously
|
||||
server.start()
|
||||
|
||||
# Or start asynchronously
|
||||
await server.start_async()
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
|
||||
- `host`: Server host (default: "localhost")
|
||||
- `port`: Server port (default: 8000)
|
||||
- `api_key`: Optional API key for authentication
|
||||
|
||||
## Client Integration
|
||||
|
||||
### WebSocket Client
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
|
||||
async def websocket_client(execution_id: str, api_key: str = None):
|
||||
uri = f"ws://localhost:8000/ws/human-input/{execution_id}"
|
||||
if api_key:
|
||||
uri += f"?token={api_key}"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
async for message in websocket:
|
||||
event = json.loads(message)
|
||||
|
||||
if event['type'] == 'human_input_required':
|
||||
print(f"Human input needed: {event['prompt']}")
|
||||
print(f"Context: {event['context']}")
|
||||
elif event['type'] == 'human_input_completed':
|
||||
print(f"Input completed: {event['human_feedback']}")
|
||||
|
||||
# Usage
|
||||
asyncio.run(websocket_client("execution-id", "api-key"))
|
||||
```
|
||||
|
||||
### Server-Sent Events (SSE) Client
|
||||
|
||||
```python
|
||||
import httpx
|
||||
import json
|
||||
|
||||
async def sse_client(execution_id: str, api_key: str = None):
|
||||
url = f"http://localhost:8000/events/human-input/{execution_id}"
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream("GET", url, headers=headers) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
event = json.loads(line[6:])
|
||||
if event.get('type') != 'heartbeat':
|
||||
print(f"Received: {event}")
|
||||
|
||||
# Usage
|
||||
asyncio.run(sse_client("execution-id", "api-key"))
|
||||
```
|
||||
|
||||
### Long Polling Client
|
||||
|
||||
```python
|
||||
import httpx
|
||||
import asyncio
|
||||
|
||||
async def polling_client(execution_id: str, api_key: str = None):
|
||||
url = f"http://localhost:8000/poll/human-input/{execution_id}"
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
last_event_id = None
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
while True:
|
||||
params = {}
|
||||
if last_event_id:
|
||||
params["last_event_id"] = last_event_id
|
||||
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
for event in data.get("events", []):
|
||||
print(f"Received: {event}")
|
||||
last_event_id = event.get('event_id')
|
||||
|
||||
await asyncio.sleep(2) # Poll every 2 seconds
|
||||
|
||||
# Usage
|
||||
asyncio.run(polling_client("execution-id", "api-key"))
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### WebSocket Endpoint
|
||||
|
||||
- **URL**: `/ws/human-input/{execution_id}`
|
||||
- **Protocol**: WebSocket
|
||||
- **Authentication**: Query parameter `token` (if API key configured)
|
||||
|
||||
### SSE Endpoint
|
||||
|
||||
- **URL**: `/events/human-input/{execution_id}`
|
||||
- **Method**: GET
|
||||
- **Headers**: `Authorization: Bearer <api_key>` (if configured)
|
||||
- **Response**: `text/event-stream`
|
||||
|
||||
### Polling Endpoint
|
||||
|
||||
- **URL**: `/poll/human-input/{execution_id}`
|
||||
- **Method**: GET
|
||||
- **Headers**: `Authorization: Bearer <api_key>` (if configured)
|
||||
- **Query Parameters**:
|
||||
- `last_event_id`: Get events after this ID
|
||||
- **Response**: JSON with `events` array
|
||||
|
||||
### Health Check
|
||||
|
||||
- **URL**: `/health`
|
||||
- **Method**: GET
|
||||
- **Response**: `{"status": "healthy", "timestamp": "..."}`
|
||||
|
||||
## Authentication
|
||||
|
||||
When an API key is configured, clients must authenticate:
|
||||
|
||||
- **WebSocket**: Include `token` query parameter
|
||||
- **SSE/Polling**: Include `Authorization: Bearer <api_key>` header
|
||||
|
||||
## Integration with Crew Execution
|
||||
|
||||
The event streaming works automatically with existing crew execution:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
# Create crew with human input task
|
||||
agent = Agent(...)
|
||||
task = Task(
|
||||
description="...",
|
||||
human_input=True, # This enables human input
|
||||
agent=agent
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
# Start event server (optional)
|
||||
server = HumanInputServer(port=8000)
|
||||
server_thread = threading.Thread(target=server.start, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
# Execute crew - events will be emitted automatically
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
- **Connection Errors**: Clients should implement reconnection logic
|
||||
- **Authentication Errors**: Server returns 401 for invalid credentials
|
||||
- **Rate Limiting**: Consider implementing client-side rate limiting for polling
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use WebSocket** for real-time applications requiring immediate notifications
|
||||
2. **Use SSE** for one-way streaming with automatic reconnection support
|
||||
3. **Use Polling** for simple implementations or when WebSocket/SSE aren't available
|
||||
4. **Implement Authentication** in production environments
|
||||
5. **Handle Connection Failures** gracefully with retry logic
|
||||
6. **Filter Events** by execution_id to avoid processing irrelevant events
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
This feature is fully backward compatible:
|
||||
|
||||
- Existing webhook functionality remains unchanged
|
||||
- Console-based human input continues to work
|
||||
- No breaking changes to existing APIs
|
||||
|
||||
## Example Applications
|
||||
|
||||
- **Web Dashboards**: Real-time crew execution monitoring
|
||||
- **Mobile Apps**: Push notifications for human input requests
|
||||
- **Integration Platforms**: Event-driven workflow automation
|
||||
- **Monitoring Systems**: Real-time alerting and logging
|
||||
241
examples/human_input_event_streaming.py
Normal file
241
examples/human_input_event_streaming.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Example demonstrating how to use the human input event streaming feature.
|
||||
|
||||
This example shows how to:
|
||||
1. Start the human input event server
|
||||
2. Connect to WebSocket/SSE/polling endpoints
|
||||
3. Handle human input events in real-time
|
||||
4. Integrate with crew execution
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import websockets
|
||||
import httpx
|
||||
from crewai.server.human_input_server import HumanInputServer
|
||||
DEPENDENCIES_AVAILABLE = True
|
||||
except ImportError:
|
||||
DEPENDENCIES_AVAILABLE = False
|
||||
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
def start_event_server(port: int = 8000, api_key: Optional[str] = None):
|
||||
"""Start the human input event server in a separate thread"""
|
||||
if not DEPENDENCIES_AVAILABLE:
|
||||
print("Server dependencies not available. Install with: pip install crewai[server]")
|
||||
return None
|
||||
|
||||
server = HumanInputServer(host="localhost", port=port, api_key=api_key)
|
||||
|
||||
def run_server():
|
||||
server.start()
|
||||
|
||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
time.sleep(2)
|
||||
print(f"Human input event server started on http://localhost:{port}")
|
||||
return server
|
||||
|
||||
|
||||
async def websocket_client_example(execution_id: str, api_key: Optional[str] = None):
|
||||
"""Example WebSocket client for receiving human input events"""
|
||||
if not DEPENDENCIES_AVAILABLE:
|
||||
print("WebSocket dependencies not available")
|
||||
return
|
||||
|
||||
uri = f"ws://localhost:8000/ws/human-input/{execution_id}"
|
||||
if api_key:
|
||||
uri += f"?token={api_key}"
|
||||
|
||||
try:
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print(f"Connected to WebSocket for execution {execution_id}")
|
||||
|
||||
async for message in websocket:
|
||||
event_data = json.loads(message)
|
||||
print(f"Received WebSocket event: {event_data['type']}")
|
||||
|
||||
if event_data['type'] == 'human_input_required':
|
||||
print(f"Human input required for task: {event_data.get('task_id')}")
|
||||
print(f"Prompt: {event_data.get('prompt')}")
|
||||
print(f"Context: {event_data.get('context')}")
|
||||
elif event_data['type'] == 'human_input_completed':
|
||||
print(f"Human input completed: {event_data.get('human_feedback')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
|
||||
|
||||
async def sse_client_example(execution_id: str, api_key: Optional[str] = None):
|
||||
"""Example SSE client for receiving human input events"""
|
||||
if not DEPENDENCIES_AVAILABLE:
|
||||
print("SSE dependencies not available")
|
||||
return
|
||||
|
||||
url = f"http://localhost:8000/events/human-input/{execution_id}"
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream("GET", url, headers=headers) as response:
|
||||
print(f"Connected to SSE for execution {execution_id}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
event_data = json.loads(line[6:])
|
||||
if event_data.get('type') != 'heartbeat':
|
||||
print(f"Received SSE event: {event_data['type']}")
|
||||
|
||||
if event_data['type'] == 'human_input_required':
|
||||
print(f"Human input required for task: {event_data.get('task_id')}")
|
||||
print(f"Prompt: {event_data.get('prompt')}")
|
||||
elif event_data['type'] == 'human_input_completed':
|
||||
print(f"Human input completed: {event_data.get('human_feedback')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"SSE error: {e}")
|
||||
|
||||
|
||||
async def polling_client_example(execution_id: str, api_key: Optional[str] = None):
|
||||
"""Example polling client for receiving human input events"""
|
||||
if not DEPENDENCIES_AVAILABLE:
|
||||
print("Polling dependencies not available")
|
||||
return
|
||||
|
||||
url = f"http://localhost:8000/poll/human-input/{execution_id}"
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
last_event_id = None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
print(f"Starting polling for execution {execution_id}")
|
||||
|
||||
while True:
|
||||
params = {}
|
||||
if last_event_id:
|
||||
params["last_event_id"] = last_event_id
|
||||
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
for event in data.get("events", []):
|
||||
print(f"Received polling event: {event['type']}")
|
||||
|
||||
if event['type'] == 'human_input_required':
|
||||
print(f"Human input required for task: {event.get('task_id')}")
|
||||
print(f"Prompt: {event.get('prompt')}")
|
||||
elif event['type'] == 'human_input_completed':
|
||||
print(f"Human input completed: {event.get('human_feedback')}")
|
||||
|
||||
last_event_id = event.get('event_id')
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Polling error: {e}")
|
||||
|
||||
|
||||
def create_sample_crew():
|
||||
"""Create a sample crew that requires human input"""
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Help with research tasks and get human feedback",
|
||||
backstory="You are a helpful research assistant that works with humans to complete tasks.",
|
||||
llm=llm,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research the latest trends in AI and provide a summary. Ask for human feedback on the findings.",
|
||||
expected_output="A comprehensive summary of AI trends with human feedback incorporated.",
|
||||
agent=agent,
|
||||
human_input=True
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
return crew
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main example function"""
|
||||
print("CrewAI Human Input Event Streaming Example")
|
||||
print("=" * 50)
|
||||
|
||||
api_key = "demo-api-key"
|
||||
|
||||
server = start_event_server(port=8000, api_key=api_key)
|
||||
if not server:
|
||||
return
|
||||
|
||||
crew = create_sample_crew()
|
||||
execution_id = str(crew.id)
|
||||
|
||||
print(f"Crew execution ID: {execution_id}")
|
||||
print("\nStarting event stream clients...")
|
||||
|
||||
websocket_task = asyncio.create_task(
|
||||
websocket_client_example(execution_id, api_key)
|
||||
)
|
||||
|
||||
sse_task = asyncio.create_task(
|
||||
sse_client_example(execution_id, api_key)
|
||||
)
|
||||
|
||||
polling_task = asyncio.create_task(
|
||||
polling_client_example(execution_id, api_key)
|
||||
)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
print("\nStarting crew execution...")
|
||||
print("Note: This will prompt for human input in the console.")
|
||||
print("The event streams above will also receive the events in real-time.")
|
||||
|
||||
def run_crew():
|
||||
try:
|
||||
result = crew.kickoff()
|
||||
print(f"\nCrew execution completed: {result}")
|
||||
except Exception as e:
|
||||
print(f"Crew execution error: {e}")
|
||||
|
||||
crew_thread = threading.Thread(target=run_crew)
|
||||
crew_thread.start()
|
||||
|
||||
await asyncio.sleep(30)
|
||||
|
||||
websocket_task.cancel()
|
||||
sse_task.cancel()
|
||||
polling_task.cancel()
|
||||
|
||||
crew_thread.join(timeout=5)
|
||||
|
||||
print("\nExample completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if DEPENDENCIES_AVAILABLE:
|
||||
asyncio.run(main())
|
||||
else:
|
||||
print("Dependencies not available. Install with: pip install crewai[server]")
|
||||
print("This example requires FastAPI, uvicorn, websockets, and httpx.")
|
||||
@@ -48,7 +48,7 @@ Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = ["crewai-tools~=0.58.0"]
|
||||
tools = ["crewai-tools~=0.59.0"]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
@@ -69,6 +69,11 @@ docling = [
|
||||
aisuite = [
|
||||
"aisuite>=0.1.10",
|
||||
]
|
||||
server = [
|
||||
"fastapi>=0.104.0",
|
||||
"uvicorn>=0.24.0",
|
||||
"websockets>=12.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
@@ -86,6 +91,10 @@ dev-dependencies = [
|
||||
"pytest-timeout>=2.3.1",
|
||||
"pytest-xdist>=3.6.1",
|
||||
"pytest-split>=0.9.0",
|
||||
"fastapi>=0.104.0",
|
||||
"uvicorn>=0.24.0",
|
||||
"websockets>=12.0",
|
||||
"httpx>=0.25.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -54,7 +54,7 @@ def _track_install_async():
|
||||
|
||||
_track_install_async()
|
||||
|
||||
__version__ = "0.150.0"
|
||||
__version__ = "0.152.0"
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"Crew",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
@@ -8,6 +9,8 @@ from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.events.event_listener import event_listener
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -126,6 +129,42 @@ class CrewAgentExecutorMixin:
|
||||
|
||||
def _ask_human_input(self, final_answer: str) -> str:
|
||||
"""Prompt human input with mode-appropriate messaging."""
|
||||
event_id = str(uuid.uuid4())
|
||||
|
||||
execution_id = getattr(self.crew, 'id', None) if self.crew else None
|
||||
crew_id = str(execution_id) if execution_id else None
|
||||
task_id = str(getattr(self.task, 'id', None)) if self.task else None
|
||||
agent_id = str(getattr(self.agent, 'id', None)) if self.agent else None
|
||||
|
||||
if self.crew and getattr(self.crew, "_train", False):
|
||||
prompt_text = (
|
||||
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
|
||||
"This will be used to train better versions of the agent.\n"
|
||||
"Please provide detailed feedback about the result quality and reasoning process."
|
||||
)
|
||||
else:
|
||||
prompt_text = (
|
||||
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
|
||||
"Please follow these guidelines:\n"
|
||||
" - If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||
" - Otherwise, provide specific improvement requests.\n"
|
||||
" - You can provide multiple rounds of feedback until satisfied."
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
HumanInputRequiredEvent(
|
||||
execution_id=execution_id,
|
||||
crew_id=crew_id,
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
prompt=prompt_text,
|
||||
context=final_answer,
|
||||
event_id=event_id,
|
||||
reason_flags={"ambiguity": True, "missing_field": False}
|
||||
)
|
||||
)
|
||||
|
||||
event_listener.formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
@@ -133,7 +172,6 @@ class CrewAgentExecutorMixin:
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
||||
)
|
||||
|
||||
# Training mode prompt (single iteration)
|
||||
if self.crew and getattr(self.crew, "_train", False):
|
||||
prompt = (
|
||||
"\n\n=====\n"
|
||||
@@ -142,7 +180,6 @@ class CrewAgentExecutorMixin:
|
||||
"Please provide detailed feedback about the result quality and reasoning process.\n"
|
||||
"=====\n"
|
||||
)
|
||||
# Regular human-in-the-loop prompt (multiple iterations)
|
||||
else:
|
||||
prompt = (
|
||||
"\n\n=====\n"
|
||||
@@ -158,6 +195,19 @@ class CrewAgentExecutorMixin:
|
||||
response = input()
|
||||
if response.strip() != "":
|
||||
self._printer.print(content="\nProcessing your feedback...", color="cyan")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
HumanInputCompletedEvent(
|
||||
execution_id=execution_id,
|
||||
crew_id=crew_id,
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
event_id=event_id,
|
||||
human_feedback=response
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
finally:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
import click
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||
from crewai.cli.create_crew import create_crew
|
||||
from crewai.cli.create_flow import create_flow
|
||||
@@ -227,7 +228,7 @@ def update():
|
||||
@crewai.command()
|
||||
def login():
|
||||
"""Sign Up/Login to CrewAI Enterprise."""
|
||||
Settings().clear()
|
||||
Settings().clear_user_settings()
|
||||
AuthenticationCommand().login()
|
||||
|
||||
|
||||
@@ -369,8 +370,8 @@ def org():
|
||||
pass
|
||||
|
||||
|
||||
@org.command()
|
||||
def list():
|
||||
@org.command("list")
|
||||
def org_list():
|
||||
"""List available organizations."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.list()
|
||||
@@ -391,5 +392,34 @@ def current():
|
||||
org_command.current()
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def config():
|
||||
"""CLI Configuration commands."""
|
||||
pass
|
||||
|
||||
|
||||
@config.command("list")
|
||||
def config_list():
|
||||
"""List all CLI configuration parameters."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.list()
|
||||
|
||||
|
||||
@config.command("set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def config_set(key: str, value: str):
|
||||
"""Set a CLI configuration parameter."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.set(key, value)
|
||||
|
||||
|
||||
@config.command("reset")
|
||||
def config_reset():
|
||||
"""Reset all CLI configuration parameters to default values."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.reset_all_settings()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
crewai()
|
||||
|
||||
@@ -4,10 +4,47 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
# Settings that are related to the user's account
|
||||
USER_SETTINGS_KEYS = [
|
||||
"tool_repository_username",
|
||||
"tool_repository_password",
|
||||
"org_name",
|
||||
"org_uuid",
|
||||
]
|
||||
|
||||
# Settings that are related to the CLI
|
||||
CLI_SETTINGS_KEYS = [
|
||||
"enterprise_base_url",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
DEFAULT_CLI_SETTINGS = {
|
||||
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
READONLY_SETTINGS_KEYS = [
|
||||
"org_name",
|
||||
"org_uuid",
|
||||
]
|
||||
|
||||
# Hidden settings - not displayed by the 'list' command and cannot be set by the user
|
||||
HIDDEN_SETTINGS_KEYS = [
|
||||
"config_path",
|
||||
"tool_repository_username",
|
||||
"tool_repository_password",
|
||||
]
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: Optional[str] = Field(
|
||||
default=DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
description="Base URL of the CrewAI Enterprise instance",
|
||||
)
|
||||
tool_repository_username: Optional[str] = Field(
|
||||
None, description="Username for interacting with the Tool Repository"
|
||||
)
|
||||
@@ -20,7 +57,7 @@ class Settings(BaseModel):
|
||||
org_uuid: Optional[str] = Field(
|
||||
None, description="UUID of the currently active organization"
|
||||
)
|
||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True)
|
||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
"""Load Settings from config path"""
|
||||
@@ -37,9 +74,16 @@ class Settings(BaseModel):
|
||||
merged_data = {**file_data, **data}
|
||||
super().__init__(config_path=config_path, **merged_data)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all settings"""
|
||||
self.config_path.unlink(missing_ok=True)
|
||||
def clear_user_settings(self) -> None:
|
||||
"""Clear all user settings"""
|
||||
self._reset_user_settings()
|
||||
self.dump()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all settings to default values"""
|
||||
self._reset_user_settings()
|
||||
self._reset_cli_settings()
|
||||
self.dump()
|
||||
|
||||
def dump(self) -> None:
|
||||
"""Save current settings to settings.json"""
|
||||
@@ -52,3 +96,13 @@ class Settings(BaseModel):
|
||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(updated_data, f, indent=4)
|
||||
|
||||
def _reset_user_settings(self) -> None:
|
||||
"""Reset all user settings to default values"""
|
||||
for key in USER_SETTINGS_KEYS:
|
||||
setattr(self, key, None)
|
||||
|
||||
def _reset_cli_settings(self) -> None:
|
||||
"""Reset all CLI settings to default values"""
|
||||
for key in CLI_SETTINGS_KEYS:
|
||||
setattr(self, key, DEFAULT_CLI_SETTINGS[key])
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
|
||||
|
||||
ENV_VARS = {
|
||||
"openai": [
|
||||
{
|
||||
@@ -320,5 +322,4 @@ DEFAULT_LLM_MODEL = "gpt-4o-mini"
|
||||
|
||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
|
||||
|
||||
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from os import getenv
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
@@ -6,6 +5,7 @@ import requests
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
|
||||
|
||||
class PlusAPI:
|
||||
@@ -29,7 +29,10 @@ class PlusAPI:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
|
||||
self.base_url = getenv("CREWAI_BASE_URL", "https://app.crewai.com")
|
||||
|
||||
self.base_url = (
|
||||
str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
)
|
||||
|
||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
@@ -108,7 +111,6 @@ class PlusAPI:
|
||||
|
||||
def create_crew(self, payload) -> requests.Response:
|
||||
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
||||
|
||||
|
||||
def get_organizations(self) -> requests.Response:
|
||||
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
||||
|
||||
0
src/crewai/cli/settings/__init__.py
Normal file
0
src/crewai/cli/settings/__init__.py
Normal file
67
src/crewai/cli/settings/main.py
Normal file
67
src/crewai/cli/settings/main.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from crewai.cli.command import BaseCommand
|
||||
from crewai.cli.config import Settings, READONLY_SETTINGS_KEYS, HIDDEN_SETTINGS_KEYS
|
||||
from typing import Any
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class SettingsCommand(BaseCommand):
|
||||
"""A class to handle CLI configuration commands."""
|
||||
|
||||
def __init__(self, settings_kwargs: dict[str, Any] = {}):
|
||||
super().__init__()
|
||||
self.settings = Settings(**settings_kwargs)
|
||||
|
||||
def list(self) -> None:
|
||||
"""List all CLI configuration parameters."""
|
||||
table = Table(title="CrewAI CLI Configuration")
|
||||
table.add_column("Setting", style="cyan", no_wrap=True)
|
||||
table.add_column("Value", style="green")
|
||||
table.add_column("Description", style="yellow")
|
||||
|
||||
# Add all settings to the table
|
||||
for field_name, field_info in Settings.model_fields.items():
|
||||
if field_name in HIDDEN_SETTINGS_KEYS:
|
||||
# Do not display hidden settings
|
||||
continue
|
||||
|
||||
current_value = getattr(self.settings, field_name)
|
||||
description = field_info.description or "No description available"
|
||||
display_value = (
|
||||
str(current_value) if current_value is not None else "Not set"
|
||||
)
|
||||
|
||||
table.add_row(field_name, display_value, description)
|
||||
|
||||
console.print(table)
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
"""Set a CLI configuration parameter."""
|
||||
|
||||
readonly_settings = READONLY_SETTINGS_KEYS + HIDDEN_SETTINGS_KEYS
|
||||
|
||||
if not hasattr(self.settings, key) or key in readonly_settings:
|
||||
console.print(
|
||||
f"Error: Unknown or readonly configuration key '{key}'",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Available keys:", style="yellow")
|
||||
for field_name in Settings.model_fields.keys():
|
||||
if field_name not in readonly_settings:
|
||||
console.print(f" - {field_name}", style="yellow")
|
||||
raise SystemExit(1)
|
||||
|
||||
setattr(self.settings, key, value)
|
||||
self.settings.dump()
|
||||
|
||||
console.print(f"Successfully set '{key}' to '{value}'", style="bold green")
|
||||
|
||||
def reset_all_settings(self) -> None:
|
||||
"""Reset all CLI configuration parameters to default values."""
|
||||
self.settings.reset()
|
||||
console.print(
|
||||
"Successfully reset all configuration parameters to default values. It is recommended to run [bold yellow]'crewai login'[/bold yellow] to re-authenticate.",
|
||||
style="bold green",
|
||||
)
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.150.0,<1.0.0"
|
||||
"crewai[tools]>=0.152.0,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.150.0,<1.0.0",
|
||||
"crewai[tools]>=0.152.0,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.150.0"
|
||||
"crewai[tools]>=0.152.0"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -308,7 +308,6 @@ class LLM(BaseLLM):
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning: Optional[bool] = None,
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
@@ -333,7 +332,6 @@ class LLM(BaseLLM):
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning = reasoning
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
@@ -408,15 +406,10 @@ class LLM(BaseLLM):
|
||||
"api_key": self.api_key,
|
||||
"stream": self.stream,
|
||||
"tools": tools,
|
||||
"reasoning_effort": self.reasoning_effort,
|
||||
**self.additional_params,
|
||||
}
|
||||
|
||||
if self.reasoning is False:
|
||||
# When reasoning is explicitly disabled, don't include reasoning_effort
|
||||
pass
|
||||
elif self.reasoning is True or self.reasoning_effort is not None:
|
||||
params["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
# Remove None values from params
|
||||
return {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from collections import defaultdict
|
||||
from mem0 import Memory, MemoryClient
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
|
||||
@@ -70,26 +71,32 @@ class Mem0Storage(Storage):
|
||||
"""
|
||||
Returns:
|
||||
dict: A filter dictionary containing AND conditions for querying data.
|
||||
- Includes user_id if memory_type is 'external'.
|
||||
- Includes user_id and agent_id if both are present.
|
||||
- Includes user_id if only user_id is present.
|
||||
- Includes agent_id if only agent_id is present.
|
||||
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present.
|
||||
"""
|
||||
filter = {
|
||||
"AND": []
|
||||
}
|
||||
filter = defaultdict(list)
|
||||
|
||||
# Add user_id condition if the memory type is external
|
||||
if self.memory_type == "external":
|
||||
filter["AND"].append({"user_id": self.config.get("user_id", "")})
|
||||
|
||||
# Add run_id condition if the memory type is short_term and a run ID is set
|
||||
if self.memory_type == "short_term" and self.mem0_run_id:
|
||||
filter["AND"].append({"run_id": self.mem0_run_id})
|
||||
else:
|
||||
user_id = self.config.get("user_id", "")
|
||||
agent_id = self.config.get("agent_id", "")
|
||||
|
||||
if user_id and agent_id:
|
||||
filter["OR"].append({"user_id": user_id})
|
||||
filter["OR"].append({"agent_id": agent_id})
|
||||
elif user_id:
|
||||
filter["AND"].append({"user_id": user_id})
|
||||
elif agent_id:
|
||||
filter["AND"].append({"agent_id": agent_id})
|
||||
|
||||
return filter
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
user_id = self.config.get("user_id", "")
|
||||
assistant_message = [{"role" : "assistant","content" : value}]
|
||||
assistant_message = [{"role" : "assistant","content" : value}]
|
||||
|
||||
base_metadata = {
|
||||
"short_term": "short_term",
|
||||
@@ -104,31 +111,32 @@ class Mem0Storage(Storage):
|
||||
"infer": self.infer
|
||||
}
|
||||
|
||||
if self.memory_type == "external":
|
||||
# MemoryClient-specific overrides
|
||||
if isinstance(self.memory, MemoryClient):
|
||||
params["includes"] = self.includes
|
||||
params["excludes"] = self.excludes
|
||||
params["output_format"] = "v1.1"
|
||||
params["version"] = "v2"
|
||||
|
||||
if self.memory_type == "short_term" and self.mem0_run_id:
|
||||
params["run_id"] = self.mem0_run_id
|
||||
|
||||
if user_id:
|
||||
params["user_id"] = user_id
|
||||
|
||||
|
||||
if params:
|
||||
# MemoryClient-specific overrides
|
||||
if isinstance(self.memory, MemoryClient):
|
||||
params["includes"] = self.includes
|
||||
params["excludes"] = self.excludes
|
||||
params["output_format"] = "v1.1"
|
||||
params["version"]="v2"
|
||||
if agent_id := self.config.get("agent_id", self._get_agent_name()):
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
if self.memory_type == "short_term":
|
||||
params["run_id"] = self.mem0_run_id
|
||||
|
||||
self.memory.add(assistant_message, **params)
|
||||
self.memory.add(assistant_message, **params)
|
||||
|
||||
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
|
||||
params = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"version": "v2",
|
||||
"output_format": "v1.1"
|
||||
}
|
||||
|
||||
|
||||
if user_id := self.config.get("user_id", ""):
|
||||
params["user_id"] = user_id
|
||||
|
||||
@@ -138,7 +146,7 @@ class Mem0Storage(Storage):
|
||||
"entities": {"type": "entity"},
|
||||
"external": {"type": "external"},
|
||||
}
|
||||
|
||||
|
||||
if self.memory_type in memory_type_map:
|
||||
params["metadata"] = memory_type_map[self.memory_type]
|
||||
if self.memory_type == "short_term":
|
||||
@@ -151,11 +159,28 @@ class Mem0Storage(Storage):
|
||||
params['threshold'] = score_threshold
|
||||
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["version"], params["run_id"], params['output_format']
|
||||
del params["metadata"], params["version"], params['output_format']
|
||||
if params.get("run_id"):
|
||||
del params["run_id"]
|
||||
|
||||
results = self.memory.search(**params)
|
||||
return [r for r in results["results"]]
|
||||
|
||||
|
||||
def reset(self):
|
||||
if self.memory:
|
||||
self.memory.reset()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _get_agent_name(self) -> str:
|
||||
if not self.crew:
|
||||
return ""
|
||||
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
|
||||
|
||||
4
src/crewai/server/__init__.py
Normal file
4
src/crewai/server/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .human_input_server import HumanInputServer
|
||||
from .event_stream_manager import EventStreamManager
|
||||
|
||||
__all__ = ["HumanInputServer", "EventStreamManager"]
|
||||
147
src/crewai/server/event_stream_manager.py
Normal file
147
src/crewai/server/event_stream_manager.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
|
||||
|
||||
|
||||
class EventStreamManager:
|
||||
"""Manages event streams for human input events"""
|
||||
|
||||
def __init__(self):
|
||||
self._websocket_connections: Dict[str, Set] = {}
|
||||
self._sse_connections: Dict[str, Set] = {}
|
||||
self._polling_events: Dict[str, List] = {}
|
||||
self._event_listeners_registered = False
|
||||
|
||||
def register_event_listeners(self):
|
||||
"""Register event listeners for human input events"""
|
||||
if self._event_listeners_registered:
|
||||
return
|
||||
|
||||
@crewai_event_bus.on(HumanInputRequiredEvent)
|
||||
def handle_human_input_required(event: HumanInputRequiredEvent):
|
||||
self._broadcast_event(event)
|
||||
|
||||
@crewai_event_bus.on(HumanInputCompletedEvent)
|
||||
def handle_human_input_completed(event: HumanInputCompletedEvent):
|
||||
self._broadcast_event(event)
|
||||
|
||||
self._event_listeners_registered = True
|
||||
|
||||
def add_websocket_connection(self, execution_id: str, websocket):
|
||||
"""Add a WebSocket connection for an execution"""
|
||||
if execution_id not in self._websocket_connections:
|
||||
self._websocket_connections[execution_id] = set()
|
||||
self._websocket_connections[execution_id].add(websocket)
|
||||
|
||||
def remove_websocket_connection(self, execution_id: str, websocket):
|
||||
"""Remove a WebSocket connection"""
|
||||
if execution_id in self._websocket_connections:
|
||||
self._websocket_connections[execution_id].discard(websocket)
|
||||
if not self._websocket_connections[execution_id]:
|
||||
del self._websocket_connections[execution_id]
|
||||
|
||||
def add_sse_connection(self, execution_id: str, queue):
|
||||
"""Add an SSE connection for an execution"""
|
||||
if execution_id not in self._sse_connections:
|
||||
self._sse_connections[execution_id] = set()
|
||||
self._sse_connections[execution_id].add(queue)
|
||||
|
||||
def remove_sse_connection(self, execution_id: str, queue):
|
||||
"""Remove an SSE connection"""
|
||||
if execution_id in self._sse_connections:
|
||||
self._sse_connections[execution_id].discard(queue)
|
||||
if not self._sse_connections[execution_id]:
|
||||
del self._sse_connections[execution_id]
|
||||
|
||||
def get_polling_events(self, execution_id: str, last_event_id: Optional[str] = None) -> List[Dict]:
|
||||
"""Get events for polling clients"""
|
||||
if execution_id not in self._polling_events:
|
||||
return []
|
||||
|
||||
events = self._polling_events[execution_id]
|
||||
|
||||
if last_event_id:
|
||||
try:
|
||||
last_index = next(
|
||||
i for i, event in enumerate(events)
|
||||
if event.get("event_id") == last_event_id
|
||||
)
|
||||
return events[last_index + 1:]
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return events
|
||||
|
||||
def _broadcast_event(self, event):
|
||||
"""Broadcast event to all relevant connections"""
|
||||
execution_id = getattr(event, 'execution_id', None)
|
||||
if not execution_id:
|
||||
return
|
||||
|
||||
event_data = self._serialize_event(event)
|
||||
|
||||
self._broadcast_websocket(execution_id, event_data)
|
||||
self._broadcast_sse(execution_id, event_data)
|
||||
self._store_polling_event(execution_id, event_data)
|
||||
|
||||
def _serialize_event(self, event) -> Dict:
|
||||
"""Serialize event to dictionary format"""
|
||||
event_dict = event.to_json()
|
||||
|
||||
if not event_dict.get("event_id"):
|
||||
event_dict["event_id"] = str(uuid.uuid4())
|
||||
|
||||
return event_dict
|
||||
|
||||
def _broadcast_websocket(self, execution_id: str, event_data: Dict):
|
||||
"""Broadcast event to WebSocket connections"""
|
||||
if execution_id not in self._websocket_connections:
|
||||
return
|
||||
|
||||
connections_to_remove = set()
|
||||
for websocket in self._websocket_connections[execution_id]:
|
||||
try:
|
||||
asyncio.create_task(websocket.send_text(json.dumps(event_data)))
|
||||
except Exception:
|
||||
connections_to_remove.add(websocket)
|
||||
|
||||
for websocket in connections_to_remove:
|
||||
self.remove_websocket_connection(execution_id, websocket)
|
||||
|
||||
def _broadcast_sse(self, execution_id: str, event_data: Dict):
|
||||
"""Broadcast event to SSE connections"""
|
||||
if execution_id not in self._sse_connections:
|
||||
return
|
||||
|
||||
connections_to_remove = set()
|
||||
for queue in self._sse_connections[execution_id]:
|
||||
try:
|
||||
queue.put_nowait(event_data)
|
||||
except Exception:
|
||||
connections_to_remove.add(queue)
|
||||
|
||||
for queue in connections_to_remove:
|
||||
self.remove_sse_connection(execution_id, queue)
|
||||
|
||||
def _store_polling_event(self, execution_id: str, event_data: Dict):
|
||||
"""Store event for polling clients"""
|
||||
if execution_id not in self._polling_events:
|
||||
self._polling_events[execution_id] = []
|
||||
|
||||
self._polling_events[execution_id].append(event_data)
|
||||
|
||||
if len(self._polling_events[execution_id]) > 100:
|
||||
self._polling_events[execution_id] = self._polling_events[execution_id][-100:]
|
||||
|
||||
def cleanup_execution(self, execution_id: str):
|
||||
"""Clean up all connections and events for an execution"""
|
||||
self._websocket_connections.pop(execution_id, None)
|
||||
self._sse_connections.pop(execution_id, None)
|
||||
self._polling_events.pop(execution_id, None)
|
||||
|
||||
|
||||
event_stream_manager = EventStreamManager()
|
||||
122
src/crewai/server/human_input_server.py
Normal file
122
src/crewai/server/human_input_server.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import uvicorn
|
||||
FASTAPI_AVAILABLE = True
|
||||
except ImportError:
|
||||
FASTAPI_AVAILABLE = False
|
||||
|
||||
from .event_stream_manager import event_stream_manager
|
||||
|
||||
|
||||
class HumanInputServer:
|
||||
"""HTTP server for human input event streaming"""
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8000, api_key: Optional[str] = None):
|
||||
if not FASTAPI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"FastAPI dependencies not available. Install with: pip install fastapi uvicorn websockets"
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.api_key = api_key
|
||||
self.app = FastAPI(title="CrewAI Human Input Event Stream API")
|
||||
self.security = HTTPBearer() if api_key else None
|
||||
self._setup_routes()
|
||||
event_stream_manager.register_event_listeners()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Setup FastAPI routes"""
|
||||
|
||||
@self.app.websocket("/ws/human-input/{execution_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, execution_id: str):
|
||||
if self.api_key:
|
||||
token = websocket.query_params.get("token")
|
||||
if not token or token != self.api_key:
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
event_stream_manager.add_websocket_connection(execution_id, websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
event_stream_manager.remove_websocket_connection(execution_id, websocket)
|
||||
|
||||
@self.app.get("/events/human-input/{execution_id}")
|
||||
async def sse_endpoint(execution_id: str, credentials: Optional[HTTPAuthorizationCredentials] = Depends(self.security) if self.security else None):
|
||||
if self.api_key and (not credentials or credentials.credentials != self.api_key):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
async def event_stream():
|
||||
event_queue = asyncio.Queue()
|
||||
event_stream_manager.add_sse_connection(execution_id, event_queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event_data = await asyncio.wait_for(event_queue.get(), timeout=30.0)
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
yield "data: {\"type\": \"heartbeat\"}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
event_stream_manager.remove_sse_connection(execution_id, event_queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
}
|
||||
)
|
||||
|
||||
@self.app.get("/poll/human-input/{execution_id}")
|
||||
async def polling_endpoint(
|
||||
execution_id: str,
|
||||
last_event_id: Optional[str] = Query(None),
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(self.security) if self.security else None
|
||||
):
|
||||
if self.api_key and (not credentials or credentials.credentials != self.api_key):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
events = event_stream_manager.get_polling_events(execution_id, last_event_id)
|
||||
return {"events": events}
|
||||
|
||||
@self.app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
async def start_async(self):
|
||||
"""Start the server asynchronously"""
|
||||
config = uvicorn.Config(
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level="info"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
def start(self):
|
||||
"""Start the server synchronously"""
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level="info"
|
||||
)
|
||||
@@ -35,6 +35,8 @@ from .llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from .task_events import (
|
||||
HumanInputCompletedEvent,
|
||||
HumanInputRequiredEvent,
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
TaskStartedEvent,
|
||||
@@ -85,6 +87,8 @@ EventTypes = Union[
|
||||
TaskStartedEvent,
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
HumanInputRequiredEvent,
|
||||
HumanInputCompletedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
|
||||
@@ -82,3 +82,35 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
|
||||
|
||||
class HumanInputRequiredEvent(BaseEvent):
|
||||
"""Event emitted when human input is required during task execution"""
|
||||
|
||||
type: str = "human_input_required"
|
||||
execution_id: Optional[str] = None
|
||||
crew_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
reason_flags: Optional[dict] = None
|
||||
event_id: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class HumanInputCompletedEvent(BaseEvent):
|
||||
"""Event emitted when human input is completed"""
|
||||
|
||||
type: str = "human_input_completed"
|
||||
execution_id: Optional[str] = None
|
||||
crew_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
event_id: Optional[str] = None
|
||||
human_feedback: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai import Agent, Task
|
||||
from crewai.llm import LLM
|
||||
@@ -260,31 +259,3 @@ def test_agent_with_function_calling_fallback():
|
||||
assert result == "4"
|
||||
assert "Reasoning Plan:" in task.description
|
||||
assert "Invalid JSON that will trigger fallback" in task.description
|
||||
|
||||
|
||||
def test_agent_with_llm_reasoning_disabled():
|
||||
"""Test agent with LLM reasoning disabled."""
|
||||
llm = LLM("gpt-3.5-turbo", reasoning=False)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the LLM reasoning parameter",
|
||||
backstory="I am a test agent created to verify the LLM reasoning parameter works correctly.",
|
||||
llm=llm,
|
||||
reasoning=False,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Simple math task: What's 3+3?",
|
||||
expected_output="The answer should be a number.",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
with patch.object(agent.llm, 'call') as mock_call:
|
||||
mock_call.return_value = "6"
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
assert result == "6"
|
||||
assert "Reasoning Plan:" not in task.description
|
||||
|
||||
@@ -4,7 +4,12 @@ import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.config import (
|
||||
Settings,
|
||||
USER_SETTINGS_KEYS,
|
||||
CLI_SETTINGS_KEYS,
|
||||
DEFAULT_CLI_SETTINGS,
|
||||
)
|
||||
|
||||
|
||||
class TestSettings(unittest.TestCase):
|
||||
@@ -52,6 +57,30 @@ class TestSettings(unittest.TestCase):
|
||||
self.assertEqual(settings.tool_repository_username, "new_user")
|
||||
self.assertEqual(settings.tool_repository_password, "file_pass")
|
||||
|
||||
def test_clear_user_settings(self):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
|
||||
settings = Settings(config_path=self.config_path, **user_settings)
|
||||
settings.clear_user_settings()
|
||||
|
||||
for key in user_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), None)
|
||||
|
||||
def test_reset_settings(self):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
)
|
||||
|
||||
settings.reset()
|
||||
|
||||
for key in user_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), None)
|
||||
for key in cli_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS[key])
|
||||
|
||||
def test_dump_new_settings(self):
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="user1"
|
||||
|
||||
@@ -6,7 +6,7 @@ from click.testing import CliRunner
|
||||
import requests
|
||||
|
||||
from crewai.cli.organization.main import OrganizationCommand
|
||||
from crewai.cli.cli import list, switch, current
|
||||
from crewai.cli.cli import org_list, switch, current
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -16,44 +16,44 @@ def runner():
|
||||
|
||||
@pytest.fixture
|
||||
def org_command():
|
||||
with patch.object(OrganizationCommand, '__init__', return_value=None):
|
||||
with patch.object(OrganizationCommand, "__init__", return_value=None):
|
||||
command = OrganizationCommand()
|
||||
yield command
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
with patch('crewai.cli.organization.main.Settings') as mock_settings_class:
|
||||
with patch("crewai.cli.organization.main.Settings") as mock_settings_class:
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_class.return_value = mock_settings_instance
|
||||
yield mock_settings_instance
|
||||
|
||||
|
||||
@patch('crewai.cli.cli.OrganizationCommand')
|
||||
@patch("crewai.cli.cli.OrganizationCommand")
|
||||
def test_org_list_command(mock_org_command_class, runner):
|
||||
mock_org_instance = MagicMock()
|
||||
mock_org_command_class.return_value = mock_org_instance
|
||||
|
||||
result = runner.invoke(list)
|
||||
result = runner.invoke(org_list)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_org_command_class.assert_called_once()
|
||||
mock_org_instance.list.assert_called_once()
|
||||
|
||||
|
||||
@patch('crewai.cli.cli.OrganizationCommand')
|
||||
@patch("crewai.cli.cli.OrganizationCommand")
|
||||
def test_org_switch_command(mock_org_command_class, runner):
|
||||
mock_org_instance = MagicMock()
|
||||
mock_org_command_class.return_value = mock_org_instance
|
||||
|
||||
result = runner.invoke(switch, ['test-id'])
|
||||
result = runner.invoke(switch, ["test-id"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_org_command_class.assert_called_once()
|
||||
mock_org_instance.switch.assert_called_once_with('test-id')
|
||||
mock_org_instance.switch.assert_called_once_with("test-id")
|
||||
|
||||
|
||||
@patch('crewai.cli.cli.OrganizationCommand')
|
||||
@patch("crewai.cli.cli.OrganizationCommand")
|
||||
def test_org_current_command(mock_org_command_class, runner):
|
||||
mock_org_instance = MagicMock()
|
||||
mock_org_command_class.return_value = mock_org_instance
|
||||
@@ -67,18 +67,18 @@ def test_org_current_command(mock_org_command_class, runner):
|
||||
|
||||
class TestOrganizationCommand(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with patch.object(OrganizationCommand, '__init__', return_value=None):
|
||||
with patch.object(OrganizationCommand, "__init__", return_value=None):
|
||||
self.org_command = OrganizationCommand()
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
|
||||
@patch('crewai.cli.organization.main.console')
|
||||
@patch('crewai.cli.organization.main.Table')
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
@patch("crewai.cli.organization.main.Table")
|
||||
def test_list_organizations_success(self, mock_table, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = [
|
||||
{"name": "Org 1", "uuid": "org-123"},
|
||||
{"name": "Org 2", "uuid": "org-456"}
|
||||
{"name": "Org 2", "uuid": "org-456"},
|
||||
]
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||
@@ -89,16 +89,14 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_table.assert_called_once_with(title="Your Organizations")
|
||||
mock_table.return_value.add_column.assert_has_calls([
|
||||
call("Name", style="cyan"),
|
||||
call("ID", style="green")
|
||||
])
|
||||
mock_table.return_value.add_row.assert_has_calls([
|
||||
call("Org 1", "org-123"),
|
||||
call("Org 2", "org-456")
|
||||
])
|
||||
mock_table.return_value.add_column.assert_has_calls(
|
||||
[call("Name", style="cyan"), call("ID", style="green")]
|
||||
)
|
||||
mock_table.return_value.add_row.assert_has_calls(
|
||||
[call("Org 1", "org-123"), call("Org 2", "org-456")]
|
||||
)
|
||||
|
||||
@patch('crewai.cli.organization.main.console')
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
def test_list_organizations_empty(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
@@ -110,33 +108,32 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"You don't belong to any organizations yet.",
|
||||
style="yellow"
|
||||
"You don't belong to any organizations yet.", style="yellow"
|
||||
)
|
||||
|
||||
@patch('crewai.cli.organization.main.console')
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
def test_list_organizations_api_error(self, mock_console):
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.side_effect = requests.exceptions.RequestException("API Error")
|
||||
self.org_command.plus_api_client.get_organizations.side_effect = (
|
||||
requests.exceptions.RequestException("API Error")
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
self.org_command.list()
|
||||
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"Failed to retrieve organization list: API Error",
|
||||
style="bold red"
|
||||
"Failed to retrieve organization list: API Error", style="bold red"
|
||||
)
|
||||
|
||||
@patch('crewai.cli.organization.main.console')
|
||||
@patch('crewai.cli.organization.main.Settings')
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
@patch("crewai.cli.organization.main.Settings")
|
||||
def test_switch_organization_success(self, mock_settings_class, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = [
|
||||
{"name": "Org 1", "uuid": "org-123"},
|
||||
{"name": "Test Org", "uuid": "test-id"}
|
||||
{"name": "Test Org", "uuid": "test-id"},
|
||||
]
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||
@@ -151,17 +148,16 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
assert mock_settings_instance.org_name == "Test Org"
|
||||
assert mock_settings_instance.org_uuid == "test-id"
|
||||
mock_console.print.assert_called_once_with(
|
||||
"Successfully switched to Test Org (test-id)",
|
||||
style="bold green"
|
||||
"Successfully switched to Test Org (test-id)", style="bold green"
|
||||
)
|
||||
|
||||
@patch('crewai.cli.organization.main.console')
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
def test_switch_organization_not_found(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = [
|
||||
{"name": "Org 1", "uuid": "org-123"},
|
||||
{"name": "Org 2", "uuid": "org-456"}
|
||||
{"name": "Org 2", "uuid": "org-456"},
|
||||
]
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||
@@ -170,12 +166,11 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"Organization with id 'non-existent-id' not found.",
|
||||
style="bold red"
|
||||
"Organization with id 'non-existent-id' not found.", style="bold red"
|
||||
)
|
||||
|
||||
@patch('crewai.cli.organization.main.console')
|
||||
@patch('crewai.cli.organization.main.Settings')
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
@patch("crewai.cli.organization.main.Settings")
|
||||
def test_current_organization_with_org(self, mock_settings_class, mock_console):
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.org_name = "Test Org"
|
||||
@@ -186,12 +181,11 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_not_called()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"Currently logged in to organization Test Org (test-id)",
|
||||
style="bold green"
|
||||
"Currently logged in to organization Test Org (test-id)", style="bold green"
|
||||
)
|
||||
|
||||
@patch('crewai.cli.organization.main.console')
|
||||
@patch('crewai.cli.organization.main.Settings')
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
@patch("crewai.cli.organization.main.Settings")
|
||||
def test_current_organization_without_org(self, mock_settings_class, mock_console):
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.org_uuid = None
|
||||
@@ -201,16 +195,14 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
|
||||
assert mock_console.print.call_count == 3
|
||||
mock_console.print.assert_any_call(
|
||||
"You're not currently logged in to any organization.",
|
||||
style="yellow"
|
||||
"You're not currently logged in to any organization.", style="yellow"
|
||||
)
|
||||
|
||||
@patch('crewai.cli.organization.main.console')
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
def test_list_organizations_unauthorized(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_http_error = requests.exceptions.HTTPError(
|
||||
"401 Client Error: Unauthorized",
|
||||
response=MagicMock(status_code=401)
|
||||
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
|
||||
)
|
||||
|
||||
mock_response.raise_for_status.side_effect = mock_http_error
|
||||
@@ -221,15 +213,14 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||
style="bold red"
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
@patch('crewai.cli.organization.main.console')
|
||||
@patch("crewai.cli.organization.main.console")
|
||||
def test_switch_organization_unauthorized(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_http_error = requests.exceptions.HTTPError(
|
||||
"401 Client Error: Unauthorized",
|
||||
response=MagicMock(status_code=401)
|
||||
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
|
||||
)
|
||||
|
||||
mock_response.raise_for_status.side_effect = mock_http_error
|
||||
@@ -240,5 +231,5 @@ class TestOrganizationCommand(unittest.TestCase):
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||
style="bold red"
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
|
||||
|
||||
class TestPlusAPI(unittest.TestCase):
|
||||
@@ -30,29 +30,41 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
def assert_request_with_org_id(self, mock_make_request, method: str, endpoint: str, **kwargs):
|
||||
def assert_request_with_org_id(
|
||||
self, mock_make_request, method: str, endpoint: str, **kwargs
|
||||
):
|
||||
mock_make_request.assert_called_once_with(
|
||||
method, f"https://app.crewai.com{endpoint}", headers={'Authorization': ANY, 'Content-Type': ANY, 'User-Agent': ANY, 'X-Crewai-Version': ANY, 'X-Crewai-Organization-Id': self.org_uuid}, **kwargs
|
||||
method,
|
||||
f"{DEFAULT_CREWAI_ENTERPRISE_URL}{endpoint}",
|
||||
headers={
|
||||
"Authorization": ANY,
|
||||
"Content-Type": ANY,
|
||||
"User-Agent": ANY,
|
||||
"X-Crewai-Version": ANY,
|
||||
"X-Crewai-Organization-Id": self.org_uuid,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("requests.Session.request")
|
||||
def test_login_to_tool_repository_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||
def test_login_to_tool_repository_with_org_uuid(
|
||||
self, mock_make_request, mock_settings_class
|
||||
):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
mock_settings_class.return_value = mock_settings
|
||||
# re-initialize Client
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.login_to_tool_repository()
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_make_request,
|
||||
'POST',
|
||||
'/crewai_plus/api/v1/tools/login'
|
||||
mock_make_request, "POST", "/crewai_plus/api/v1/tools/login"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@@ -66,28 +78,27 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"GET", "/crewai_plus/api/v1/agents/test_agent_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("requests.Session.request")
|
||||
def test_get_agent_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
mock_settings_class.return_value = mock_settings
|
||||
# re-initialize Client
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.get_agent("test_agent_handle")
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_make_request,
|
||||
"GET",
|
||||
"/crewai_plus/api/v1/agents/test_agent_handle"
|
||||
mock_make_request, "GET", "/crewai_plus/api/v1/agents/test_agent_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_get_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
@@ -98,12 +109,13 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("requests.Session.request")
|
||||
def test_get_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
mock_settings_class.return_value = mock_settings
|
||||
# re-initialize Client
|
||||
self.api = PlusAPI(self.api_key)
|
||||
@@ -115,9 +127,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
response = self.api.get_tool("test_tool_handle")
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_make_request,
|
||||
"GET",
|
||||
"/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
mock_make_request, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@@ -147,12 +157,13 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("requests.Session.request")
|
||||
def test_publish_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
mock_settings_class.return_value = mock_settings
|
||||
# re-initialize Client
|
||||
self.api = PlusAPI(self.api_key)
|
||||
@@ -160,7 +171,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
# Set up mock response
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
@@ -180,12 +191,9 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
}
|
||||
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/crewai_plus/api/v1/tools",
|
||||
json=expected_params
|
||||
mock_make_request, "POST", "/crewai_plus/api/v1/tools", json=expected_params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@@ -311,8 +319,11 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"POST", "/crewai_plus/api/v1/crews", json=payload
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
|
||||
def test_custom_base_url(self):
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
def test_custom_base_url(self, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.enterprise_base_url = "https://custom-url.com/api"
|
||||
mock_settings_class.return_value = mock_settings
|
||||
custom_api = PlusAPI("test_key")
|
||||
self.assertEqual(
|
||||
custom_api.base_url,
|
||||
|
||||
91
tests/cli/test_settings_command.py
Normal file
91
tests/cli/test_settings_command.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.config import (
|
||||
Settings,
|
||||
USER_SETTINGS_KEYS,
|
||||
CLI_SETTINGS_KEYS,
|
||||
DEFAULT_CLI_SETTINGS,
|
||||
HIDDEN_SETTINGS_KEYS,
|
||||
READONLY_SETTINGS_KEYS,
|
||||
)
|
||||
import shutil
|
||||
|
||||
|
||||
class TestSettingsCommand(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = Path(tempfile.mkdtemp())
|
||||
self.config_path = self.test_dir / "settings.json"
|
||||
self.settings = Settings(config_path=self.config_path)
|
||||
self.settings_command = SettingsCommand(
|
||||
settings_kwargs={"config_path": self.config_path}
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
@patch("crewai.cli.settings.main.console")
|
||||
@patch("crewai.cli.settings.main.Table")
|
||||
def test_list_settings(self, mock_table_class, mock_console):
|
||||
mock_table_instance = MagicMock()
|
||||
mock_table_class.return_value = mock_table_instance
|
||||
|
||||
self.settings_command.list()
|
||||
|
||||
# Tests that the table is created skipping hidden settings
|
||||
mock_table_instance.add_row.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
field_name,
|
||||
getattr(self.settings, field_name) or "Not set",
|
||||
field_info.description,
|
||||
)
|
||||
for field_name, field_info in Settings.model_fields.items()
|
||||
if field_name not in HIDDEN_SETTINGS_KEYS
|
||||
]
|
||||
)
|
||||
|
||||
# Tests that the table is printed
|
||||
mock_console.print.assert_called_once_with(mock_table_instance)
|
||||
|
||||
def test_set_valid_keys(self):
|
||||
valid_keys = Settings.model_fields.keys() - (
|
||||
READONLY_SETTINGS_KEYS + HIDDEN_SETTINGS_KEYS
|
||||
)
|
||||
for key in valid_keys:
|
||||
test_value = f"some_value_for_{key}"
|
||||
self.settings_command.set(key, test_value)
|
||||
self.assertEqual(getattr(self.settings_command.settings, key), test_value)
|
||||
|
||||
def test_set_invalid_key(self):
|
||||
with self.assertRaises(SystemExit):
|
||||
self.settings_command.set("invalid_key", "value")
|
||||
|
||||
def test_set_readonly_keys(self):
|
||||
for key in READONLY_SETTINGS_KEYS:
|
||||
with self.assertRaises(SystemExit):
|
||||
self.settings_command.set(key, "some_readonly_key_value")
|
||||
|
||||
def test_set_hidden_keys(self):
|
||||
for key in HIDDEN_SETTINGS_KEYS:
|
||||
with self.assertRaises(SystemExit):
|
||||
self.settings_command.set(key, "some_hidden_key_value")
|
||||
|
||||
def test_reset_all_settings(self):
|
||||
for key in USER_SETTINGS_KEYS + CLI_SETTINGS_KEYS:
|
||||
setattr(self.settings_command.settings, key, f"custom_value_for_{key}")
|
||||
self.settings_command.settings.dump()
|
||||
|
||||
self.settings_command.reset_all_settings()
|
||||
|
||||
print(USER_SETTINGS_KEYS)
|
||||
for key in USER_SETTINGS_KEYS:
|
||||
self.assertEqual(getattr(self.settings_command.settings, key), None)
|
||||
|
||||
for key in CLI_SETTINGS_KEYS:
|
||||
self.assertEqual(
|
||||
getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS[key]
|
||||
)
|
||||
@@ -711,99 +711,3 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
||||
formatted = ollama_llm._format_messages_for_provider(original_messages)
|
||||
|
||||
assert formatted == original_messages
|
||||
|
||||
|
||||
def test_llm_reasoning_parameter_false():
|
||||
"""Test that reasoning=False disables reasoning mode."""
|
||||
llm = LLM(model="ollama/qwen", reasoning=False)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Test response"
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm.call("Test message")
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
assert "reasoning_effort" not in kwargs
|
||||
|
||||
|
||||
def test_llm_reasoning_parameter_true():
|
||||
"""Test that reasoning=True enables reasoning mode."""
|
||||
llm = LLM(model="ollama/qwen", reasoning=True, reasoning_effort="medium")
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Test response"
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm.call("Test message")
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
assert kwargs["reasoning_effort"] == "medium"
|
||||
|
||||
|
||||
def test_llm_reasoning_parameter_none_with_reasoning_effort():
|
||||
"""Test that reasoning=None with reasoning_effort still includes reasoning_effort."""
|
||||
llm = LLM(model="ollama/qwen", reasoning=None, reasoning_effort="high")
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Test response"
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm.call("Test message")
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
assert kwargs["reasoning_effort"] == "high"
|
||||
|
||||
|
||||
def test_llm_reasoning_false_overrides_reasoning_effort():
|
||||
"""Test that reasoning=False overrides reasoning_effort."""
|
||||
llm = LLM(model="ollama/qwen", reasoning=False, reasoning_effort="high")
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Test response"
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm.call("Test message")
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
assert "reasoning_effort" not in kwargs
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_ollama_qwen_with_reasoning_disabled():
|
||||
"""Test Ollama Qwen model with reasoning disabled."""
|
||||
if not os.getenv("OLLAMA_BASE_URL"):
|
||||
pytest.skip("OLLAMA_BASE_URL not set; skipping test.")
|
||||
|
||||
llm = LLM(
|
||||
model="ollama/qwen",
|
||||
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
|
||||
reasoning=False
|
||||
)
|
||||
result = llm.call("What is 2+2?")
|
||||
assert isinstance(result, str)
|
||||
assert len(result.strip()) > 0
|
||||
|
||||
210
tests/server/test_event_stream_manager.py
Normal file
210
tests/server/test_event_stream_manager.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.server.event_stream_manager import EventStreamManager
|
||||
from crewai.utilities.events.task_events import HumanInputRequiredEvent
|
||||
|
||||
|
||||
class TestEventStreamManager:
|
||||
"""Test the event stream manager"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test environment"""
|
||||
self.manager = EventStreamManager()
|
||||
self.manager._websocket_connections.clear()
|
||||
self.manager._sse_connections.clear()
|
||||
self.manager._polling_events.clear()
|
||||
|
||||
def test_websocket_connection_management(self):
|
||||
"""Test WebSocket connection management"""
|
||||
execution_id = "test-execution"
|
||||
websocket1 = MagicMock()
|
||||
websocket2 = MagicMock()
|
||||
|
||||
self.manager.add_websocket_connection(execution_id, websocket1)
|
||||
assert execution_id in self.manager._websocket_connections
|
||||
assert websocket1 in self.manager._websocket_connections[execution_id]
|
||||
|
||||
self.manager.add_websocket_connection(execution_id, websocket2)
|
||||
assert len(self.manager._websocket_connections[execution_id]) == 2
|
||||
|
||||
self.manager.remove_websocket_connection(execution_id, websocket1)
|
||||
assert websocket1 not in self.manager._websocket_connections[execution_id]
|
||||
assert websocket2 in self.manager._websocket_connections[execution_id]
|
||||
|
||||
self.manager.remove_websocket_connection(execution_id, websocket2)
|
||||
assert execution_id not in self.manager._websocket_connections
|
||||
|
||||
def test_sse_connection_management(self):
|
||||
"""Test SSE connection management"""
|
||||
execution_id = "test-execution"
|
||||
queue1 = MagicMock()
|
||||
queue2 = MagicMock()
|
||||
|
||||
self.manager.add_sse_connection(execution_id, queue1)
|
||||
assert execution_id in self.manager._sse_connections
|
||||
assert queue1 in self.manager._sse_connections[execution_id]
|
||||
|
||||
self.manager.add_sse_connection(execution_id, queue2)
|
||||
assert len(self.manager._sse_connections[execution_id]) == 2
|
||||
|
||||
self.manager.remove_sse_connection(execution_id, queue1)
|
||||
assert queue1 not in self.manager._sse_connections[execution_id]
|
||||
assert queue2 in self.manager._sse_connections[execution_id]
|
||||
|
||||
self.manager.remove_sse_connection(execution_id, queue2)
|
||||
assert execution_id not in self.manager._sse_connections
|
||||
|
||||
def test_polling_events_storage(self):
|
||||
"""Test polling events storage and retrieval"""
|
||||
execution_id = "test-execution"
|
||||
|
||||
event1 = {"event_id": "event-1", "type": "test", "data": "test1"}
|
||||
event2 = {"event_id": "event-2", "type": "test", "data": "test2"}
|
||||
|
||||
self.manager._store_polling_event(execution_id, event1)
|
||||
self.manager._store_polling_event(execution_id, event2)
|
||||
|
||||
events = self.manager.get_polling_events(execution_id)
|
||||
assert len(events) == 2
|
||||
assert events[0] == event1
|
||||
assert events[1] == event2
|
||||
|
||||
def test_polling_events_with_last_event_id(self):
|
||||
"""Test polling events retrieval with last_event_id"""
|
||||
execution_id = "test-execution"
|
||||
|
||||
event1 = {"event_id": "event-1", "type": "test", "data": "test1"}
|
||||
event2 = {"event_id": "event-2", "type": "test", "data": "test2"}
|
||||
event3 = {"event_id": "event-3", "type": "test", "data": "test3"}
|
||||
|
||||
self.manager._store_polling_event(execution_id, event1)
|
||||
self.manager._store_polling_event(execution_id, event2)
|
||||
self.manager._store_polling_event(execution_id, event3)
|
||||
|
||||
events = self.manager.get_polling_events(execution_id, "event-1")
|
||||
assert len(events) == 2
|
||||
assert events[0] == event2
|
||||
assert events[1] == event3
|
||||
|
||||
def test_polling_events_limit(self):
|
||||
"""Test polling events storage limit"""
|
||||
execution_id = "test-execution"
|
||||
|
||||
for i in range(105):
|
||||
event = {"event_id": f"event-{i}", "type": "test", "data": f"test{i}"}
|
||||
self.manager._store_polling_event(execution_id, event)
|
||||
|
||||
events = self.manager.get_polling_events(execution_id)
|
||||
assert len(events) == 100
|
||||
assert events[0]["event_id"] == "event-5"
|
||||
assert events[-1]["event_id"] == "event-104"
|
||||
|
||||
def test_event_serialization(self):
|
||||
"""Test event serialization"""
|
||||
event = HumanInputRequiredEvent(
|
||||
execution_id="test-execution",
|
||||
crew_id="test-crew",
|
||||
task_id="test-task",
|
||||
prompt="Test prompt"
|
||||
)
|
||||
|
||||
serialized = self.manager._serialize_event(event)
|
||||
assert isinstance(serialized, dict)
|
||||
assert serialized["type"] == "human_input_required"
|
||||
assert serialized["execution_id"] == "test-execution"
|
||||
assert "event_id" in serialized
|
||||
|
||||
def test_broadcast_websocket(self):
|
||||
"""Test WebSocket broadcasting"""
|
||||
execution_id = "test-execution"
|
||||
websocket = MagicMock()
|
||||
|
||||
self.manager.add_websocket_connection(execution_id, websocket)
|
||||
|
||||
event_data = {"type": "test", "data": "test"}
|
||||
|
||||
with patch('asyncio.create_task') as mock_create_task:
|
||||
self.manager._broadcast_websocket(execution_id, event_data)
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
def test_broadcast_sse(self):
|
||||
"""Test SSE broadcasting"""
|
||||
execution_id = "test-execution"
|
||||
queue = MagicMock()
|
||||
|
||||
self.manager.add_sse_connection(execution_id, queue)
|
||||
|
||||
event_data = {"type": "test", "data": "test"}
|
||||
self.manager._broadcast_sse(execution_id, event_data)
|
||||
|
||||
queue.put_nowait.assert_called_once_with(event_data)
|
||||
|
||||
def test_broadcast_event(self):
|
||||
"""Test complete event broadcasting"""
|
||||
execution_id = "test-execution"
|
||||
|
||||
event = HumanInputRequiredEvent(
|
||||
execution_id=execution_id,
|
||||
crew_id="test-crew",
|
||||
task_id="test-task",
|
||||
prompt="Test prompt"
|
||||
)
|
||||
|
||||
with patch.object(self.manager, '_broadcast_websocket') as mock_ws, \
|
||||
patch.object(self.manager, '_broadcast_sse') as mock_sse, \
|
||||
patch.object(self.manager, '_store_polling_event') as mock_poll:
|
||||
|
||||
self.manager._broadcast_event(event)
|
||||
|
||||
mock_ws.assert_called_once()
|
||||
mock_sse.assert_called_once()
|
||||
mock_poll.assert_called_once()
|
||||
|
||||
def test_cleanup_execution(self):
|
||||
"""Test execution cleanup"""
|
||||
execution_id = "test-execution"
|
||||
|
||||
websocket = MagicMock()
|
||||
queue = MagicMock()
|
||||
event = {"event_id": "event-1", "type": "test"}
|
||||
|
||||
self.manager.add_websocket_connection(execution_id, websocket)
|
||||
self.manager.add_sse_connection(execution_id, queue)
|
||||
self.manager._store_polling_event(execution_id, event)
|
||||
|
||||
assert execution_id in self.manager._websocket_connections
|
||||
assert execution_id in self.manager._sse_connections
|
||||
assert execution_id in self.manager._polling_events
|
||||
|
||||
self.manager.cleanup_execution(execution_id)
|
||||
|
||||
assert execution_id not in self.manager._websocket_connections
|
||||
assert execution_id not in self.manager._sse_connections
|
||||
assert execution_id not in self.manager._polling_events
|
||||
|
||||
def test_register_event_listeners(self):
|
||||
"""Test event listener registration"""
|
||||
with patch('crewai.utilities.events.crewai_event_bus.crewai_event_bus.on') as mock_on:
|
||||
self.manager.register_event_listeners()
|
||||
assert mock_on.call_count == 2
|
||||
|
||||
self.manager.register_event_listeners()
|
||||
assert mock_on.call_count == 2
|
||||
|
||||
def test_broadcast_event_without_execution_id(self):
|
||||
"""Test broadcasting event without execution_id"""
|
||||
event = HumanInputRequiredEvent(
|
||||
crew_id="test-crew",
|
||||
task_id="test-task",
|
||||
prompt="Test prompt"
|
||||
)
|
||||
|
||||
with patch.object(self.manager, '_broadcast_websocket') as mock_ws, \
|
||||
patch.object(self.manager, '_broadcast_sse') as mock_sse, \
|
||||
patch.object(self.manager, '_store_polling_event') as mock_poll:
|
||||
|
||||
self.manager._broadcast_event(event)
|
||||
|
||||
mock_ws.assert_not_called()
|
||||
mock_sse.assert_not_called()
|
||||
mock_poll.assert_not_called()
|
||||
136
tests/server/test_human_input_server.py
Normal file
136
tests/server/test_human_input_server.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from fastapi.testclient import TestClient
|
||||
from crewai.server.human_input_server import HumanInputServer
|
||||
from crewai.server.event_stream_manager import event_stream_manager
|
||||
from crewai.utilities.events.task_events import HumanInputRequiredEvent
|
||||
FASTAPI_AVAILABLE = True
|
||||
except ImportError:
|
||||
FASTAPI_AVAILABLE = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FASTAPI_AVAILABLE, reason="FastAPI dependencies not available")
|
||||
class TestHumanInputServer:
|
||||
"""Test the human input server endpoints"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test environment"""
|
||||
self.server = HumanInputServer(host="localhost", port=8001, api_key="test-key")
|
||||
self.client = TestClient(self.server.app)
|
||||
event_stream_manager._websocket_connections.clear()
|
||||
event_stream_manager._sse_connections.clear()
|
||||
event_stream_manager._polling_events.clear()
|
||||
|
||||
def test_health_endpoint(self):
|
||||
"""Test health check endpoint"""
|
||||
response = self.client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "timestamp" in data
|
||||
|
||||
def test_polling_endpoint_unauthorized(self):
|
||||
"""Test polling endpoint without authentication"""
|
||||
response = self.client.get("/poll/human-input/test-execution-id")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_polling_endpoint_authorized(self):
|
||||
"""Test polling endpoint with authentication"""
|
||||
headers = {"Authorization": "Bearer test-key"}
|
||||
response = self.client.get("/poll/human-input/test-execution-id", headers=headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "events" in data
|
||||
assert isinstance(data["events"], list)
|
||||
|
||||
def test_polling_endpoint_with_events(self):
|
||||
"""Test polling endpoint returns stored events"""
|
||||
execution_id = "test-execution-id"
|
||||
|
||||
event = HumanInputRequiredEvent(
|
||||
execution_id=execution_id,
|
||||
crew_id="test-crew",
|
||||
task_id="test-task",
|
||||
agent_id="test-agent",
|
||||
prompt="Test prompt",
|
||||
context="Test context",
|
||||
event_id="test-event-1"
|
||||
)
|
||||
|
||||
event_stream_manager._store_polling_event(execution_id, event.to_json())
|
||||
|
||||
headers = {"Authorization": "Bearer test-key"}
|
||||
response = self.client.get(f"/poll/human-input/{execution_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["events"]) == 1
|
||||
assert data["events"][0]["type"] == "human_input_required"
|
||||
assert data["events"][0]["execution_id"] == execution_id
|
||||
|
||||
def test_polling_endpoint_with_last_event_id(self):
|
||||
"""Test polling endpoint with last_event_id parameter"""
|
||||
execution_id = "test-execution-id"
|
||||
|
||||
event1 = HumanInputRequiredEvent(
|
||||
execution_id=execution_id,
|
||||
event_id="event-1"
|
||||
)
|
||||
event2 = HumanInputRequiredEvent(
|
||||
execution_id=execution_id,
|
||||
event_id="event-2"
|
||||
)
|
||||
|
||||
event_stream_manager._store_polling_event(execution_id, event1.to_json())
|
||||
event_stream_manager._store_polling_event(execution_id, event2.to_json())
|
||||
|
||||
headers = {"Authorization": "Bearer test-key"}
|
||||
response = self.client.get(
|
||||
f"/poll/human-input/{execution_id}?last_event_id=event-1",
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["events"]) == 1
|
||||
assert data["events"][0]["event_id"] == "event-2"
|
||||
|
||||
def test_sse_endpoint_unauthorized(self):
|
||||
"""Test SSE endpoint without authentication"""
|
||||
response = self.client.get("/events/human-input/test-execution-id")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_sse_endpoint_authorized(self):
|
||||
"""Test SSE endpoint with authentication"""
|
||||
headers = {"Authorization": "Bearer test-key"}
|
||||
with self.client.stream("GET", "/events/human-input/test-execution-id", headers=headers) as response:
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
def test_websocket_endpoint_unauthorized(self):
|
||||
"""Test WebSocket endpoint without authentication"""
|
||||
with pytest.raises(Exception):
|
||||
with self.client.websocket_connect("/ws/human-input/test-execution-id"):
|
||||
pass
|
||||
|
||||
def test_websocket_endpoint_authorized(self):
|
||||
"""Test WebSocket endpoint with authentication"""
|
||||
with self.client.websocket_connect("/ws/human-input/test-execution-id?token=test-key") as websocket:
|
||||
assert websocket is not None
|
||||
|
||||
def test_server_without_api_key(self):
|
||||
"""Test server initialization without API key"""
|
||||
server = HumanInputServer(host="localhost", port=8002)
|
||||
client = TestClient(server.app)
|
||||
|
||||
response = client.get("/poll/human-input/test-execution-id")
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get("/events/human-input/test-execution-id")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.skipif(FASTAPI_AVAILABLE, reason="Testing import error handling")
|
||||
def test_server_without_fastapi():
|
||||
"""Test server initialization without FastAPI dependencies"""
|
||||
with pytest.raises(ImportError, match="FastAPI dependencies not available"):
|
||||
HumanInputServer()
|
||||
@@ -191,17 +191,39 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {"key": "value"}
|
||||
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{'role': 'assistant' , 'content': test_value}],
|
||||
[{"role": "assistant" , "content": test_value}],
|
||||
infer=True,
|
||||
metadata={"type": "short_term", "key": "value"},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id='Test_Agent'
|
||||
)
|
||||
|
||||
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")]
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {"key": "value"}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{"role": "assistant" , "content": test_value}],
|
||||
infer=True,
|
||||
metadata={"type": "short_term", "key": "value"},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3'
|
||||
)
|
||||
|
||||
|
||||
@@ -209,13 +231,13 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {"key": "value"}
|
||||
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{'role': 'assistant' , 'content': test_value}],
|
||||
infer=True,
|
||||
@@ -224,7 +246,9 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co
|
||||
run_id="my_run_id",
|
||||
includes="include1",
|
||||
excludes="exclude1",
|
||||
output_format='v1.1'
|
||||
output_format='v1.1',
|
||||
user_id='test_user',
|
||||
agent_id='Test_Agent'
|
||||
)
|
||||
|
||||
|
||||
@@ -237,10 +261,10 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id="test_user",
|
||||
filters={'AND': [{'run_id': 'my_run_id'}]},
|
||||
filters={'AND': [{'run_id': 'my_run_id'}]},
|
||||
threshold=0.5
|
||||
)
|
||||
|
||||
@@ -257,8 +281,8 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
query="test query",
|
||||
limit=5,
|
||||
metadata={"type": "short_term"},
|
||||
user_id="test_user",
|
||||
version='v2',
|
||||
@@ -286,4 +310,56 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
|
||||
)
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew)
|
||||
assert mem0_storage.infer is True
|
||||
assert mem0_storage.infer is True
|
||||
|
||||
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
|
||||
config = {
|
||||
"agent_id": "agent-123",
|
||||
}
|
||||
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config=config)
|
||||
mem0_storage.save("test memory", {"key": "value"})
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{'role': 'assistant' , 'content': 'test memory'}],
|
||||
infer=True,
|
||||
metadata={"type": "external", "key": "value"},
|
||||
agent_id="agent-123",
|
||||
)
|
||||
|
||||
def test_search_method_with_agent_entity():
|
||||
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123"})
|
||||
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
filters={"AND": [{"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_search_method_with_agent_id_and_user_id():
|
||||
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"})
|
||||
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id='user-123',
|
||||
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
209
tests/test_human_input_event_integration.py
Normal file
209
tests/test_human_input_event_integration.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import pytest
|
||||
import uuid
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.utilities.events.task_events import HumanInputRequiredEvent, HumanInputCompletedEvent
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
|
||||
|
||||
class TestHumanInputEventIntegration:
|
||||
"""Test integration between human input flow and event system"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test environment"""
|
||||
self.executor = CrewAgentExecutorMixin()
|
||||
self.executor.crew = MagicMock()
|
||||
self.executor.crew.id = str(uuid.uuid4())
|
||||
self.executor.crew._train = False
|
||||
self.executor.task = MagicMock()
|
||||
self.executor.task.id = str(uuid.uuid4())
|
||||
self.executor.agent = MagicMock()
|
||||
self.executor.agent.id = str(uuid.uuid4())
|
||||
self.executor._printer = MagicMock()
|
||||
|
||||
@patch('builtins.input', return_value='test feedback')
|
||||
def test_human_input_emits_required_event(self, mock_input):
|
||||
"""Test that human input emits HumanInputRequiredEvent"""
|
||||
events_captured = []
|
||||
|
||||
def capture_event(event):
|
||||
events_captured.append(event)
|
||||
|
||||
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
|
||||
result = self.executor._ask_human_input("Test result")
|
||||
|
||||
assert result == 'test feedback'
|
||||
assert len(events_captured) == 2
|
||||
|
||||
required_event = events_captured[0][1]
|
||||
completed_event = events_captured[1][1]
|
||||
|
||||
assert isinstance(required_event, HumanInputRequiredEvent)
|
||||
assert isinstance(completed_event, HumanInputCompletedEvent)
|
||||
|
||||
assert required_event.execution_id == str(self.executor.crew.id)
|
||||
assert required_event.crew_id == str(self.executor.crew.id)
|
||||
assert required_event.task_id == str(self.executor.task.id)
|
||||
assert required_event.agent_id == str(self.executor.agent.id)
|
||||
assert "HUMAN FEEDBACK" in required_event.prompt
|
||||
assert required_event.context == "Test result"
|
||||
assert required_event.event_id is not None
|
||||
|
||||
assert completed_event.execution_id == str(self.executor.crew.id)
|
||||
assert completed_event.human_feedback == 'test feedback'
|
||||
assert completed_event.event_id == required_event.event_id
|
||||
|
||||
@patch('builtins.input', return_value='training feedback')
|
||||
def test_training_mode_human_input_events(self, mock_input):
|
||||
"""Test human input events in training mode"""
|
||||
self.executor.crew._train = True
|
||||
events_captured = []
|
||||
|
||||
def capture_event(event):
|
||||
events_captured.append(event)
|
||||
|
||||
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
|
||||
result = self.executor._ask_human_input("Test result")
|
||||
|
||||
assert result == 'training feedback'
|
||||
assert len(events_captured) == 2
|
||||
|
||||
required_event = events_captured[0][1]
|
||||
assert isinstance(required_event, HumanInputRequiredEvent)
|
||||
assert "TRAINING MODE" in required_event.prompt
|
||||
|
||||
@patch('builtins.input', return_value='')
|
||||
def test_empty_feedback_events(self, mock_input):
|
||||
"""Test events with empty feedback"""
|
||||
events_captured = []
|
||||
|
||||
def capture_event(event):
|
||||
events_captured.append(event)
|
||||
|
||||
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
|
||||
result = self.executor._ask_human_input("Test result")
|
||||
|
||||
assert result == ''
|
||||
assert len(events_captured) == 2
|
||||
|
||||
completed_event = events_captured[1][1]
|
||||
assert isinstance(completed_event, HumanInputCompletedEvent)
|
||||
assert completed_event.human_feedback == ''
|
||||
|
||||
def test_event_payload_structure(self):
|
||||
"""Test that event payload matches GitHub issue specification"""
|
||||
event = HumanInputRequiredEvent(
|
||||
execution_id="test-execution-id",
|
||||
crew_id="test-crew-id",
|
||||
task_id="test-task-id",
|
||||
agent_id="test-agent-id",
|
||||
prompt="Test prompt",
|
||||
context="Test context",
|
||||
reason_flags={"ambiguity": True, "missing_field": False},
|
||||
event_id="test-event-id"
|
||||
)
|
||||
|
||||
payload = event.to_json()
|
||||
|
||||
assert payload["type"] == "human_input_required"
|
||||
assert payload["execution_id"] == "test-execution-id"
|
||||
assert payload["crew_id"] == "test-crew-id"
|
||||
assert payload["task_id"] == "test-task-id"
|
||||
assert payload["agent_id"] == "test-agent-id"
|
||||
assert payload["prompt"] == "Test prompt"
|
||||
assert payload["context"] == "Test context"
|
||||
assert payload["reason_flags"]["ambiguity"] is True
|
||||
assert payload["reason_flags"]["missing_field"] is False
|
||||
assert payload["event_id"] == "test-event-id"
|
||||
assert "timestamp" in payload
|
||||
|
||||
def test_completed_event_payload_structure(self):
|
||||
"""Test that completed event payload is correct"""
|
||||
event = HumanInputCompletedEvent(
|
||||
execution_id="test-execution-id",
|
||||
crew_id="test-crew-id",
|
||||
task_id="test-task-id",
|
||||
agent_id="test-agent-id",
|
||||
event_id="test-event-id",
|
||||
human_feedback="User feedback"
|
||||
)
|
||||
|
||||
payload = event.to_json()
|
||||
|
||||
assert payload["type"] == "human_input_completed"
|
||||
assert payload["execution_id"] == "test-execution-id"
|
||||
assert payload["crew_id"] == "test-crew-id"
|
||||
assert payload["task_id"] == "test-task-id"
|
||||
assert payload["agent_id"] == "test-agent-id"
|
||||
assert payload["event_id"] == "test-event-id"
|
||||
assert payload["human_feedback"] == "User feedback"
|
||||
assert "timestamp" in payload
|
||||
|
||||
@patch('builtins.input', side_effect=KeyboardInterrupt("Test interrupt"))
|
||||
def test_human_input_exception_handling(self, mock_input):
|
||||
"""Test that events are still emitted even if input is interrupted"""
|
||||
events_captured = []
|
||||
|
||||
def capture_event(event):
|
||||
events_captured.append(event)
|
||||
|
||||
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event):
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
self.executor._ask_human_input("Test result")
|
||||
|
||||
assert len(events_captured) == 1
|
||||
required_event = events_captured[0][1]
|
||||
assert isinstance(required_event, HumanInputRequiredEvent)
|
||||
|
||||
def test_human_input_without_crew(self):
|
||||
"""Test human input events when crew is None"""
|
||||
self.executor.crew = None
|
||||
events_captured = []
|
||||
|
||||
def capture_event(event):
|
||||
events_captured.append(event)
|
||||
|
||||
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event), \
|
||||
patch('builtins.input', return_value='test'):
|
||||
|
||||
self.executor._ask_human_input("Test result")
|
||||
|
||||
assert len(events_captured) == 2
|
||||
required_event = events_captured[0][1]
|
||||
assert required_event.execution_id is None
|
||||
assert required_event.crew_id is None
|
||||
|
||||
def test_human_input_without_task(self):
|
||||
"""Test human input events when task is None"""
|
||||
self.executor.task = None
|
||||
events_captured = []
|
||||
|
||||
def capture_event(event):
|
||||
events_captured.append(event)
|
||||
|
||||
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event), \
|
||||
patch('builtins.input', return_value='test'):
|
||||
|
||||
self.executor._ask_human_input("Test result")
|
||||
|
||||
assert len(events_captured) == 2
|
||||
required_event = events_captured[0][1]
|
||||
assert required_event.task_id is None
|
||||
|
||||
def test_human_input_without_agent(self):
|
||||
"""Test human input events when agent is None"""
|
||||
self.executor.agent = None
|
||||
events_captured = []
|
||||
|
||||
def capture_event(event):
|
||||
events_captured.append(event)
|
||||
|
||||
with patch.object(crewai_event_bus, 'emit', side_effect=capture_event), \
|
||||
patch('builtins.input', return_value='test'):
|
||||
|
||||
self.executor._ask_human_input("Test result")
|
||||
|
||||
assert len(events_captured) == 2
|
||||
required_event = events_captured[0][1]
|
||||
assert required_event.agent_id is None
|
||||
Reference in New Issue
Block a user