mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
Address PR feedback: Improve error handling, add OpenAPI docs, and verify task management
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from crewai.a2a.agent import A2AAgentIntegration
|
from crewai.a2a.agent import A2AAgentIntegration
|
||||||
from crewai.a2a.client import A2AClient
|
from crewai.a2a.client import A2AClient
|
||||||
|
from crewai.a2a.config import A2AConfig
|
||||||
from crewai.a2a.server import A2AServer
|
from crewai.a2a.server import A2AServer
|
||||||
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
||||||
|
|
||||||
@@ -11,4 +12,5 @@ __all__ = [
|
|||||||
"A2AServer",
|
"A2AServer",
|
||||||
"TaskManager",
|
"TaskManager",
|
||||||
"InMemoryTaskManager",
|
"InMemoryTaskManager",
|
||||||
|
"A2AConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -8,11 +8,14 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.a2a.config import A2AConfig
|
||||||
|
|
||||||
from crewai.types.a2a import (
|
from crewai.types.a2a import (
|
||||||
A2AClientError,
|
A2AClientError,
|
||||||
A2AClientHTTPError,
|
A2AClientHTTPError,
|
||||||
@@ -53,7 +56,8 @@ class A2AClient:
|
|||||||
self,
|
self,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
timeout: int = 60,
|
timeout: Optional[int] = None,
|
||||||
|
config: Optional["A2AConfig"] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the A2A client.
|
"""Initialize the A2A client.
|
||||||
|
|
||||||
@@ -61,10 +65,22 @@ class A2AClient:
|
|||||||
base_url: The base URL of the A2A server.
|
base_url: The base URL of the A2A server.
|
||||||
api_key: The API key to use for authentication.
|
api_key: The API key to use for authentication.
|
||||||
timeout: The timeout for HTTP requests in seconds.
|
timeout: The timeout for HTTP requests in seconds.
|
||||||
|
config: The A2A configuration. If provided, other parameters are ignored.
|
||||||
"""
|
"""
|
||||||
|
if config:
|
||||||
|
from crewai.a2a.config import A2AConfig
|
||||||
|
self.config = config
|
||||||
|
else:
|
||||||
|
from crewai.a2a.config import A2AConfig
|
||||||
|
self.config = A2AConfig()
|
||||||
|
if api_key:
|
||||||
|
self.config.api_key = api_key
|
||||||
|
if timeout:
|
||||||
|
self.config.client_timeout = timeout
|
||||||
|
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
self.api_key = api_key or os.environ.get("A2A_API_KEY")
|
self.api_key = self.config.api_key or os.environ.get("A2A_API_KEY")
|
||||||
self.timeout = timeout
|
self.timeout = self.config.client_timeout
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def send_task(
|
async def send_task(
|
||||||
@@ -92,7 +108,10 @@ class A2AClient:
|
|||||||
The created task.
|
The created task.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
A2AClientError: If there is an error sending the task.
|
MissingAPIKeyError: If no API key is provided.
|
||||||
|
A2AClientHTTPError: If there is an HTTP error.
|
||||||
|
A2AClientJSONError: If there is an error parsing the JSON response.
|
||||||
|
A2AClientError: If there is any other error sending the task.
|
||||||
"""
|
"""
|
||||||
params = TaskSendParams(
|
params = TaskSendParams(
|
||||||
id=task_id,
|
id=task_id,
|
||||||
@@ -105,15 +124,26 @@ class A2AClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
request = SendTaskRequest(params=params)
|
request = SendTaskRequest(params=params)
|
||||||
response = await self._send_jsonrpc_request(request)
|
|
||||||
|
try:
|
||||||
|
response = await self._send_jsonrpc_request(request)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise A2AClientError(f"Error sending task: {response.error.message}")
|
||||||
|
|
||||||
if response.error:
|
if not response.result:
|
||||||
raise A2AClientError(f"Error sending task: {response.error.message}")
|
raise A2AClientError("No result returned from send task request")
|
||||||
|
|
||||||
if not response.result:
|
if isinstance(response.result, dict):
|
||||||
raise A2AClientError("No result returned from send task request")
|
return Task.model_validate(response.result)
|
||||||
|
return cast(Task, response.result)
|
||||||
return cast(Task, response.result)
|
except asyncio.TimeoutError as e:
|
||||||
|
raise A2AClientError(f"Task request timed out: {e}")
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
if isinstance(e, aiohttp.ClientResponseError):
|
||||||
|
raise A2AClientHTTPError(e.status, str(e))
|
||||||
|
else:
|
||||||
|
raise A2AClientError(f"Client error: {e}")
|
||||||
|
|
||||||
async def send_task_streaming(
|
async def send_task_streaming(
|
||||||
self,
|
self,
|
||||||
@@ -318,6 +348,14 @@ class A2AClient:
|
|||||||
return JSONRPCResponse.model_validate(data)
|
return JSONRPCResponse.model_validate(data)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise A2AClientError(f"Invalid response: {e}")
|
raise A2AClientError(f"Invalid response: {e}")
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
raise A2AClientHTTPError(status=0, message=f"Connection error: {e}")
|
||||||
|
except aiohttp.ClientOSError as e:
|
||||||
|
raise A2AClientHTTPError(status=0, message=f"OS error: {e}")
|
||||||
|
except aiohttp.ServerDisconnectedError as e:
|
||||||
|
raise A2AClientHTTPError(status=0, message=f"Server disconnected: {e}")
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
raise A2AClientHTTPError(e.status, str(e))
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
raise A2AClientError(f"HTTP error: {e}")
|
raise A2AClientError(f"HTTP error: {e}")
|
||||||
|
|
||||||
@@ -394,6 +432,14 @@ class A2AClient:
|
|||||||
await queue.put(
|
await queue.put(
|
||||||
A2AClientError(f"Invalid artifact event: {e}")
|
A2AClientError(f"Invalid artifact event: {e}")
|
||||||
)
|
)
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
await queue.put(A2AClientHTTPError(status=0, message=f"Connection error: {e}"))
|
||||||
|
except aiohttp.ClientOSError as e:
|
||||||
|
await queue.put(A2AClientHTTPError(status=0, message=f"OS error: {e}"))
|
||||||
|
except aiohttp.ServerDisconnectedError as e:
|
||||||
|
await queue.put(A2AClientHTTPError(status=0, message=f"Server disconnected: {e}"))
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
await queue.put(A2AClientHTTPError(e.status, str(e)))
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
await queue.put(A2AClientError(f"HTTP error: {e}"))
|
await queue.put(A2AClientError(f"HTTP error: {e}"))
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|||||||
89
src/crewai/a2a/config.py
Normal file
89
src/crewai/a2a/config.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""
|
||||||
|
Configuration management for A2A protocol in CrewAI.
|
||||||
|
|
||||||
|
This module provides configuration management for the A2A protocol implementation
|
||||||
|
in CrewAI, including default values and environment variable support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class A2AConfig(BaseModel):
|
||||||
|
"""Configuration for A2A protocol."""
|
||||||
|
|
||||||
|
server_host: str = Field(
|
||||||
|
default="0.0.0.0",
|
||||||
|
description="Host to bind the A2A server to.",
|
||||||
|
)
|
||||||
|
server_port: int = Field(
|
||||||
|
default=8000,
|
||||||
|
description="Port to bind the A2A server to.",
|
||||||
|
)
|
||||||
|
enable_cors: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether to enable CORS for the A2A server.",
|
||||||
|
)
|
||||||
|
cors_origins: Optional[list[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="CORS origins to allow. If None, all origins are allowed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
client_timeout: int = Field(
|
||||||
|
default=60,
|
||||||
|
description="Timeout for A2A client requests in seconds.",
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for A2A authentication.",
|
||||||
|
)
|
||||||
|
|
||||||
|
task_ttl: int = Field(
|
||||||
|
default=3600,
|
||||||
|
description="Time-to-live for tasks in seconds.",
|
||||||
|
)
|
||||||
|
cleanup_interval: int = Field(
|
||||||
|
default=300,
|
||||||
|
description="Interval for cleaning up expired tasks in seconds.",
|
||||||
|
)
|
||||||
|
max_history_length: int = Field(
|
||||||
|
default=100,
|
||||||
|
description="Maximum number of messages to include in task history.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "A2AConfig":
|
||||||
|
"""Create a configuration from environment variables.
|
||||||
|
|
||||||
|
Environment variables are prefixed with A2A_ and are uppercase.
|
||||||
|
For example, A2A_SERVER_PORT=8080 will set server_port to 8080.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A2AConfig: The configuration.
|
||||||
|
"""
|
||||||
|
config_dict: Dict[str, Union[str, int, bool, list[str]]] = {}
|
||||||
|
|
||||||
|
if "A2A_SERVER_HOST" in os.environ:
|
||||||
|
config_dict["server_host"] = os.environ["A2A_SERVER_HOST"]
|
||||||
|
if "A2A_SERVER_PORT" in os.environ:
|
||||||
|
config_dict["server_port"] = int(os.environ["A2A_SERVER_PORT"])
|
||||||
|
if "A2A_ENABLE_CORS" in os.environ:
|
||||||
|
config_dict["enable_cors"] = os.environ["A2A_ENABLE_CORS"].lower() == "true"
|
||||||
|
if "A2A_CORS_ORIGINS" in os.environ:
|
||||||
|
config_dict["cors_origins"] = os.environ["A2A_CORS_ORIGINS"].split(",")
|
||||||
|
|
||||||
|
if "A2A_CLIENT_TIMEOUT" in os.environ:
|
||||||
|
config_dict["client_timeout"] = int(os.environ["A2A_CLIENT_TIMEOUT"])
|
||||||
|
if "A2A_API_KEY" in os.environ:
|
||||||
|
config_dict["api_key"] = os.environ["A2A_API_KEY"]
|
||||||
|
|
||||||
|
if "A2A_TASK_TTL" in os.environ:
|
||||||
|
config_dict["task_ttl"] = int(os.environ["A2A_TASK_TTL"])
|
||||||
|
if "A2A_CLEANUP_INTERVAL" in os.environ:
|
||||||
|
config_dict["cleanup_interval"] = int(os.environ["A2A_CLEANUP_INTERVAL"])
|
||||||
|
if "A2A_MAX_HISTORY_LENGTH" in os.environ:
|
||||||
|
config_dict["max_history_length"] = int(os.environ["A2A_MAX_HISTORY_LENGTH"])
|
||||||
|
|
||||||
|
return cls(**config_dict)
|
||||||
@@ -7,13 +7,16 @@ This module implements the server for the A2A protocol in CrewAI.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request, Response
|
from fastapi import FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.a2a.config import A2AConfig
|
||||||
|
|
||||||
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
||||||
from crewai.types.a2a import (
|
from crewai.types.a2a import (
|
||||||
A2ARequest,
|
A2ARequest,
|
||||||
@@ -58,17 +61,47 @@ class A2AServer:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task_manager: Optional[TaskManager] = None,
|
task_manager: Optional[TaskManager] = None,
|
||||||
enable_cors: bool = True,
|
enable_cors: Optional[bool] = None,
|
||||||
cors_origins: Optional[List[str]] = None,
|
cors_origins: Optional[List[str]] = None,
|
||||||
|
config: Optional["A2AConfig"] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the A2A server.
|
"""Initialize the A2A server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_manager: The task manager to use. If None, an InMemoryTaskManager will be created.
|
task_manager: The task manager to use. If None, an InMemoryTaskManager will be created.
|
||||||
enable_cors: Whether to enable CORS.
|
enable_cors: Whether to enable CORS. If None, uses config value.
|
||||||
cors_origins: The CORS origins to allow.
|
cors_origins: The CORS origins to allow. If None, uses config value.
|
||||||
|
config: The A2A configuration. If provided, other parameters are ignored.
|
||||||
"""
|
"""
|
||||||
self.app = FastAPI(title="A2A Server")
|
from crewai.a2a.config import A2AConfig
|
||||||
|
self.config = config or A2AConfig.from_env()
|
||||||
|
|
||||||
|
enable_cors = enable_cors if enable_cors is not None else self.config.enable_cors
|
||||||
|
cors_origins = cors_origins or self.config.cors_origins
|
||||||
|
|
||||||
|
self.app = FastAPI(
|
||||||
|
title="A2A Protocol Server",
|
||||||
|
description="""
|
||||||
|
A2A (Agent-to-Agent) protocol server for CrewAI.
|
||||||
|
|
||||||
|
This server implements Google's A2A protocol specification, enabling interoperability
|
||||||
|
between different agent systems. It provides endpoints for task creation, retrieval,
|
||||||
|
cancellation, and streaming updates.
|
||||||
|
""",
|
||||||
|
version="1.0.0",
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url="/redoc",
|
||||||
|
openapi_tags=[
|
||||||
|
{
|
||||||
|
"name": "tasks",
|
||||||
|
"description": "Operations for managing A2A tasks",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "jsonrpc",
|
||||||
|
"description": "JSON-RPC interface for the A2A protocol",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
self.task_manager = task_manager or InMemoryTaskManager()
|
self.task_manager = task_manager or InMemoryTaskManager()
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -81,11 +114,125 @@ class A2AServer:
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app.post("/v1/jsonrpc")(self.handle_jsonrpc)
|
@self.app.post(
|
||||||
self.app.post("/v1/tasks/send")(self.handle_send_task)
|
"/v1/jsonrpc",
|
||||||
self.app.post("/v1/tasks/sendSubscribe")(self.handle_send_task_subscribe)
|
summary="Handle JSON-RPC requests",
|
||||||
self.app.post("/v1/tasks/{task_id}/cancel")(self.handle_cancel_task)
|
description="""
|
||||||
self.app.get("/v1/tasks/{task_id}")(self.handle_get_task)
|
Process JSON-RPC requests for the A2A protocol.
|
||||||
|
|
||||||
|
This endpoint handles all JSON-RPC requests for the A2A protocol, including:
|
||||||
|
- SendTask: Create a new task
|
||||||
|
- GetTask: Retrieve a task by ID
|
||||||
|
- CancelTask: Cancel a running task
|
||||||
|
- SetTaskPushNotification: Configure push notifications for a task
|
||||||
|
- GetTaskPushNotification: Retrieve push notification configuration for a task
|
||||||
|
""",
|
||||||
|
response_model=JSONRPCResponse,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Successful response with result or error"},
|
||||||
|
400: {"description": "Invalid request format or parameters"},
|
||||||
|
500: {"description": "Internal server error during processing"},
|
||||||
|
},
|
||||||
|
tags=["jsonrpc"],
|
||||||
|
)
|
||||||
|
async def handle_jsonrpc(request: Request):
|
||||||
|
return await self.handle_jsonrpc(request)
|
||||||
|
|
||||||
|
@self.app.post(
|
||||||
|
"/v1/tasks/send",
|
||||||
|
summary="Send a task to an agent",
|
||||||
|
description="""
|
||||||
|
Create a new task and send it to an agent for execution.
|
||||||
|
|
||||||
|
This endpoint allows clients to send tasks to agents for processing.
|
||||||
|
The task is created with the provided parameters and immediately
|
||||||
|
transitions to the WORKING state. The response includes the created
|
||||||
|
task with its current status.
|
||||||
|
""",
|
||||||
|
response_model=Task,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Task created successfully and processing started"},
|
||||||
|
400: {"description": "Invalid request format or parameters"},
|
||||||
|
500: {"description": "Internal server error during task creation or processing"},
|
||||||
|
},
|
||||||
|
tags=["tasks"],
|
||||||
|
)
|
||||||
|
async def handle_send_task(request: Request):
|
||||||
|
return await self.handle_send_task(request)
|
||||||
|
|
||||||
|
@self.app.post(
|
||||||
|
"/v1/tasks/sendSubscribe",
|
||||||
|
summary="Send a task and subscribe to updates",
|
||||||
|
description="""
|
||||||
|
Create a new task and subscribe to status updates via Server-Sent Events (SSE).
|
||||||
|
|
||||||
|
This endpoint allows clients to send tasks to agents and receive real-time
|
||||||
|
updates as the task progresses. The response is a streaming SSE connection
|
||||||
|
that provides status updates and artifact notifications until the task
|
||||||
|
reaches a terminal state (COMPLETED, FAILED, CANCELED, or EXPIRED).
|
||||||
|
""",
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": "Streaming response with task updates",
|
||||||
|
"content": {
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {"type": "string"},
|
||||||
|
"example": 'event: status\ndata: {"task_id": "123", "status": {"state": "WORKING"}}\n\n',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
400: {"description": "Invalid request format or parameters"},
|
||||||
|
500: {"description": "Internal server error during task creation or processing"},
|
||||||
|
},
|
||||||
|
tags=["tasks"],
|
||||||
|
)
|
||||||
|
async def handle_send_task_subscribe(request: Request):
|
||||||
|
return await self.handle_send_task_subscribe(request)
|
||||||
|
|
||||||
|
@self.app.post(
|
||||||
|
"/v1/tasks/{task_id}/cancel",
|
||||||
|
summary="Cancel a task",
|
||||||
|
description="""
|
||||||
|
Cancel a running task by ID.
|
||||||
|
|
||||||
|
This endpoint allows clients to cancel a task that is currently in progress.
|
||||||
|
The task must be in a non-terminal state (PENDING, WORKING) to be canceled.
|
||||||
|
Once canceled, the task transitions to the CANCELED state and cannot be
|
||||||
|
resumed. The response includes the updated task with its current status.
|
||||||
|
""",
|
||||||
|
response_model=Task,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Task canceled successfully and status updated to CANCELED"},
|
||||||
|
404: {"description": "Task not found or already expired"},
|
||||||
|
409: {"description": "Task cannot be canceled (already in terminal state)"},
|
||||||
|
500: {"description": "Internal server error during task cancellation"},
|
||||||
|
},
|
||||||
|
tags=["tasks"],
|
||||||
|
)
|
||||||
|
async def handle_cancel_task(task_id: str, request: Request):
|
||||||
|
return await self.handle_cancel_task(task_id, request)
|
||||||
|
|
||||||
|
@self.app.get(
|
||||||
|
"/v1/tasks/{task_id}",
|
||||||
|
summary="Get task details",
|
||||||
|
description="""
|
||||||
|
Retrieve details of a task by ID.
|
||||||
|
|
||||||
|
This endpoint allows clients to retrieve the current state and details of a task.
|
||||||
|
The response includes the task's status, history, and any associated metadata.
|
||||||
|
Clients can specify the history_length parameter to limit the number of messages
|
||||||
|
included in the response.
|
||||||
|
""",
|
||||||
|
response_model=Task,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Task details retrieved successfully with current status"},
|
||||||
|
404: {"description": "Task not found or expired"},
|
||||||
|
500: {"description": "Internal server error during task retrieval"},
|
||||||
|
},
|
||||||
|
tags=["tasks"],
|
||||||
|
)
|
||||||
|
async def handle_get_task(task_id: str, request: Request):
|
||||||
|
return await self.handle_get_task(task_id, request)
|
||||||
|
|
||||||
async def handle_jsonrpc(self, request: Request) -> JSONResponse:
|
async def handle_jsonrpc(self, request: Request) -> JSONResponse:
|
||||||
"""Handle JSON-RPC requests.
|
"""Handle JSON-RPC requests.
|
||||||
@@ -121,7 +268,7 @@ class A2AServer:
|
|||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content=JSONRPCResponse(
|
content=JSONRPCResponse(
|
||||||
id=body.get("id") if isinstance(body, dict) else None,
|
id=body.get("id") if isinstance(body, dict) else None,
|
||||||
error=InternalError(data=str(e)),
|
error=InternalError(message="Internal server error"),
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
status_code=500,
|
status_code=500,
|
||||||
)
|
)
|
||||||
@@ -210,7 +357,7 @@ class A2AServer:
|
|||||||
self.logger.exception(f"Error handling {method} request")
|
self.logger.exception(f"Error handling {method} request")
|
||||||
return JSONRPCResponse(
|
return JSONRPCResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
error=InternalError(data=str(e)),
|
error=InternalError(message="Internal server error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_send_task(self, request: Request) -> JSONResponse:
|
async def handle_send_task(self, request: Request) -> JSONResponse:
|
||||||
@@ -227,15 +374,15 @@ class A2AServer:
|
|||||||
params = TaskSendParams.model_validate(body)
|
params = TaskSendParams.model_validate(body)
|
||||||
task = await self._handle_send_task(params)
|
task = await self._handle_send_task(params)
|
||||||
return JSONResponse(content=task.model_dump())
|
return JSONResponse(content=task.model_dump())
|
||||||
except ValidationError as e:
|
except ValidationError:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={"error": str(e)},
|
content={"error": "Invalid request format or parameters"},
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception("Error handling send task request")
|
self.logger.exception("Error handling send task request")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={"error": str(e)},
|
content={"error": "Internal server error"},
|
||||||
status_code=500,
|
status_code=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -283,15 +430,15 @@ class A2AServer:
|
|||||||
self._stream_task_updates(params.id, queue),
|
self._stream_task_updates(params.id, queue),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
except ValidationError as e:
|
except ValidationError:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={"error": str(e)},
|
content={"error": "Invalid request format or parameters"},
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception("Error handling send task subscribe request")
|
self.logger.exception("Error handling send task subscribe request")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={"error": str(e)},
|
content={"error": "Internal server error"},
|
||||||
status_code=500,
|
status_code=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -346,7 +493,7 @@ class A2AServer:
|
|||||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception(f"Error handling get task request for {task_id}")
|
self.logger.exception(f"Error handling get task request for {task_id}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
async def handle_cancel_task(self, task_id: str, request: Request) -> JSONResponse:
|
async def handle_cancel_task(self, task_id: str, request: Request) -> JSONResponse:
|
||||||
"""Handle cancel task requests.
|
"""Handle cancel task requests.
|
||||||
@@ -365,4 +512,4 @@ class A2AServer:
|
|||||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception(f"Error handling cancel task request for {task_id}")
|
self.logger.exception(f"Error handling cancel task request for {task_id}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|||||||
@@ -8,9 +8,12 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.a2a.config import A2AConfig
|
||||||
|
|
||||||
from crewai.types.a2a import (
|
from crewai.types.a2a import (
|
||||||
Artifact,
|
Artifact,
|
||||||
Message,
|
Message,
|
||||||
@@ -165,12 +168,37 @@ class TaskManager(ABC):
|
|||||||
class InMemoryTaskManager(TaskManager):
|
class InMemoryTaskManager(TaskManager):
|
||||||
"""In-memory implementation of the A2A task manager."""
|
"""In-memory implementation of the A2A task manager."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(
|
||||||
"""Initialize the in-memory task manager."""
|
self,
|
||||||
|
task_ttl: Optional[int] = None,
|
||||||
|
cleanup_interval: Optional[int] = None,
|
||||||
|
config: Optional["A2AConfig"] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the in-memory task manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_ttl: Time to live for tasks in seconds. Default is 1 hour.
|
||||||
|
cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes.
|
||||||
|
config: The A2A configuration. If provided, other parameters are ignored.
|
||||||
|
"""
|
||||||
|
from crewai.a2a.config import A2AConfig
|
||||||
|
self.config = config or A2AConfig.from_env()
|
||||||
|
|
||||||
|
self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl
|
||||||
|
self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval
|
||||||
|
|
||||||
self._tasks: Dict[str, Task] = {}
|
self._tasks: Dict[str, Task] = {}
|
||||||
self._push_notifications: Dict[str, PushNotificationConfig] = {}
|
self._push_notifications: Dict[str, PushNotificationConfig] = {}
|
||||||
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
|
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
|
||||||
|
self._task_timestamps: Dict[str, datetime] = {}
|
||||||
self._logger = logging.getLogger(__name__)
|
self._logger = logging.getLogger(__name__)
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if asyncio.get_running_loop():
|
||||||
|
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||||
|
except RuntimeError:
|
||||||
|
self._logger.info("No running event loop, periodic cleanup disabled")
|
||||||
|
|
||||||
async def create_task(
|
async def create_task(
|
||||||
self,
|
self,
|
||||||
@@ -198,6 +226,7 @@ class InMemoryTaskManager(TaskManager):
|
|||||||
state=TaskState.SUBMITTED,
|
state=TaskState.SUBMITTED,
|
||||||
message=message,
|
message=message,
|
||||||
timestamp=datetime.now(),
|
timestamp=datetime.now(),
|
||||||
|
previous_state=None, # Initial state has no previous state
|
||||||
)
|
)
|
||||||
|
|
||||||
task = Task(
|
task = Task(
|
||||||
@@ -211,6 +240,7 @@ class InMemoryTaskManager(TaskManager):
|
|||||||
|
|
||||||
self._tasks[task_id] = task
|
self._tasks[task_id] = task
|
||||||
self._task_subscribers[task_id] = set()
|
self._task_subscribers[task_id] = set()
|
||||||
|
self._task_timestamps[task_id] = datetime.now()
|
||||||
return task
|
return task
|
||||||
|
|
||||||
async def get_task(
|
async def get_task(
|
||||||
@@ -263,20 +293,29 @@ class InMemoryTaskManager(TaskManager):
|
|||||||
raise KeyError(f"Task {task_id} not found")
|
raise KeyError(f"Task {task_id} not found")
|
||||||
|
|
||||||
task = self._tasks[task_id]
|
task = self._tasks[task_id]
|
||||||
|
task = self._tasks[task_id]
|
||||||
|
previous_state = task.status.state if task.status else None
|
||||||
|
|
||||||
|
if previous_state and not TaskState.is_valid_transition(previous_state, state):
|
||||||
|
raise ValueError(f"Invalid state transition from {previous_state} to {state}")
|
||||||
|
|
||||||
status = TaskStatus(
|
status = TaskStatus(
|
||||||
state=state,
|
state=state,
|
||||||
message=message,
|
message=message,
|
||||||
timestamp=datetime.now(),
|
timestamp=datetime.now(),
|
||||||
|
previous_state=previous_state,
|
||||||
)
|
)
|
||||||
task.status = status
|
task.status = status
|
||||||
|
|
||||||
if message and task.history is not None:
|
if message and task.history is not None:
|
||||||
task.history.append(message)
|
task.history.append(message)
|
||||||
|
|
||||||
|
self._task_timestamps[task_id] = datetime.now()
|
||||||
|
|
||||||
event = TaskStatusUpdateEvent(
|
event = TaskStatusUpdateEvent(
|
||||||
id=task_id,
|
id=task_id,
|
||||||
status=status,
|
status=status,
|
||||||
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED],
|
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED],
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -436,3 +475,48 @@ class InMemoryTaskManager(TaskManager):
|
|||||||
if task_id in self._task_subscribers:
|
if task_id in self._task_subscribers:
|
||||||
for queue in self._task_subscribers[task_id]:
|
for queue in self._task_subscribers[task_id]:
|
||||||
await queue.put(event)
|
await queue.put(event)
|
||||||
|
|
||||||
|
async def _periodic_cleanup(self) -> None:
|
||||||
|
"""Periodically clean up expired tasks."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self._cleanup_interval)
|
||||||
|
await self._cleanup_expired_tasks()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.exception(f"Error during periodic cleanup: {e}")
|
||||||
|
|
||||||
|
async def _cleanup_expired_tasks(self) -> None:
|
||||||
|
"""Clean up expired tasks."""
|
||||||
|
now = datetime.now()
|
||||||
|
expired_tasks = []
|
||||||
|
|
||||||
|
for task_id, timestamp in self._task_timestamps.items():
|
||||||
|
if (now - timestamp).total_seconds() > self._task_ttl:
|
||||||
|
expired_tasks.append(task_id)
|
||||||
|
|
||||||
|
for task_id in expired_tasks:
|
||||||
|
self._logger.info(f"Cleaning up expired task: {task_id}")
|
||||||
|
self._tasks.pop(task_id, None)
|
||||||
|
self._push_notifications.pop(task_id, None)
|
||||||
|
self._task_timestamps.pop(task_id, None)
|
||||||
|
|
||||||
|
if task_id in self._task_subscribers:
|
||||||
|
previous_state = None
|
||||||
|
if task_id in self._tasks and self._tasks[task_id].status:
|
||||||
|
previous_state = self._tasks[task_id].status.state
|
||||||
|
|
||||||
|
status = TaskStatus(
|
||||||
|
state=TaskState.EXPIRED,
|
||||||
|
timestamp=now,
|
||||||
|
previous_state=previous_state,
|
||||||
|
)
|
||||||
|
event = TaskStatusUpdateEvent(
|
||||||
|
task_id=task_id,
|
||||||
|
status=status,
|
||||||
|
final=True,
|
||||||
|
)
|
||||||
|
await self._notify_subscribers(task_id, event)
|
||||||
|
|
||||||
|
self._task_subscribers.pop(task_id, None)
|
||||||
|
|||||||
@@ -459,10 +459,11 @@ class Agent(BaseAgent):
|
|||||||
task = Task(
|
task = Task(
|
||||||
description=task_description,
|
description=task_description,
|
||||||
agent=self,
|
agent=self,
|
||||||
|
expected_output="text", # Default to text output
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.execute_task(task, context)
|
result = self.execute_task(task=task, context=context)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.exception(f"Error handling A2A task: {e}")
|
self._logger.exception(f"Error handling A2A task: {e}")
|
||||||
|
|||||||
@@ -24,6 +24,42 @@ class TaskState(str, Enum):
|
|||||||
CANCELED = 'canceled'
|
CANCELED = 'canceled'
|
||||||
FAILED = 'failed'
|
FAILED = 'failed'
|
||||||
UNKNOWN = 'unknown'
|
UNKNOWN = 'unknown'
|
||||||
|
EXPIRED = 'expired'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def valid_transitions(cls) -> Dict[str, List[str]]:
|
||||||
|
"""Get valid state transitions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping from state to list of valid next states.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
cls.SUBMITTED: [cls.WORKING, cls.CANCELED, cls.FAILED],
|
||||||
|
cls.WORKING: [cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
|
||||||
|
cls.INPUT_REQUIRED: [cls.WORKING, cls.CANCELED, cls.FAILED],
|
||||||
|
cls.COMPLETED: [], # Terminal state
|
||||||
|
cls.CANCELED: [], # Terminal state
|
||||||
|
cls.FAILED: [], # Terminal state
|
||||||
|
cls.UNKNOWN: [cls.SUBMITTED, cls.WORKING, cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
|
||||||
|
cls.EXPIRED: [], # Terminal state
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_valid_transition(cls, from_state: 'TaskState', to_state: 'TaskState') -> bool:
|
||||||
|
"""Check if a state transition is valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_state: The current state.
|
||||||
|
to_state: The target state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the transition is valid, False otherwise.
|
||||||
|
"""
|
||||||
|
if from_state == to_state:
|
||||||
|
return True
|
||||||
|
|
||||||
|
valid_next_states = cls.valid_transitions().get(from_state, [])
|
||||||
|
return to_state in valid_next_states
|
||||||
|
|
||||||
|
|
||||||
class TextPart(BaseModel):
|
class TextPart(BaseModel):
|
||||||
@@ -83,11 +119,21 @@ class TaskStatus(BaseModel):
|
|||||||
state: TaskState
|
state: TaskState
|
||||||
message: Optional[Message] = None
|
message: Optional[Message] = None
|
||||||
timestamp: datetime = Field(default_factory=datetime.now)
|
timestamp: datetime = Field(default_factory=datetime.now)
|
||||||
|
previous_state: Optional[TaskState] = None
|
||||||
|
|
||||||
@field_serializer('timestamp')
|
@field_serializer('timestamp')
|
||||||
def serialize_dt(self, dt: datetime, _info):
|
def serialize_dt(self, dt: datetime, _info):
|
||||||
"""Serialize datetime to ISO format."""
|
"""Serialize datetime to ISO format."""
|
||||||
return dt.isoformat()
|
return dt.isoformat()
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def validate_state_transition(self) -> Self:
|
||||||
|
"""Validate state transition."""
|
||||||
|
if self.previous_state and not TaskState.is_valid_transition(self.previous_state, self.state):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid state transition from {self.previous_state} to {self.state}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class Artifact(BaseModel):
|
class Artifact(BaseModel):
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
"""Tests for the A2A protocol integration."""
|
"""Tests for the A2A protocol integration."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager
|
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.types.a2a import (
|
from crewai.types.a2a import (
|
||||||
|
JSONRPCResponse,
|
||||||
Message,
|
Message,
|
||||||
Task as A2ATask,
|
Task as A2ATask,
|
||||||
TaskState,
|
TaskState,
|
||||||
@@ -60,7 +64,7 @@ def a2a_integration():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def a2a_client():
|
def a2a_client():
|
||||||
"""Create an A2A client."""
|
"""Create an A2A client."""
|
||||||
return A2AClient(base_url="http://localhost:8000")
|
return A2AClient(base_url="http://localhost:8000", api_key="test_api_key")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -81,8 +85,7 @@ class TestA2AIntegration:
|
|||||||
@patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a")
|
@patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a")
|
||||||
def test_execute_task_via_a2a(self, mock_execute, agent):
|
def test_execute_task_via_a2a(self, mock_execute, agent):
|
||||||
"""Test executing a task via A2A."""
|
"""Test executing a task via A2A."""
|
||||||
mock_execute.return_value = asyncio.Future()
|
mock_execute.return_value = "Task result"
|
||||||
mock_execute.return_value.set_result("Task result")
|
|
||||||
|
|
||||||
result = asyncio.run(
|
result = asyncio.run(
|
||||||
agent.execute_task_via_a2a(
|
agent.execute_task_via_a2a(
|
||||||
@@ -116,8 +119,8 @@ class TestA2AIntegration:
|
|||||||
assert result == "Task result"
|
assert result == "Task result"
|
||||||
mock_execute.assert_called_once()
|
mock_execute.assert_called_once()
|
||||||
args, kwargs = mock_execute.call_args
|
args, kwargs = mock_execute.call_args
|
||||||
assert args[0].description == "Test task"
|
|
||||||
assert kwargs["context"] == "Test context"
|
assert kwargs["context"] == "Test context"
|
||||||
|
assert kwargs["task"].description == "Test task"
|
||||||
|
|
||||||
def test_a2a_disabled(self, agent):
|
def test_a2a_disabled(self, agent):
|
||||||
"""Test that A2A methods raise ValueError when A2A is disabled."""
|
"""Test that A2A methods raise ValueError when A2A is disabled."""
|
||||||
@@ -159,7 +162,7 @@ class TestA2AAgentIntegration:
|
|||||||
queue = asyncio.Queue()
|
queue = asyncio.Queue()
|
||||||
await queue.put(
|
await queue.put(
|
||||||
TaskStatusUpdateEvent(
|
TaskStatusUpdateEvent(
|
||||||
task_id="test_task_id",
|
id="test_task_id",
|
||||||
status=TaskStatus(
|
status=TaskStatus(
|
||||||
state=TaskState.COMPLETED,
|
state=TaskState.COMPLETED,
|
||||||
message=Message(
|
message=Message(
|
||||||
@@ -198,23 +201,29 @@ class TestA2AServer:
|
|||||||
class TestA2AClient:
|
class TestA2AClient:
|
||||||
"""Tests for the A2AClient class."""
|
"""Tests for the A2AClient class."""
|
||||||
|
|
||||||
@patch("aiohttp.ClientSession.post")
|
@patch("crewai.a2a.client.A2AClient._send_jsonrpc_request")
|
||||||
async def test_send_task(self, mock_post, a2a_client):
|
async def test_send_task(self, mock_send_request, a2a_client):
|
||||||
"""Test sending a task."""
|
"""Test sending a task."""
|
||||||
mock_response = MagicMock()
|
mock_response = JSONRPCResponse(
|
||||||
mock_response.status = 200
|
jsonrpc="2.0",
|
||||||
mock_response.json = AsyncMock(
|
id="test_request_id",
|
||||||
return_value={
|
result=A2ATask(
|
||||||
"id": "test_task_id",
|
id="test_task_id",
|
||||||
"history": [
|
sessionId="test_session_id",
|
||||||
{
|
status=TaskStatus(
|
||||||
"role": "user",
|
state=TaskState.SUBMITTED,
|
||||||
"parts": [{"text": "Test task description"}],
|
timestamp=datetime.now(),
|
||||||
}
|
),
|
||||||
|
history=[
|
||||||
|
Message(
|
||||||
|
role="user",
|
||||||
|
parts=[TextPart(text="Test task description")],
|
||||||
|
)
|
||||||
],
|
],
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
mock_post.return_value.__aenter__.return_value = mock_response
|
|
||||||
|
mock_send_request.return_value = mock_response
|
||||||
|
|
||||||
task = await a2a_client.send_task(
|
task = await a2a_client.send_task(
|
||||||
task_id="test_task_id",
|
task_id="test_task_id",
|
||||||
@@ -222,9 +231,10 @@ class TestA2AClient:
|
|||||||
role="user",
|
role="user",
|
||||||
parts=[TextPart(text="Test task description")],
|
parts=[TextPart(text="Test task description")],
|
||||||
),
|
),
|
||||||
|
session_id="test_session_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert task.id == "test_task_id"
|
assert task.id == "test_task_id"
|
||||||
assert task.history[0].role == "user"
|
assert task.history[0].role == "user"
|
||||||
assert task.history[0].parts[0].text == "Test task description"
|
assert task.history[0].parts[0].text == "Test task description"
|
||||||
mock_post.assert_called_once()
|
mock_send_request.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user