Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
9bb8854c25 Address PR feedback: Improve error handling, add OpenAPI docs, and verify task management
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-09 04:58:19 +00:00
Devin AI
cfabb9fa78 Add A2A protocol support for CrewAI (Issue #2796)
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-09 04:13:04 +00:00
13 changed files with 2723 additions and 0 deletions

View File

@@ -0,0 +1,59 @@
"""
Example of using the A2A protocol with CrewAI.
This example demonstrates how to:
1. Create an agent with A2A protocol support
2. Start an A2A server for the agent
3. Execute a task via the A2A protocol
"""
import asyncio
import os
import uvicorn
from threading import Thread
from crewai import Agent
from crewai.a2a import A2AServer, InMemoryTaskManager
agent = Agent(
role="Data Analyst",
goal="Analyze data and provide insights",
backstory="I am a data analyst with expertise in finding patterns and insights in data.",
a2a_enabled=True,
a2a_url="http://localhost:8000",
)
def start_server():
"""Start the A2A server."""
task_manager = InMemoryTaskManager()
server = A2AServer(task_manager=task_manager)
uvicorn.run(server.app, host="0.0.0.0", port=8000)
async def execute_task_via_a2a():
"""Execute a task via the A2A protocol."""
await asyncio.sleep(2)
result = await agent.execute_task_via_a2a(
task_description="Analyze the following data and provide insights: [1, 2, 3, 4, 5]",
context="This is a simple example of using the A2A protocol.",
)
print(f"Task result: {result}")
async def main():
"""Run the example."""
server_thread = Thread(target=start_server)
server_thread.daemon = True
server_thread.start()
await execute_task_via_a2a()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -23,4 +23,9 @@ __all__ = [
"LLM",
"Flow",
"Knowledge",
"A2AAgentIntegration",
"A2AClient",
"A2AServer",
]
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer

View File

@@ -0,0 +1,16 @@
"""A2A protocol implementation for CrewAI."""
from crewai.a2a.agent import A2AAgentIntegration
from crewai.a2a.client import A2AClient
from crewai.a2a.config import A2AConfig
from crewai.a2a.server import A2AServer
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
__all__ = [
"A2AAgentIntegration",
"A2AClient",
"A2AServer",
"TaskManager",
"InMemoryTaskManager",
"A2AConfig",
]

223
src/crewai/a2a/agent.py Normal file
View File

@@ -0,0 +1,223 @@
"""
A2A protocol agent integration for CrewAI.
This module implements the integration between CrewAI agents and the A2A protocol.
"""
import asyncio
import json
import logging
import uuid
from typing import Any, Dict, List, Optional, Union
from crewai.a2a.client import A2AClient
from crewai.a2a.task_manager import TaskManager
from crewai.types.a2a import (
Artifact,
DataPart,
FilePart,
Message,
Part,
Task as A2ATask,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
class A2AAgentIntegration:
"""Integration between CrewAI agents and the A2A protocol."""
def __init__(
self,
task_manager: Optional[TaskManager] = None,
client: Optional[A2AClient] = None,
):
"""Initialize the A2A agent integration.
Args:
task_manager: The task manager to use for handling A2A tasks.
client: The A2A client to use for sending tasks to other agents.
"""
self.task_manager = task_manager
self.client = client
self.logger = logging.getLogger(__name__)
async def execute_task_via_a2a(
self,
agent_url: str,
task_description: str,
context: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 300,
) -> str:
"""Execute a task via the A2A protocol.
Args:
agent_url: The URL of the agent to execute the task.
task_description: The description of the task.
context: Additional context for the task.
api_key: The API key to use for authentication.
timeout: The timeout for the task execution in seconds.
Returns:
The result of the task execution.
Raises:
TimeoutError: If the task execution times out.
Exception: If there is an error executing the task.
"""
if not self.client:
self.client = A2AClient(base_url=agent_url, api_key=api_key)
parts: List[Part] = [TextPart(text=task_description)]
if context:
parts.append(
DataPart(
data={"context": context},
metadata={"type": "context"},
)
)
message = Message(role="user", parts=parts)
task_id = str(uuid.uuid4())
try:
queue = await self.client.send_task_streaming(
task_id=task_id,
message=message,
)
result = await self._wait_for_task_completion(queue, timeout)
return result
except Exception as e:
self.logger.exception(f"Error executing task via A2A: {e}")
raise
async def _wait_for_task_completion(
self, queue: asyncio.Queue, timeout: int
) -> str:
"""Wait for a task to complete.
Args:
queue: The queue to receive task updates from.
timeout: The timeout for the task execution in seconds.
Returns:
The result of the task execution.
Raises:
TimeoutError: If the task execution times out.
Exception: If there is an error executing the task.
"""
result = ""
try:
async def _timeout():
await asyncio.sleep(timeout)
await queue.put(TimeoutError(f"Task execution timed out after {timeout} seconds"))
timeout_task = asyncio.create_task(_timeout())
while True:
event = await queue.get()
if isinstance(event, Exception):
raise event
if isinstance(event, TaskStatusUpdateEvent):
if event.status.state == TaskState.COMPLETED:
if event.status.message:
for part in event.status.message.parts:
if isinstance(part, TextPart):
result += part.text
break
elif event.status.state in [TaskState.FAILED, TaskState.CANCELED]:
error_message = "Task failed"
if event.status.message:
for part in event.status.message.parts:
if isinstance(part, TextPart):
error_message = part.text
raise Exception(error_message)
elif isinstance(event, TaskArtifactUpdateEvent):
for part in event.artifact.parts:
if isinstance(part, TextPart):
result += part.text
finally:
timeout_task.cancel()
return result
async def handle_a2a_task(
self,
task: A2ATask,
agent_execute_func: Any,
context: Optional[str] = None,
) -> None:
"""Handle an A2A task.
Args:
task: The A2A task to handle.
agent_execute_func: The function to execute the task.
context: Additional context for the task.
Raises:
Exception: If there is an error handling the task.
"""
if not self.task_manager:
raise ValueError("Task manager is required to handle A2A tasks")
try:
await self.task_manager.update_task_status(
task_id=task.id,
state=TaskState.WORKING,
)
task_description = ""
task_context = context or ""
if task.history and task.history[-1].role == "user":
message = task.history[-1]
for part in message.parts:
if isinstance(part, TextPart):
task_description += part.text
elif isinstance(part, DataPart) and part.data.get("context"):
task_context += part.data["context"]
try:
result = await agent_execute_func(task_description, task_context)
response_message = Message(
role="agent",
parts=[TextPart(text=result)],
)
await self.task_manager.update_task_status(
task_id=task.id,
state=TaskState.COMPLETED,
message=response_message,
)
artifact = Artifact(
name="result",
parts=[TextPart(text=result)],
)
await self.task_manager.add_task_artifact(
task_id=task.id,
artifact=artifact,
)
except Exception as e:
error_message = Message(
role="agent",
parts=[TextPart(text=str(e))],
)
await self.task_manager.update_task_status(
task_id=task.id,
state=TaskState.FAILED,
message=error_message,
)
raise
except Exception as e:
self.logger.exception(f"Error handling A2A task: {e}")
raise

470
src/crewai/a2a/client.py Normal file
View File

@@ -0,0 +1,470 @@
"""
A2A protocol client for CrewAI.
This module implements the client for the A2A protocol in CrewAI.
"""
import asyncio
import json
import logging
import os
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast
import aiohttp
from pydantic import ValidationError
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
from crewai.types.a2a import (
A2AClientError,
A2AClientHTTPError,
A2AClientJSONError,
Artifact,
CancelTaskRequest,
CancelTaskResponse,
GetTaskPushNotificationRequest,
GetTaskPushNotificationResponse,
GetTaskRequest,
GetTaskResponse,
JSONRPCError,
JSONRPCRequest,
JSONRPCResponse,
Message,
MissingAPIKeyError,
PushNotificationConfig,
SendTaskRequest,
SendTaskResponse,
SendTaskStreamingRequest,
SetTaskPushNotificationRequest,
SetTaskPushNotificationResponse,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskPushNotificationConfig,
TaskQueryParams,
TaskSendParams,
TaskState,
TaskStatusUpdateEvent,
)
class A2AClient:
"""A2A protocol client implementation."""
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
timeout: Optional[int] = None,
config: Optional["A2AConfig"] = None,
):
"""Initialize the A2A client.
Args:
base_url: The base URL of the A2A server.
api_key: The API key to use for authentication.
timeout: The timeout for HTTP requests in seconds.
config: The A2A configuration. If provided, other parameters are ignored.
"""
if config:
from crewai.a2a.config import A2AConfig
self.config = config
else:
from crewai.a2a.config import A2AConfig
self.config = A2AConfig()
if api_key:
self.config.api_key = api_key
if timeout:
self.config.client_timeout = timeout
self.base_url = base_url.rstrip("/")
self.api_key = self.config.api_key or os.environ.get("A2A_API_KEY")
self.timeout = self.config.client_timeout
self.logger = logging.getLogger(__name__)
async def send_task(
self,
task_id: str,
message: Message,
session_id: Optional[str] = None,
accepted_output_modes: Optional[List[str]] = None,
push_notification: Optional[PushNotificationConfig] = None,
history_length: Optional[int] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Task:
"""Send a task to the A2A server.
Args:
task_id: The ID of the task.
message: The message to send.
session_id: The session ID.
accepted_output_modes: The accepted output modes.
push_notification: The push notification configuration.
history_length: The number of messages to include in the history.
metadata: Additional metadata.
Returns:
The created task.
Raises:
MissingAPIKeyError: If no API key is provided.
A2AClientHTTPError: If there is an HTTP error.
A2AClientJSONError: If there is an error parsing the JSON response.
A2AClientError: If there is any other error sending the task.
"""
params = TaskSendParams(
id=task_id,
sessionId=session_id,
message=message,
acceptedOutputModes=accepted_output_modes,
pushNotification=push_notification,
historyLength=history_length,
metadata=metadata,
)
request = SendTaskRequest(params=params)
try:
response = await self._send_jsonrpc_request(request)
if response.error:
raise A2AClientError(f"Error sending task: {response.error.message}")
if not response.result:
raise A2AClientError("No result returned from send task request")
if isinstance(response.result, dict):
return Task.model_validate(response.result)
return cast(Task, response.result)
except asyncio.TimeoutError as e:
raise A2AClientError(f"Task request timed out: {e}")
except aiohttp.ClientError as e:
if isinstance(e, aiohttp.ClientResponseError):
raise A2AClientHTTPError(e.status, str(e))
else:
raise A2AClientError(f"Client error: {e}")
async def send_task_streaming(
self,
task_id: str,
message: Message,
session_id: Optional[str] = None,
accepted_output_modes: Optional[List[str]] = None,
push_notification: Optional[PushNotificationConfig] = None,
history_length: Optional[int] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> asyncio.Queue:
"""Send a task to the A2A server and subscribe to updates.
Args:
task_id: The ID of the task.
message: The message to send.
session_id: The session ID.
accepted_output_modes: The accepted output modes.
push_notification: The push notification configuration.
history_length: The number of messages to include in the history.
metadata: Additional metadata.
Returns:
A queue that will receive task updates.
Raises:
A2AClientError: If there is an error sending the task.
"""
params = TaskSendParams(
id=task_id,
sessionId=session_id,
message=message,
acceptedOutputModes=accepted_output_modes,
pushNotification=push_notification,
historyLength=history_length,
metadata=metadata,
)
queue: asyncio.Queue = asyncio.Queue()
asyncio.create_task(
self._handle_streaming_response(
f"{self.base_url}/v1/tasks/sendSubscribe", params, queue
)
)
return queue
async def get_task(
self, task_id: str, history_length: Optional[int] = None
) -> Task:
"""Get a task from the A2A server.
Args:
task_id: The ID of the task.
history_length: The number of messages to include in the history.
Returns:
The task.
Raises:
A2AClientError: If there is an error getting the task.
"""
params = TaskQueryParams(id=task_id, historyLength=history_length)
request = GetTaskRequest(params=params)
response = await self._send_jsonrpc_request(request)
if response.error:
raise A2AClientError(f"Error getting task: {response.error.message}")
if not response.result:
raise A2AClientError("No result returned from get task request")
return cast(Task, response.result)
async def cancel_task(self, task_id: str) -> Task:
"""Cancel a task on the A2A server.
Args:
task_id: The ID of the task.
Returns:
The canceled task.
Raises:
A2AClientError: If there is an error canceling the task.
"""
params = TaskIdParams(id=task_id)
request = CancelTaskRequest(params=params)
response = await self._send_jsonrpc_request(request)
if response.error:
raise A2AClientError(f"Error canceling task: {response.error.message}")
if not response.result:
raise A2AClientError("No result returned from cancel task request")
return cast(Task, response.result)
async def set_push_notification(
self, task_id: str, config: PushNotificationConfig
) -> PushNotificationConfig:
"""Set push notification for a task.
Args:
task_id: The ID of the task.
config: The push notification configuration.
Returns:
The push notification configuration.
Raises:
A2AClientError: If there is an error setting the push notification.
"""
params = TaskPushNotificationConfig(id=task_id, pushNotificationConfig=config)
request = SetTaskPushNotificationRequest(params=params)
response = await self._send_jsonrpc_request(request)
if response.error:
raise A2AClientError(
f"Error setting push notification: {response.error.message}"
)
if not response.result:
raise A2AClientError(
"No result returned from set push notification request"
)
return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig
async def get_push_notification(
self, task_id: str
) -> Optional[PushNotificationConfig]:
"""Get push notification for a task.
Args:
task_id: The ID of the task.
Returns:
The push notification configuration, or None if not set.
Raises:
A2AClientError: If there is an error getting the push notification.
"""
params = TaskIdParams(id=task_id)
request = GetTaskPushNotificationRequest(params=params)
response = await self._send_jsonrpc_request(request)
if response.error:
raise A2AClientError(
f"Error getting push notification: {response.error.message}"
)
if not response.result:
return None
return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig
async def _send_jsonrpc_request(
self, request: JSONRPCRequest
) -> JSONRPCResponse:
"""Send a JSON-RPC request to the A2A server.
Args:
request: The JSON-RPC request.
Returns:
The JSON-RPC response.
Raises:
A2AClientError: If there is an error sending the request.
"""
if not self.api_key:
raise MissingAPIKeyError(
"API key is required. Set it in the constructor or as the A2A_API_KEY environment variable."
)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/v1/jsonrpc",
headers=headers,
json=request.model_dump(),
timeout=self.timeout,
) as response:
if response.status != 200:
raise A2AClientHTTPError(
response.status, await response.text()
)
try:
data = await response.json()
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e))
try:
return JSONRPCResponse.model_validate(data)
except ValidationError as e:
raise A2AClientError(f"Invalid response: {e}")
except aiohttp.ClientConnectorError as e:
raise A2AClientHTTPError(status=0, message=f"Connection error: {e}")
except aiohttp.ClientOSError as e:
raise A2AClientHTTPError(status=0, message=f"OS error: {e}")
except aiohttp.ServerDisconnectedError as e:
raise A2AClientHTTPError(status=0, message=f"Server disconnected: {e}")
except aiohttp.ClientResponseError as e:
raise A2AClientHTTPError(e.status, str(e))
except aiohttp.ClientError as e:
raise A2AClientError(f"HTTP error: {e}")
async def _handle_streaming_response(
self,
url: str,
params: TaskSendParams,
queue: asyncio.Queue,
) -> None:
"""Handle a streaming response from the A2A server.
Args:
url: The URL to send the request to.
params: The task send parameters.
queue: The queue to put events into.
"""
if not self.api_key:
await queue.put(
Exception(
"API key is required. Set it in the constructor or as the A2A_API_KEY environment variable."
)
)
return
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"Accept": "text/event-stream",
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
headers=headers,
json=params.model_dump(),
timeout=self.timeout,
) as response:
if response.status != 200:
await queue.put(
A2AClientHTTPError(response.status, await response.text())
)
return
buffer = ""
async for line in response.content:
line = line.decode("utf-8")
buffer += line
if buffer.endswith("\n\n"):
event_data = self._parse_sse_event(buffer)
buffer = ""
if event_data:
event_type = event_data.get("event")
data = event_data.get("data")
if event_type == "status":
try:
event = TaskStatusUpdateEvent.model_validate_json(data)
await queue.put(event)
if event.final:
break
except ValidationError as e:
await queue.put(
A2AClientError(f"Invalid status event: {e}")
)
elif event_type == "artifact":
try:
event = TaskArtifactUpdateEvent.model_validate_json(data)
await queue.put(event)
except ValidationError as e:
await queue.put(
A2AClientError(f"Invalid artifact event: {e}")
)
except aiohttp.ClientConnectorError as e:
await queue.put(A2AClientHTTPError(status=0, message=f"Connection error: {e}"))
except aiohttp.ClientOSError as e:
await queue.put(A2AClientHTTPError(status=0, message=f"OS error: {e}"))
except aiohttp.ServerDisconnectedError as e:
await queue.put(A2AClientHTTPError(status=0, message=f"Server disconnected: {e}"))
except aiohttp.ClientResponseError as e:
await queue.put(A2AClientHTTPError(e.status, str(e)))
except aiohttp.ClientError as e:
await queue.put(A2AClientError(f"HTTP error: {e}"))
except asyncio.CancelledError:
pass
except Exception as e:
await queue.put(A2AClientError(f"Error handling streaming response: {e}"))
def _parse_sse_event(self, data: str) -> Dict[str, str]:
"""Parse an SSE event.
Args:
data: The SSE event data.
Returns:
A dictionary with the event type and data.
"""
result = {}
for line in data.split("\n"):
line = line.strip()
if not line:
continue
if line.startswith("event:"):
result["event"] = line[6:].strip()
elif line.startswith("data:"):
result["data"] = line[5:].strip()
return result

89
src/crewai/a2a/config.py Normal file
View File

@@ -0,0 +1,89 @@
"""
Configuration management for A2A protocol in CrewAI.
This module provides configuration management for the A2A protocol implementation
in CrewAI, including default values and environment variable support.
"""
import os
from typing import Dict, Optional, Union
from pydantic import BaseModel, Field
class A2AConfig(BaseModel):
"""Configuration for A2A protocol."""
server_host: str = Field(
default="0.0.0.0",
description="Host to bind the A2A server to.",
)
server_port: int = Field(
default=8000,
description="Port to bind the A2A server to.",
)
enable_cors: bool = Field(
default=True,
description="Whether to enable CORS for the A2A server.",
)
cors_origins: Optional[list[str]] = Field(
default=None,
description="CORS origins to allow. If None, all origins are allowed.",
)
client_timeout: int = Field(
default=60,
description="Timeout for A2A client requests in seconds.",
)
api_key: Optional[str] = Field(
default=None,
description="API key for A2A authentication.",
)
task_ttl: int = Field(
default=3600,
description="Time-to-live for tasks in seconds.",
)
cleanup_interval: int = Field(
default=300,
description="Interval for cleaning up expired tasks in seconds.",
)
max_history_length: int = Field(
default=100,
description="Maximum number of messages to include in task history.",
)
@classmethod
def from_env(cls) -> "A2AConfig":
"""Create a configuration from environment variables.
Environment variables are prefixed with A2A_ and are uppercase.
For example, A2A_SERVER_PORT=8080 will set server_port to 8080.
Returns:
A2AConfig: The configuration.
"""
config_dict: Dict[str, Union[str, int, bool, list[str]]] = {}
if "A2A_SERVER_HOST" in os.environ:
config_dict["server_host"] = os.environ["A2A_SERVER_HOST"]
if "A2A_SERVER_PORT" in os.environ:
config_dict["server_port"] = int(os.environ["A2A_SERVER_PORT"])
if "A2A_ENABLE_CORS" in os.environ:
config_dict["enable_cors"] = os.environ["A2A_ENABLE_CORS"].lower() == "true"
if "A2A_CORS_ORIGINS" in os.environ:
config_dict["cors_origins"] = os.environ["A2A_CORS_ORIGINS"].split(",")
if "A2A_CLIENT_TIMEOUT" in os.environ:
config_dict["client_timeout"] = int(os.environ["A2A_CLIENT_TIMEOUT"])
if "A2A_API_KEY" in os.environ:
config_dict["api_key"] = os.environ["A2A_API_KEY"]
if "A2A_TASK_TTL" in os.environ:
config_dict["task_ttl"] = int(os.environ["A2A_TASK_TTL"])
if "A2A_CLEANUP_INTERVAL" in os.environ:
config_dict["cleanup_interval"] = int(os.environ["A2A_CLEANUP_INTERVAL"])
if "A2A_MAX_HISTORY_LENGTH" in os.environ:
config_dict["max_history_length"] = int(os.environ["A2A_MAX_HISTORY_LENGTH"])
return cls(**config_dict)

515
src/crewai/a2a/server.py Normal file
View File

@@ -0,0 +1,515 @@
"""
A2A protocol server for CrewAI.
This module implements the server for the A2A protocol in CrewAI.
"""
import asyncio
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import ValidationError
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
from crewai.types.a2a import (
A2ARequest,
CancelTaskRequest,
CancelTaskResponse,
ContentTypeNotSupportedError,
GetTaskPushNotificationRequest,
GetTaskPushNotificationResponse,
GetTaskRequest,
GetTaskResponse,
InternalError,
InvalidParamsError,
InvalidRequestError,
JSONParseError,
JSONRPCError,
JSONRPCRequest,
JSONRPCResponse,
MethodNotFoundError,
SendTaskRequest,
SendTaskResponse,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
SetTaskPushNotificationRequest,
SetTaskPushNotificationResponse,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskNotCancelableError,
TaskNotFoundError,
TaskPushNotificationConfig,
TaskQueryParams,
TaskSendParams,
TaskState,
TaskStatusUpdateEvent,
UnsupportedOperationError,
)
class A2AServer:
"""A2A protocol server implementation."""
def __init__(
self,
task_manager: Optional[TaskManager] = None,
enable_cors: Optional[bool] = None,
cors_origins: Optional[List[str]] = None,
config: Optional["A2AConfig"] = None,
):
"""Initialize the A2A server.
Args:
task_manager: The task manager to use. If None, an InMemoryTaskManager will be created.
enable_cors: Whether to enable CORS. If None, uses config value.
cors_origins: The CORS origins to allow. If None, uses config value.
config: The A2A configuration. If provided, other parameters are ignored.
"""
from crewai.a2a.config import A2AConfig
self.config = config or A2AConfig.from_env()
enable_cors = enable_cors if enable_cors is not None else self.config.enable_cors
cors_origins = cors_origins or self.config.cors_origins
self.app = FastAPI(
title="A2A Protocol Server",
description="""
A2A (Agent-to-Agent) protocol server for CrewAI.
This server implements Google's A2A protocol specification, enabling interoperability
between different agent systems. It provides endpoints for task creation, retrieval,
cancellation, and streaming updates.
""",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_tags=[
{
"name": "tasks",
"description": "Operations for managing A2A tasks",
},
{
"name": "jsonrpc",
"description": "JSON-RPC interface for the A2A protocol",
},
],
)
self.task_manager = task_manager or InMemoryTaskManager()
self.logger = logging.getLogger(__name__)
if enable_cors:
self.app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins or ["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@self.app.post(
"/v1/jsonrpc",
summary="Handle JSON-RPC requests",
description="""
Process JSON-RPC requests for the A2A protocol.
This endpoint handles all JSON-RPC requests for the A2A protocol, including:
- SendTask: Create a new task
- GetTask: Retrieve a task by ID
- CancelTask: Cancel a running task
- SetTaskPushNotification: Configure push notifications for a task
- GetTaskPushNotification: Retrieve push notification configuration for a task
""",
response_model=JSONRPCResponse,
responses={
200: {"description": "Successful response with result or error"},
400: {"description": "Invalid request format or parameters"},
500: {"description": "Internal server error during processing"},
},
tags=["jsonrpc"],
)
async def handle_jsonrpc(request: Request):
return await self.handle_jsonrpc(request)
@self.app.post(
"/v1/tasks/send",
summary="Send a task to an agent",
description="""
Create a new task and send it to an agent for execution.
This endpoint allows clients to send tasks to agents for processing.
The task is created with the provided parameters and immediately
transitions to the WORKING state. The response includes the created
task with its current status.
""",
response_model=Task,
responses={
200: {"description": "Task created successfully and processing started"},
400: {"description": "Invalid request format or parameters"},
500: {"description": "Internal server error during task creation or processing"},
},
tags=["tasks"],
)
async def handle_send_task(request: Request):
return await self.handle_send_task(request)
@self.app.post(
"/v1/tasks/sendSubscribe",
summary="Send a task and subscribe to updates",
description="""
Create a new task and subscribe to status updates via Server-Sent Events (SSE).
This endpoint allows clients to send tasks to agents and receive real-time
updates as the task progresses. The response is a streaming SSE connection
that provides status updates and artifact notifications until the task
reaches a terminal state (COMPLETED, FAILED, CANCELED, or EXPIRED).
""",
responses={
200: {
"description": "Streaming response with task updates",
"content": {
"text/event-stream": {
"schema": {"type": "string"},
"example": 'event: status\ndata: {"task_id": "123", "status": {"state": "WORKING"}}\n\n',
}
},
},
400: {"description": "Invalid request format or parameters"},
500: {"description": "Internal server error during task creation or processing"},
},
tags=["tasks"],
)
async def handle_send_task_subscribe(request: Request):
return await self.handle_send_task_subscribe(request)
@self.app.post(
"/v1/tasks/{task_id}/cancel",
summary="Cancel a task",
description="""
Cancel a running task by ID.
This endpoint allows clients to cancel a task that is currently in progress.
The task must be in a non-terminal state (PENDING, WORKING) to be canceled.
Once canceled, the task transitions to the CANCELED state and cannot be
resumed. The response includes the updated task with its current status.
""",
response_model=Task,
responses={
200: {"description": "Task canceled successfully and status updated to CANCELED"},
404: {"description": "Task not found or already expired"},
409: {"description": "Task cannot be canceled (already in terminal state)"},
500: {"description": "Internal server error during task cancellation"},
},
tags=["tasks"],
)
async def handle_cancel_task(task_id: str, request: Request):
return await self.handle_cancel_task(task_id, request)
@self.app.get(
"/v1/tasks/{task_id}",
summary="Get task details",
description="""
Retrieve details of a task by ID.
This endpoint allows clients to retrieve the current state and details of a task.
The response includes the task's status, history, and any associated metadata.
Clients can specify the history_length parameter to limit the number of messages
included in the response.
""",
response_model=Task,
responses={
200: {"description": "Task details retrieved successfully with current status"},
404: {"description": "Task not found or expired"},
500: {"description": "Internal server error during task retrieval"},
},
tags=["tasks"],
)
async def handle_get_task(task_id: str, request: Request):
return await self.handle_get_task(task_id, request)
async def handle_jsonrpc(self, request: Request) -> JSONResponse:
"""Handle JSON-RPC requests.
Args:
request: The FastAPI request.
Returns:
A JSON response.
"""
try:
body = await request.json()
except json.JSONDecodeError:
return JSONResponse(
content=JSONRPCResponse(
id=None, error=JSONParseError()
).model_dump(),
status_code=400,
)
try:
if isinstance(body, list):
responses = []
for req_data in body:
response = await self._process_jsonrpc_request(req_data)
responses.append(response.model_dump())
return JSONResponse(content=responses)
else:
response = await self._process_jsonrpc_request(body)
return JSONResponse(content=response.model_dump())
except Exception as e:
self.logger.exception("Error processing JSON-RPC request")
return JSONResponse(
content=JSONRPCResponse(
id=body.get("id") if isinstance(body, dict) else None,
error=InternalError(message="Internal server error"),
).model_dump(),
status_code=500,
)
async def _process_jsonrpc_request(
self, request_data: Dict[str, Any]
) -> JSONRPCResponse:
"""Process a JSON-RPC request.
Args:
request_data: The JSON-RPC request data.
Returns:
A JSON-RPC response.
"""
if not isinstance(request_data, dict) or request_data.get("jsonrpc") != "2.0":
return JSONRPCResponse(
id=request_data.get("id") if isinstance(request_data, dict) else None,
error=InvalidRequestError(),
)
request_id = request_data.get("id")
method = request_data.get("method")
if not method:
return JSONRPCResponse(
id=request_id,
error=InvalidRequestError(message="Method is required"),
)
try:
request = A2ARequest.validate_python(request_data)
except ValidationError as e:
return JSONRPCResponse(
id=request_id,
error=InvalidParamsError(data=str(e)),
)
try:
if isinstance(request, SendTaskRequest):
task = await self._handle_send_task(request.params)
return SendTaskResponse(id=request_id, result=task)
elif isinstance(request, GetTaskRequest):
task = await self.task_manager.get_task(
request.params.id, request.params.historyLength
)
return GetTaskResponse(id=request_id, result=task)
elif isinstance(request, CancelTaskRequest):
task = await self.task_manager.cancel_task(request.params.id)
return CancelTaskResponse(id=request_id, result=task)
elif isinstance(request, SetTaskPushNotificationRequest):
config = await self.task_manager.set_push_notification(
request.params.id, request.params.pushNotificationConfig
)
return SetTaskPushNotificationResponse(
id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config)
)
elif isinstance(request, GetTaskPushNotificationRequest):
config = await self.task_manager.get_push_notification(
request.params.id
)
if config:
return GetTaskPushNotificationResponse(
id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config)
)
else:
return GetTaskPushNotificationResponse(id=request_id, result=None)
elif isinstance(request, SendTaskStreamingRequest):
return JSONRPCResponse(
id=request_id,
error=UnsupportedOperationError(
message="Streaming requests should be sent to the streaming endpoint"
),
)
else:
return JSONRPCResponse(
id=request_id,
error=MethodNotFoundError(),
)
except KeyError:
return JSONRPCResponse(
id=request_id,
error=TaskNotFoundError(),
)
except Exception as e:
self.logger.exception(f"Error handling {method} request")
return JSONRPCResponse(
id=request_id,
error=InternalError(message="Internal server error"),
)
async def handle_send_task(self, request: Request) -> JSONResponse:
"""Handle send task requests.
Args:
request: The FastAPI request.
Returns:
A JSON response.
"""
try:
body = await request.json()
params = TaskSendParams.model_validate(body)
task = await self._handle_send_task(params)
return JSONResponse(content=task.model_dump())
except ValidationError:
return JSONResponse(
content={"error": "Invalid request format or parameters"},
status_code=400,
)
except Exception as e:
self.logger.exception("Error handling send task request")
return JSONResponse(
content={"error": "Internal server error"},
status_code=500,
)
async def _handle_send_task(self, params: TaskSendParams) -> Task:
"""Handle send task requests.
Args:
params: The task send parameters.
Returns:
The created task.
"""
task = await self.task_manager.create_task(
task_id=params.id,
session_id=params.sessionId,
message=params.message,
metadata=params.metadata,
)
await self.task_manager.update_task_status(
task_id=params.id,
state=TaskState.WORKING,
)
return task
async def handle_send_task_subscribe(self, request: Request) -> StreamingResponse:
"""Handle send task subscribe requests.
Args:
request: The FastAPI request.
Returns:
A streaming response.
"""
try:
body = await request.json()
params = TaskSendParams.model_validate(body)
task = await self._handle_send_task(params)
queue = await self.task_manager.subscribe_to_task(params.id)
return StreamingResponse(
self._stream_task_updates(params.id, queue),
media_type="text/event-stream",
)
except ValidationError:
return JSONResponse(
content={"error": "Invalid request format or parameters"},
status_code=400,
)
except Exception as e:
self.logger.exception("Error handling send task subscribe request")
return JSONResponse(
content={"error": "Internal server error"},
status_code=500,
)
async def _stream_task_updates(
self, task_id: str, queue: asyncio.Queue
) -> None:
"""Stream task updates.
Args:
task_id: The ID of the task.
queue: The queue to receive updates from.
Yields:
SSE formatted events.
"""
try:
while True:
event = await queue.get()
if isinstance(event, TaskStatusUpdateEvent):
event_type = "status"
elif isinstance(event, TaskArtifactUpdateEvent):
event_type = "artifact"
else:
event_type = "unknown"
data = json.dumps(event.model_dump())
yield f"event: {event_type}\ndata: {data}\n\n"
if isinstance(event, TaskStatusUpdateEvent) and event.final:
break
finally:
await self.task_manager.unsubscribe_from_task(task_id, queue)
async def handle_get_task(self, task_id: str, request: Request) -> JSONResponse:
"""Handle get task requests.
Args:
task_id: The ID of the task.
request: The FastAPI request.
Returns:
A JSON response.
"""
try:
history_length = request.query_params.get("historyLength")
history_length = int(history_length) if history_length else None
task = await self.task_manager.get_task(task_id, history_length)
return JSONResponse(content=task.model_dump())
except KeyError:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
except Exception as e:
self.logger.exception(f"Error handling get task request for {task_id}")
raise HTTPException(status_code=500, detail="Internal server error")
async def handle_cancel_task(self, task_id: str, request: Request) -> JSONResponse:
"""Handle cancel task requests.
Args:
task_id: The ID of the task.
request: The FastAPI request.
Returns:
A JSON response.
"""
try:
task = await self.task_manager.cancel_task(task_id)
return JSONResponse(content=task.model_dump())
except KeyError:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
except Exception as e:
self.logger.exception(f"Error handling cancel task request for {task_id}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,522 @@
"""
A2A protocol task manager for CrewAI.
This module implements the task manager for the A2A protocol in CrewAI.
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
from uuid import uuid4
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
from crewai.types.a2a import (
Artifact,
Message,
PushNotificationConfig,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
class TaskManager(ABC):
"""Abstract base class for A2A task managers."""
@abstractmethod
async def create_task(
self,
task_id: str,
session_id: Optional[str] = None,
message: Optional[Message] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Task:
"""Create a new task.
Args:
task_id: The ID of the task.
session_id: The session ID.
message: The initial message.
metadata: Additional metadata.
Returns:
The created task.
"""
pass
@abstractmethod
async def get_task(
self, task_id: str, history_length: Optional[int] = None
) -> Task:
"""Get a task by ID.
Args:
task_id: The ID of the task.
history_length: The number of messages to include in the history.
Returns:
The task.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def update_task_status(
self,
task_id: str,
state: TaskState,
message: Optional[Message] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> TaskStatusUpdateEvent:
"""Update the status of a task.
Args:
task_id: The ID of the task.
state: The new state of the task.
message: An optional message to include with the status update.
metadata: Additional metadata.
Returns:
The task status update event.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def add_task_artifact(
self,
task_id: str,
artifact: Artifact,
metadata: Optional[Dict[str, Any]] = None,
) -> TaskArtifactUpdateEvent:
"""Add an artifact to a task.
Args:
task_id: The ID of the task.
artifact: The artifact to add.
metadata: Additional metadata.
Returns:
The task artifact update event.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def cancel_task(self, task_id: str) -> Task:
"""Cancel a task.
Args:
task_id: The ID of the task.
Returns:
The canceled task.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def set_push_notification(
self, task_id: str, config: PushNotificationConfig
) -> PushNotificationConfig:
"""Set push notification for a task.
Args:
task_id: The ID of the task.
config: The push notification configuration.
Returns:
The push notification configuration.
Raises:
KeyError: If the task is not found.
"""
pass
@abstractmethod
async def get_push_notification(
self, task_id: str
) -> Optional[PushNotificationConfig]:
"""Get push notification for a task.
Args:
task_id: The ID of the task.
Returns:
The push notification configuration, or None if not set.
Raises:
KeyError: If the task is not found.
"""
pass
class InMemoryTaskManager(TaskManager):
"""In-memory implementation of the A2A task manager."""
def __init__(
self,
task_ttl: Optional[int] = None,
cleanup_interval: Optional[int] = None,
config: Optional["A2AConfig"] = None,
):
"""Initialize the in-memory task manager.
Args:
task_ttl: Time to live for tasks in seconds. Default is 1 hour.
cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes.
config: The A2A configuration. If provided, other parameters are ignored.
"""
from crewai.a2a.config import A2AConfig
self.config = config or A2AConfig.from_env()
self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl
self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval
self._tasks: Dict[str, Task] = {}
self._push_notifications: Dict[str, PushNotificationConfig] = {}
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
self._task_timestamps: Dict[str, datetime] = {}
self._logger = logging.getLogger(__name__)
self._cleanup_task = None
try:
if asyncio.get_running_loop():
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
except RuntimeError:
self._logger.info("No running event loop, periodic cleanup disabled")
async def create_task(
self,
task_id: str,
session_id: Optional[str] = None,
message: Optional[Message] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Task:
"""Create a new task.
Args:
task_id: The ID of the task.
session_id: The session ID.
message: The initial message.
metadata: Additional metadata.
Returns:
The created task.
"""
if task_id in self._tasks:
return self._tasks[task_id]
session_id = session_id or uuid4().hex
status = TaskStatus(
state=TaskState.SUBMITTED,
message=message,
timestamp=datetime.now(),
previous_state=None, # Initial state has no previous state
)
task = Task(
id=task_id,
sessionId=session_id,
status=status,
artifacts=[],
history=[message] if message else [],
metadata=metadata or {},
)
self._tasks[task_id] = task
self._task_subscribers[task_id] = set()
self._task_timestamps[task_id] = datetime.now()
return task
async def get_task(
self, task_id: str, history_length: Optional[int] = None
) -> Task:
"""Get a task by ID.
Args:
task_id: The ID of the task.
history_length: The number of messages to include in the history.
Returns:
The task.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
if history_length is not None and task.history:
task_copy = task.model_copy(deep=True)
task_copy.history = task.history[-history_length:]
return task_copy
return task
async def update_task_status(
self,
task_id: str,
state: TaskState,
message: Optional[Message] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> TaskStatusUpdateEvent:
"""Update the status of a task.
Args:
task_id: The ID of the task.
state: The new state of the task.
message: An optional message to include with the status update.
metadata: Additional metadata.
Returns:
The task status update event.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
task = self._tasks[task_id]
previous_state = task.status.state if task.status else None
if previous_state and not TaskState.is_valid_transition(previous_state, state):
raise ValueError(f"Invalid state transition from {previous_state} to {state}")
status = TaskStatus(
state=state,
message=message,
timestamp=datetime.now(),
previous_state=previous_state,
)
task.status = status
if message and task.history is not None:
task.history.append(message)
self._task_timestamps[task_id] = datetime.now()
event = TaskStatusUpdateEvent(
id=task_id,
status=status,
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED],
metadata=metadata or {},
)
await self._notify_subscribers(task_id, event)
return event
async def add_task_artifact(
self,
task_id: str,
artifact: Artifact,
metadata: Optional[Dict[str, Any]] = None,
) -> TaskArtifactUpdateEvent:
"""Add an artifact to a task.
Args:
task_id: The ID of the task.
artifact: The artifact to add.
metadata: Additional metadata.
Returns:
The task artifact update event.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
if task.artifacts is None:
task.artifacts = []
if artifact.append and task.artifacts:
for existing in task.artifacts:
if existing.name == artifact.name:
existing.parts.extend(artifact.parts)
existing.lastChunk = artifact.lastChunk
break
else:
task.artifacts.append(artifact)
else:
task.artifacts.append(artifact)
event = TaskArtifactUpdateEvent(
id=task_id,
artifact=artifact,
metadata=metadata or {},
)
await self._notify_subscribers(task_id, event)
return event
async def cancel_task(self, task_id: str) -> Task:
"""Cancel a task.
Args:
task_id: The ID of the task.
Returns:
The canceled task.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
if task.status.state not in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED]:
await self.update_task_status(task_id, TaskState.CANCELED)
return task
async def set_push_notification(
self, task_id: str, config: PushNotificationConfig
) -> PushNotificationConfig:
"""Set push notification for a task.
Args:
task_id: The ID of the task.
config: The push notification configuration.
Returns:
The push notification configuration.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
self._push_notifications[task_id] = config
return config
async def get_push_notification(
self, task_id: str
) -> Optional[PushNotificationConfig]:
"""Get push notification for a task.
Args:
task_id: The ID of the task.
Returns:
The push notification configuration, or None if not set.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
return self._push_notifications.get(task_id)
async def subscribe_to_task(self, task_id: str) -> asyncio.Queue:
"""Subscribe to task updates.
Args:
task_id: The ID of the task.
Returns:
A queue that will receive task updates.
Raises:
KeyError: If the task is not found.
"""
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
queue: asyncio.Queue = asyncio.Queue()
self._task_subscribers.setdefault(task_id, set()).add(queue)
return queue
async def unsubscribe_from_task(self, task_id: str, queue: asyncio.Queue) -> None:
"""Unsubscribe from task updates.
Args:
task_id: The ID of the task.
queue: The queue to unsubscribe.
"""
if task_id in self._task_subscribers:
self._task_subscribers[task_id].discard(queue)
async def _notify_subscribers(
self,
task_id: str,
event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent],
) -> None:
"""Notify subscribers of a task update.
Args:
task_id: The ID of the task.
event: The event to send to subscribers.
"""
if task_id in self._task_subscribers:
for queue in self._task_subscribers[task_id]:
await queue.put(event)
async def _periodic_cleanup(self) -> None:
"""Periodically clean up expired tasks."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired_tasks()
except asyncio.CancelledError:
break
except Exception as e:
self._logger.exception(f"Error during periodic cleanup: {e}")
async def _cleanup_expired_tasks(self) -> None:
"""Clean up expired tasks."""
now = datetime.now()
expired_tasks = []
for task_id, timestamp in self._task_timestamps.items():
if (now - timestamp).total_seconds() > self._task_ttl:
expired_tasks.append(task_id)
for task_id in expired_tasks:
self._logger.info(f"Cleaning up expired task: {task_id}")
self._tasks.pop(task_id, None)
self._push_notifications.pop(task_id, None)
self._task_timestamps.pop(task_id, None)
if task_id in self._task_subscribers:
previous_state = None
if task_id in self._tasks and self._tasks[task_id].status:
previous_state = self._tasks[task_id].status.state
status = TaskStatus(
state=TaskState.EXPIRED,
timestamp=now,
previous_state=previous_state,
)
event = TaskStatusUpdateEvent(
task_id=task_id,
status=status,
final=True,
)
await self._notify_subscribers(task_id, event)
self._task_subscribers.pop(task_id, None)

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.a2a import A2AAgentIntegration
from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
@@ -131,14 +132,29 @@ class Agent(BaseAgent):
default=None,
description="Knowledge sources for the agent.",
)
a2a_enabled: bool = Field(
default=False,
description="Whether the agent supports the A2A protocol.",
)
a2a_url: Optional[str] = Field(
default=None,
description="The URL where the agent's A2A server is hosted.",
)
_knowledge: Optional[Knowledge] = PrivateAttr(
default=None,
)
_a2a_integration: Optional[A2AAgentIntegration] = PrivateAttr(
default=None,
)
@model_validator(mode="after")
def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role
if self.a2a_enabled:
self._a2a_integration = A2AAgentIntegration()
unaccepted_attributes = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
@@ -355,6 +371,103 @@ class Agent(BaseAgent):
result = tool_result["result"]
return result
async def execute_task_via_a2a(
self,
task_description: str,
context: Optional[str] = None,
agent_url: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 300,
) -> str:
"""Execute a task via the A2A protocol.
Args:
task_description: The description of the task.
context: Additional context for the task.
agent_url: The URL of the agent to execute the task. Defaults to self.a2a_url.
api_key: The API key to use for authentication.
timeout: The timeout for the task execution in seconds.
Returns:
The result of the task execution.
Raises:
ValueError: If A2A is not enabled or no agent URL is provided.
TimeoutError: If the task execution times out.
Exception: If there is an error executing the task.
"""
if not self.a2a_enabled:
raise ValueError("A2A protocol is not enabled for this agent")
if not self._a2a_integration:
self._a2a_integration = A2AAgentIntegration()
url = agent_url or self.a2a_url
if not url:
raise ValueError("No A2A agent URL provided")
try:
import asyncio
if asyncio.get_event_loop().is_running():
return await self._a2a_integration.execute_task_via_a2a(
agent_url=url,
task_description=task_description,
context=context,
api_key=api_key,
timeout=timeout,
)
else:
return asyncio.run(self._a2a_integration.execute_task_via_a2a(
agent_url=url,
task_description=task_description,
context=context,
api_key=api_key,
timeout=timeout,
))
except Exception as e:
self._logger.exception(f"Error executing task via A2A: {e}")
raise
async def handle_a2a_task(
self,
task_id: str,
task_description: str,
context: Optional[str] = None,
) -> str:
"""Handle an A2A task.
Args:
task_id: The ID of the A2A task.
task_description: The description of the task.
context: Additional context for the task.
Returns:
The result of the task execution.
Raises:
ValueError: If A2A is not enabled.
Exception: If there is an error handling the task.
"""
if not self.a2a_enabled:
raise ValueError("A2A protocol is not enabled for this agent")
if not self._a2a_integration:
self._a2a_integration = A2AAgentIntegration()
# Create a Task object from the task description
task = Task(
description=task_description,
agent=self,
expected_output="text", # Default to text output
)
try:
result = self.execute_task(task=task, context=context)
return result
except Exception as e:
self._logger.exception(f"Error handling A2A task: {e}")
raise
def create_agent_executor(
self, tools: Optional[List[BaseTool]] = None, task=None

View File

@@ -0,0 +1 @@
"""Type definitions for CrewAI."""

469
src/crewai/types/a2a.py Normal file
View File

@@ -0,0 +1,469 @@
"""
A2A protocol types for CrewAI.
This module implements the A2A (Agent-to-Agent) protocol types as defined by Google.
The A2A protocol enables interoperability between different agent systems.
For more information, see: https://developers.googleblog.com/en/a2a-a-new-era-of-agent-interoperability/
"""
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Self, Union
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, field_serializer, model_validator
class TaskState(str, Enum):
"""Task state in the A2A protocol."""
SUBMITTED = 'submitted'
WORKING = 'working'
INPUT_REQUIRED = 'input-required'
COMPLETED = 'completed'
CANCELED = 'canceled'
FAILED = 'failed'
UNKNOWN = 'unknown'
EXPIRED = 'expired'
@classmethod
def valid_transitions(cls) -> Dict[str, List[str]]:
"""Get valid state transitions.
Returns:
A dictionary mapping from state to list of valid next states.
"""
return {
cls.SUBMITTED: [cls.WORKING, cls.CANCELED, cls.FAILED],
cls.WORKING: [cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
cls.INPUT_REQUIRED: [cls.WORKING, cls.CANCELED, cls.FAILED],
cls.COMPLETED: [], # Terminal state
cls.CANCELED: [], # Terminal state
cls.FAILED: [], # Terminal state
cls.UNKNOWN: [cls.SUBMITTED, cls.WORKING, cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
cls.EXPIRED: [], # Terminal state
}
@classmethod
def is_valid_transition(cls, from_state: 'TaskState', to_state: 'TaskState') -> bool:
"""Check if a state transition is valid.
Args:
from_state: The current state.
to_state: The target state.
Returns:
True if the transition is valid, False otherwise.
"""
if from_state == to_state:
return True
valid_next_states = cls.valid_transitions().get(from_state, [])
return to_state in valid_next_states
class TextPart(BaseModel):
"""Text part in the A2A protocol."""
type: Literal['text'] = 'text'
text: str
metadata: Optional[Dict[str, Any]] = None
class FileContent(BaseModel):
"""File content in the A2A protocol."""
name: Optional[str] = None
mimeType: Optional[str] = None
bytes: Optional[str] = None
uri: Optional[str] = None
@model_validator(mode='after')
def check_content(self) -> Self:
"""Validate file content has either bytes or uri."""
if not (self.bytes or self.uri):
raise ValueError(
"Either 'bytes' or 'uri' must be present in the file data"
)
if self.bytes and self.uri:
raise ValueError(
"Only one of 'bytes' or 'uri' can be present in the file data"
)
return self
class FilePart(BaseModel):
"""File part in the A2A protocol."""
type: Literal['file'] = 'file'
file: FileContent
metadata: Optional[Dict[str, Any]] = None
class DataPart(BaseModel):
"""Data part in the A2A protocol."""
type: Literal['data'] = 'data'
data: Dict[str, Any]
metadata: Optional[Dict[str, Any]] = None
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator='type')]
class Message(BaseModel):
"""Message in the A2A protocol."""
role: Literal['user', 'agent']
parts: List[Part]
metadata: Optional[Dict[str, Any]] = None
class TaskStatus(BaseModel):
"""Task status in the A2A protocol."""
state: TaskState
message: Optional[Message] = None
timestamp: datetime = Field(default_factory=datetime.now)
previous_state: Optional[TaskState] = None
@field_serializer('timestamp')
def serialize_dt(self, dt: datetime, _info):
"""Serialize datetime to ISO format."""
return dt.isoformat()
@model_validator(mode='after')
def validate_state_transition(self) -> Self:
"""Validate state transition."""
if self.previous_state and not TaskState.is_valid_transition(self.previous_state, self.state):
raise ValueError(
f"Invalid state transition from {self.previous_state} to {self.state}"
)
return self
class Artifact(BaseModel):
"""Artifact in the A2A protocol."""
name: Optional[str] = None
description: Optional[str] = None
parts: List[Part]
metadata: Optional[Dict[str, Any]] = None
index: int = 0
append: Optional[bool] = None
lastChunk: Optional[bool] = None
class Task(BaseModel):
"""Task in the A2A protocol."""
id: str
sessionId: Optional[str] = None
status: TaskStatus
artifacts: Optional[List[Artifact]] = None
history: Optional[List[Message]] = None
metadata: Optional[Dict[str, Any]] = None
class TaskStatusUpdateEvent(BaseModel):
"""Task status update event in the A2A protocol."""
id: str
status: TaskStatus
final: bool = False
metadata: Optional[Dict[str, Any]] = None
class TaskArtifactUpdateEvent(BaseModel):
"""Task artifact update event in the A2A protocol."""
id: str
artifact: Artifact
metadata: Optional[Dict[str, Any]] = None
class AuthenticationInfo(BaseModel):
"""Authentication information in the A2A protocol."""
model_config = ConfigDict(extra='allow')
schemes: List[str]
credentials: Optional[str] = None
class PushNotificationConfig(BaseModel):
"""Push notification configuration in the A2A protocol."""
url: str
token: Optional[str] = None
authentication: Optional[AuthenticationInfo] = None
class TaskIdParams(BaseModel):
"""Task ID parameters in the A2A protocol."""
id: str
metadata: Optional[Dict[str, Any]] = None
class TaskQueryParams(TaskIdParams):
"""Task query parameters in the A2A protocol."""
historyLength: Optional[int] = None
class TaskSendParams(BaseModel):
"""Task send parameters in the A2A protocol."""
id: str
sessionId: str = Field(default_factory=lambda: uuid4().hex)
message: Message
acceptedOutputModes: Optional[List[str]] = None
pushNotification: Optional[PushNotificationConfig] = None
historyLength: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
class TaskPushNotificationConfig(BaseModel):
"""Task push notification configuration in the A2A protocol."""
id: str
pushNotificationConfig: PushNotificationConfig
class JSONRPCMessage(BaseModel):
"""JSON-RPC message in the A2A protocol."""
jsonrpc: Literal['2.0'] = '2.0'
id: Optional[Union[int, str]] = Field(default_factory=lambda: uuid4().hex)
class JSONRPCRequest(JSONRPCMessage):
"""JSON-RPC request in the A2A protocol."""
method: str
params: Optional[Dict[str, Any]] = None
class JSONRPCError(BaseModel):
"""JSON-RPC error in the A2A protocol."""
code: int
message: str
data: Optional[Any] = None
class JSONRPCResponse(JSONRPCMessage):
"""JSON-RPC response in the A2A protocol."""
result: Optional[Any] = None
error: Optional[JSONRPCError] = None
class SendTaskRequest(JSONRPCRequest):
"""Send task request in the A2A protocol."""
method: Literal['tasks/send'] = 'tasks/send'
params: TaskSendParams
class SendTaskResponse(JSONRPCResponse):
"""Send task response in the A2A protocol."""
result: Optional[Task] = None
class SendTaskStreamingRequest(JSONRPCRequest):
"""Send task streaming request in the A2A protocol."""
method: Literal['tasks/sendSubscribe'] = 'tasks/sendSubscribe'
params: TaskSendParams
class SendTaskStreamingResponse(JSONRPCResponse):
"""Send task streaming response in the A2A protocol."""
result: Optional[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]] = None
class GetTaskRequest(JSONRPCRequest):
"""Get task request in the A2A protocol."""
method: Literal['tasks/get'] = 'tasks/get'
params: TaskQueryParams
class GetTaskResponse(JSONRPCResponse):
"""Get task response in the A2A protocol."""
result: Optional[Task] = None
class CancelTaskRequest(JSONRPCRequest):
"""Cancel task request in the A2A protocol."""
method: Literal['tasks/cancel'] = 'tasks/cancel'
params: TaskIdParams
class CancelTaskResponse(JSONRPCResponse):
"""Cancel task response in the A2A protocol."""
result: Optional[Task] = None
class SetTaskPushNotificationRequest(JSONRPCRequest):
"""Set task push notification request in the A2A protocol."""
method: Literal['tasks/pushNotification/set'] = 'tasks/pushNotification/set'
params: TaskPushNotificationConfig
class SetTaskPushNotificationResponse(JSONRPCResponse):
"""Set task push notification response in the A2A protocol."""
result: Optional[TaskPushNotificationConfig] = None
class GetTaskPushNotificationRequest(JSONRPCRequest):
"""Get task push notification request in the A2A protocol."""
method: Literal['tasks/pushNotification/get'] = 'tasks/pushNotification/get'
params: TaskIdParams
class GetTaskPushNotificationResponse(JSONRPCResponse):
"""Get task push notification response in the A2A protocol."""
result: Optional[TaskPushNotificationConfig] = None
class TaskResubscriptionRequest(JSONRPCRequest):
"""Task resubscription request in the A2A protocol."""
method: Literal['tasks/resubscribe'] = 'tasks/resubscribe'
params: TaskIdParams
A2ARequest = TypeAdapter(
Annotated[
Union[
SendTaskRequest,
GetTaskRequest,
CancelTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
],
Field(discriminator='method'),
]
)
class JSONParseError(JSONRPCError):
"""JSON parse error in the A2A protocol."""
code: int = -32700
message: str = 'Invalid JSON payload'
data: Optional[Any] = None
class InvalidRequestError(JSONRPCError):
"""Invalid request error in the A2A protocol."""
code: int = -32600
message: str = 'Request payload validation error'
data: Optional[Any] = None
class MethodNotFoundError(JSONRPCError):
"""Method not found error in the A2A protocol."""
code: int = -32601
message: str = 'Method not found'
data: None = None
class InvalidParamsError(JSONRPCError):
"""Invalid parameters error in the A2A protocol."""
code: int = -32602
message: str = 'Invalid parameters'
data: Optional[Any] = None
class InternalError(JSONRPCError):
"""Internal error in the A2A protocol."""
code: int = -32603
message: str = 'Internal error'
data: Optional[Any] = None
class TaskNotFoundError(JSONRPCError):
"""Task not found error in the A2A protocol."""
code: int = -32001
message: str = 'Task not found'
data: None = None
class TaskNotCancelableError(JSONRPCError):
"""Task not cancelable error in the A2A protocol."""
code: int = -32002
message: str = 'Task cannot be canceled'
data: None = None
class PushNotificationNotSupportedError(JSONRPCError):
"""Push notification not supported error in the A2A protocol."""
code: int = -32003
message: str = 'Push Notification is not supported'
data: None = None
class UnsupportedOperationError(JSONRPCError):
"""Unsupported operation error in the A2A protocol."""
code: int = -32004
message: str = 'This operation is not supported'
data: None = None
class ContentTypeNotSupportedError(JSONRPCError):
"""Content type not supported error in the A2A protocol."""
code: int = -32005
message: str = 'Incompatible content types'
data: None = None
class AgentProvider(BaseModel):
"""Agent provider in the A2A protocol."""
organization: str
url: Optional[str] = None
class AgentCapabilities(BaseModel):
"""Agent capabilities in the A2A protocol."""
streaming: bool = False
pushNotifications: bool = False
stateTransitionHistory: bool = False
class AgentAuthentication(BaseModel):
"""Agent authentication in the A2A protocol."""
schemes: List[str]
credentials: Optional[str] = None
class AgentSkill(BaseModel):
"""Agent skill in the A2A protocol."""
id: str
name: str
description: Optional[str] = None
tags: Optional[List[str]] = None
examples: Optional[List[str]] = None
inputModes: Optional[List[str]] = None
outputModes: Optional[List[str]] = None
class AgentCard(BaseModel):
"""Agent card in the A2A protocol."""
name: str
description: Optional[str] = None
url: str
provider: Optional[AgentProvider] = None
version: str
documentationUrl: Optional[str] = None
capabilities: AgentCapabilities
authentication: Optional[AgentAuthentication] = None
defaultInputModes: List[str] = ['text']
defaultOutputModes: List[str] = ['text']
skills: List[AgentSkill]
class A2AClientError(Exception):
"""Base exception for A2A client errors."""
pass
class A2AClientHTTPError(A2AClientError):
"""HTTP error in the A2A client."""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
self.message = message
super().__init__(f'HTTP Error {status_code}: {message}')
class A2AClientJSONError(A2AClientError):
"""JSON error in the A2A client."""
def __init__(self, message: str):
self.message = message
super().__init__(f'JSON Error: {message}')
class MissingAPIKeyError(Exception):
"""Exception for missing API key."""
pass

1
tests/a2a/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Tests for the A2A protocol implementation."""

View File

@@ -0,0 +1,240 @@
"""Tests for the A2A protocol integration."""
import asyncio
from datetime import datetime
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
pytestmark = pytest.mark.asyncio
from crewai.agent import Agent
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager
from crewai.task import Task
from crewai.types.a2a import (
JSONRPCResponse,
Message,
Task as A2ATask,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
)
@pytest.fixture
def agent():
"""Create an agent with A2A enabled."""
return Agent(
role="test_agent",
goal="Test A2A protocol",
backstory="I am a test agent",
a2a_enabled=True,
a2a_url="http://localhost:8000",
)
@pytest.fixture
def task():
"""Create a task."""
return Task(
description="Test task",
)
@pytest.fixture
def a2a_task():
"""Create an A2A task."""
return A2ATask(
id="test_task_id",
history=[
Message(
role="user",
parts=[TextPart(text="Test task description")],
)
],
)
@pytest.fixture
def a2a_integration():
"""Create an A2A integration."""
return A2AAgentIntegration()
@pytest.fixture
def a2a_client():
"""Create an A2A client."""
return A2AClient(base_url="http://localhost:8000", api_key="test_api_key")
@pytest.fixture
def task_manager():
"""Create a task manager."""
return InMemoryTaskManager()
class TestA2AIntegration:
"""Tests for the A2A protocol integration."""
def test_agent_a2a_attributes(self, agent):
"""Test that the agent has A2A attributes."""
assert agent.a2a_enabled is True
assert agent.a2a_url == "http://localhost:8000"
assert agent._a2a_integration is not None
@patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a")
def test_execute_task_via_a2a(self, mock_execute, agent):
"""Test executing a task via A2A."""
mock_execute.return_value = "Task result"
result = asyncio.run(
agent.execute_task_via_a2a(
task_description="Test task",
context="Test context",
)
)
assert result == "Task result"
mock_execute.assert_called_once_with(
agent_url="http://localhost:8000",
task_description="Test task",
context="Test context",
api_key=None,
timeout=300,
)
@patch("crewai.agent.Agent.execute_task")
def test_handle_a2a_task(self, mock_execute, agent):
"""Test handling an A2A task."""
mock_execute.return_value = "Task result"
result = asyncio.run(
agent.handle_a2a_task(
task_id="test_task_id",
task_description="Test task",
context="Test context",
)
)
assert result == "Task result"
mock_execute.assert_called_once()
args, kwargs = mock_execute.call_args
assert kwargs["context"] == "Test context"
assert kwargs["task"].description == "Test task"
def test_a2a_disabled(self, agent):
"""Test that A2A methods raise ValueError when A2A is disabled."""
agent.a2a_enabled = False
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
asyncio.run(
agent.execute_task_via_a2a(
task_description="Test task",
)
)
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
asyncio.run(
agent.handle_a2a_task(
task_id="test_task_id",
task_description="Test task",
)
)
def test_no_agent_url(self, agent):
"""Test that execute_task_via_a2a raises ValueError when no agent URL is provided."""
agent.a2a_url = None
with pytest.raises(ValueError, match="No A2A agent URL provided"):
asyncio.run(
agent.execute_task_via_a2a(
task_description="Test task",
)
)
class TestA2AAgentIntegration:
"""Tests for the A2AAgentIntegration class."""
@patch("crewai.a2a.client.A2AClient.send_task_streaming")
async def test_execute_task_via_a2a(self, mock_send_task, a2a_integration):
"""Test executing a task via A2A."""
queue = asyncio.Queue()
await queue.put(
TaskStatusUpdateEvent(
id="test_task_id",
status=TaskStatus(
state=TaskState.COMPLETED,
message=Message(
role="agent",
parts=[TextPart(text="Task result")],
),
),
final=True,
)
)
mock_send_task.return_value = queue
result = await a2a_integration.execute_task_via_a2a(
agent_url="http://localhost:8000",
task_description="Test task",
context="Test context",
)
assert result == "Task result"
mock_send_task.assert_called_once()
class TestA2AServer:
"""Tests for the A2AServer class."""
@patch("fastapi.FastAPI.post")
def test_server_initialization(self, mock_post, task_manager):
"""Test server initialization."""
server = A2AServer(task_manager=task_manager)
assert server.task_manager == task_manager
assert server.app is not None
assert mock_post.call_count == 4 # 4 endpoints registered
class TestA2AClient:
"""Tests for the A2AClient class."""
@patch("crewai.a2a.client.A2AClient._send_jsonrpc_request")
async def test_send_task(self, mock_send_request, a2a_client):
"""Test sending a task."""
mock_response = JSONRPCResponse(
jsonrpc="2.0",
id="test_request_id",
result=A2ATask(
id="test_task_id",
sessionId="test_session_id",
status=TaskStatus(
state=TaskState.SUBMITTED,
timestamp=datetime.now(),
),
history=[
Message(
role="user",
parts=[TextPart(text="Test task description")],
)
],
)
)
mock_send_request.return_value = mock_response
task = await a2a_client.send_task(
task_id="test_task_id",
message=Message(
role="user",
parts=[TextPart(text="Test task description")],
),
session_id="test_session_id",
)
assert task.id == "test_task_id"
assert task.history[0].role == "user"
assert task.history[0].parts[0].text == "Test task description"
mock_send_request.assert_called_once()