From cfabb9fa78ac745274a76c6a69608c571bcc5c93 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Fri, 9 May 2025 04:13:04 +0000 Subject: [PATCH] Add A2A protocol support for CrewAI (Issue #2796) Co-Authored-By: Joe Moura --- examples/a2a_protocol_example.py | 59 ++++ src/crewai/__init__.py | 5 + src/crewai/a2a/__init__.py | 14 + src/crewai/a2a/agent.py | 223 +++++++++++++++ src/crewai/a2a/client.py | 424 +++++++++++++++++++++++++++++ src/crewai/a2a/server.py | 368 +++++++++++++++++++++++++ src/crewai/a2a/task_manager.py | 438 ++++++++++++++++++++++++++++++ src/crewai/agent.py | 112 ++++++++ src/crewai/types/__init__.py | 1 + src/crewai/types/a2a.py | 423 +++++++++++++++++++++++++++++ tests/a2a/__init__.py | 1 + tests/a2a/test_a2a_integration.py | 230 ++++++++++++++++ 12 files changed, 2298 insertions(+) create mode 100644 examples/a2a_protocol_example.py create mode 100644 src/crewai/a2a/__init__.py create mode 100644 src/crewai/a2a/agent.py create mode 100644 src/crewai/a2a/client.py create mode 100644 src/crewai/a2a/server.py create mode 100644 src/crewai/a2a/task_manager.py create mode 100644 src/crewai/types/a2a.py create mode 100644 tests/a2a/__init__.py create mode 100644 tests/a2a/test_a2a_integration.py diff --git a/examples/a2a_protocol_example.py b/examples/a2a_protocol_example.py new file mode 100644 index 000000000..5e964ced9 --- /dev/null +++ b/examples/a2a_protocol_example.py @@ -0,0 +1,59 @@ +""" +Example of using the A2A protocol with CrewAI. + +This example demonstrates how to: +1. Create an agent with A2A protocol support +2. Start an A2A server for the agent +3. Execute a task via the A2A protocol +""" + +import asyncio +import os +import uvicorn +from threading import Thread + +from crewai import Agent +from crewai.a2a import A2AServer, InMemoryTaskManager + + +agent = Agent( + role="Data Analyst", + goal="Analyze data and provide insights", + backstory="I am a data analyst with expertise in finding patterns and insights in data.", + a2a_enabled=True, + a2a_url="http://localhost:8000", +) + + +def start_server(): + """Start the A2A server.""" + task_manager = InMemoryTaskManager() + + server = A2AServer(task_manager=task_manager) + + uvicorn.run(server.app, host="0.0.0.0", port=8000) + + +async def execute_task_via_a2a(): + """Execute a task via the A2A protocol.""" + await asyncio.sleep(2) + + result = await agent.execute_task_via_a2a( + task_description="Analyze the following data and provide insights: [1, 2, 3, 4, 5]", + context="This is a simple example of using the A2A protocol.", + ) + + print(f"Task result: {result}") + + +async def main(): + """Run the example.""" + server_thread = Thread(target=start_server) + server_thread.daemon = True + server_thread.start() + + await execute_task_via_a2a() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/crewai/__init__.py b/src/crewai/__init__.py index 0833afd58..09015f901 100644 --- a/src/crewai/__init__.py +++ b/src/crewai/__init__.py @@ -23,4 +23,9 @@ __all__ = [ "LLM", "Flow", "Knowledge", + "A2AAgentIntegration", + "A2AClient", + "A2AServer", ] + +from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer diff --git a/src/crewai/a2a/__init__.py b/src/crewai/a2a/__init__.py new file mode 100644 index 000000000..0ebf9ae20 --- /dev/null +++ b/src/crewai/a2a/__init__.py @@ -0,0 +1,14 @@ +"""A2A protocol implementation for CrewAI.""" + +from crewai.a2a.agent import A2AAgentIntegration +from crewai.a2a.client import A2AClient +from crewai.a2a.server import A2AServer +from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager + +__all__ = [ + "A2AAgentIntegration", + "A2AClient", + "A2AServer", + "TaskManager", + "InMemoryTaskManager", +] diff --git a/src/crewai/a2a/agent.py b/src/crewai/a2a/agent.py new file mode 100644 index 000000000..03bb35920 --- /dev/null +++ b/src/crewai/a2a/agent.py @@ -0,0 +1,223 @@ +""" +A2A protocol agent integration for CrewAI. + +This module implements the integration between CrewAI agents and the A2A protocol. +""" + +import asyncio +import json +import logging +import uuid +from typing import Any, Dict, List, Optional, Union + +from crewai.a2a.client import A2AClient +from crewai.a2a.task_manager import TaskManager +from crewai.types.a2a import ( + Artifact, + DataPart, + FilePart, + Message, + Part, + Task as A2ATask, + TaskArtifactUpdateEvent, + TaskState, + TaskStatusUpdateEvent, + TextPart, +) + + +class A2AAgentIntegration: + """Integration between CrewAI agents and the A2A protocol.""" + + def __init__( + self, + task_manager: Optional[TaskManager] = None, + client: Optional[A2AClient] = None, + ): + """Initialize the A2A agent integration. + + Args: + task_manager: The task manager to use for handling A2A tasks. + client: The A2A client to use for sending tasks to other agents. + """ + self.task_manager = task_manager + self.client = client + self.logger = logging.getLogger(__name__) + + async def execute_task_via_a2a( + self, + agent_url: str, + task_description: str, + context: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 300, + ) -> str: + """Execute a task via the A2A protocol. + + Args: + agent_url: The URL of the agent to execute the task. + task_description: The description of the task. + context: Additional context for the task. + api_key: The API key to use for authentication. + timeout: The timeout for the task execution in seconds. + + Returns: + The result of the task execution. + + Raises: + TimeoutError: If the task execution times out. + Exception: If there is an error executing the task. + """ + if not self.client: + self.client = A2AClient(base_url=agent_url, api_key=api_key) + + parts: List[Part] = [TextPart(text=task_description)] + if context: + parts.append( + DataPart( + data={"context": context}, + metadata={"type": "context"}, + ) + ) + + message = Message(role="user", parts=parts) + + task_id = str(uuid.uuid4()) + + try: + queue = await self.client.send_task_streaming( + task_id=task_id, + message=message, + ) + + result = await self._wait_for_task_completion(queue, timeout) + return result + except Exception as e: + self.logger.exception(f"Error executing task via A2A: {e}") + raise + + async def _wait_for_task_completion( + self, queue: asyncio.Queue, timeout: int + ) -> str: + """Wait for a task to complete. + + Args: + queue: The queue to receive task updates from. + timeout: The timeout for the task execution in seconds. + + Returns: + The result of the task execution. + + Raises: + TimeoutError: If the task execution times out. + Exception: If there is an error executing the task. + """ + result = "" + try: + async def _timeout(): + await asyncio.sleep(timeout) + await queue.put(TimeoutError(f"Task execution timed out after {timeout} seconds")) + + timeout_task = asyncio.create_task(_timeout()) + + while True: + event = await queue.get() + + if isinstance(event, Exception): + raise event + + if isinstance(event, TaskStatusUpdateEvent): + if event.status.state == TaskState.COMPLETED: + if event.status.message: + for part in event.status.message.parts: + if isinstance(part, TextPart): + result += part.text + break + elif event.status.state in [TaskState.FAILED, TaskState.CANCELED]: + error_message = "Task failed" + if event.status.message: + for part in event.status.message.parts: + if isinstance(part, TextPart): + error_message = part.text + raise Exception(error_message) + elif isinstance(event, TaskArtifactUpdateEvent): + for part in event.artifact.parts: + if isinstance(part, TextPart): + result += part.text + finally: + timeout_task.cancel() + + return result + + async def handle_a2a_task( + self, + task: A2ATask, + agent_execute_func: Any, + context: Optional[str] = None, + ) -> None: + """Handle an A2A task. + + Args: + task: The A2A task to handle. + agent_execute_func: The function to execute the task. + context: Additional context for the task. + + Raises: + Exception: If there is an error handling the task. + """ + if not self.task_manager: + raise ValueError("Task manager is required to handle A2A tasks") + + try: + await self.task_manager.update_task_status( + task_id=task.id, + state=TaskState.WORKING, + ) + + task_description = "" + task_context = context or "" + + if task.history and task.history[-1].role == "user": + message = task.history[-1] + for part in message.parts: + if isinstance(part, TextPart): + task_description += part.text + elif isinstance(part, DataPart) and part.data.get("context"): + task_context += part.data["context"] + + try: + result = await agent_execute_func(task_description, task_context) + + response_message = Message( + role="agent", + parts=[TextPart(text=result)], + ) + + await self.task_manager.update_task_status( + task_id=task.id, + state=TaskState.COMPLETED, + message=response_message, + ) + + artifact = Artifact( + name="result", + parts=[TextPart(text=result)], + ) + await self.task_manager.add_task_artifact( + task_id=task.id, + artifact=artifact, + ) + except Exception as e: + error_message = Message( + role="agent", + parts=[TextPart(text=str(e))], + ) + await self.task_manager.update_task_status( + task_id=task.id, + state=TaskState.FAILED, + message=error_message, + ) + raise + except Exception as e: + self.logger.exception(f"Error handling A2A task: {e}") + raise diff --git a/src/crewai/a2a/client.py b/src/crewai/a2a/client.py new file mode 100644 index 000000000..743009ca3 --- /dev/null +++ b/src/crewai/a2a/client.py @@ -0,0 +1,424 @@ +""" +A2A protocol client for CrewAI. + +This module implements the client for the A2A protocol in CrewAI. +""" + +import asyncio +import json +import logging +import os +from typing import Any, Dict, List, Optional, Union, cast + +import aiohttp +from pydantic import ValidationError + +from crewai.types.a2a import ( + A2AClientError, + A2AClientHTTPError, + A2AClientJSONError, + Artifact, + CancelTaskRequest, + CancelTaskResponse, + GetTaskPushNotificationRequest, + GetTaskPushNotificationResponse, + GetTaskRequest, + GetTaskResponse, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + Message, + MissingAPIKeyError, + PushNotificationConfig, + SendTaskRequest, + SendTaskResponse, + SendTaskStreamingRequest, + SetTaskPushNotificationRequest, + SetTaskPushNotificationResponse, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskSendParams, + TaskState, + TaskStatusUpdateEvent, +) + + +class A2AClient: + """A2A protocol client implementation.""" + + def __init__( + self, + base_url: str, + api_key: Optional[str] = None, + timeout: int = 60, + ): + """Initialize the A2A client. + + Args: + base_url: The base URL of the A2A server. + api_key: The API key to use for authentication. + timeout: The timeout for HTTP requests in seconds. + """ + self.base_url = base_url.rstrip("/") + self.api_key = api_key or os.environ.get("A2A_API_KEY") + self.timeout = timeout + self.logger = logging.getLogger(__name__) + + async def send_task( + self, + task_id: str, + message: Message, + session_id: Optional[str] = None, + accepted_output_modes: Optional[List[str]] = None, + push_notification: Optional[PushNotificationConfig] = None, + history_length: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Task: + """Send a task to the A2A server. + + Args: + task_id: The ID of the task. + message: The message to send. + session_id: The session ID. + accepted_output_modes: The accepted output modes. + push_notification: The push notification configuration. + history_length: The number of messages to include in the history. + metadata: Additional metadata. + + Returns: + The created task. + + Raises: + A2AClientError: If there is an error sending the task. + """ + params = TaskSendParams( + id=task_id, + sessionId=session_id, + message=message, + acceptedOutputModes=accepted_output_modes, + pushNotification=push_notification, + historyLength=history_length, + metadata=metadata, + ) + + request = SendTaskRequest(params=params) + response = await self._send_jsonrpc_request(request) + + if response.error: + raise A2AClientError(f"Error sending task: {response.error.message}") + + if not response.result: + raise A2AClientError("No result returned from send task request") + + return cast(Task, response.result) + + async def send_task_streaming( + self, + task_id: str, + message: Message, + session_id: Optional[str] = None, + accepted_output_modes: Optional[List[str]] = None, + push_notification: Optional[PushNotificationConfig] = None, + history_length: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> asyncio.Queue: + """Send a task to the A2A server and subscribe to updates. + + Args: + task_id: The ID of the task. + message: The message to send. + session_id: The session ID. + accepted_output_modes: The accepted output modes. + push_notification: The push notification configuration. + history_length: The number of messages to include in the history. + metadata: Additional metadata. + + Returns: + A queue that will receive task updates. + + Raises: + A2AClientError: If there is an error sending the task. + """ + params = TaskSendParams( + id=task_id, + sessionId=session_id, + message=message, + acceptedOutputModes=accepted_output_modes, + pushNotification=push_notification, + historyLength=history_length, + metadata=metadata, + ) + + queue: asyncio.Queue = asyncio.Queue() + + asyncio.create_task( + self._handle_streaming_response( + f"{self.base_url}/v1/tasks/sendSubscribe", params, queue + ) + ) + + return queue + + async def get_task( + self, task_id: str, history_length: Optional[int] = None + ) -> Task: + """Get a task from the A2A server. + + Args: + task_id: The ID of the task. + history_length: The number of messages to include in the history. + + Returns: + The task. + + Raises: + A2AClientError: If there is an error getting the task. + """ + params = TaskQueryParams(id=task_id, historyLength=history_length) + request = GetTaskRequest(params=params) + response = await self._send_jsonrpc_request(request) + + if response.error: + raise A2AClientError(f"Error getting task: {response.error.message}") + + if not response.result: + raise A2AClientError("No result returned from get task request") + + return cast(Task, response.result) + + async def cancel_task(self, task_id: str) -> Task: + """Cancel a task on the A2A server. + + Args: + task_id: The ID of the task. + + Returns: + The canceled task. + + Raises: + A2AClientError: If there is an error canceling the task. + """ + params = TaskIdParams(id=task_id) + request = CancelTaskRequest(params=params) + response = await self._send_jsonrpc_request(request) + + if response.error: + raise A2AClientError(f"Error canceling task: {response.error.message}") + + if not response.result: + raise A2AClientError("No result returned from cancel task request") + + return cast(Task, response.result) + + async def set_push_notification( + self, task_id: str, config: PushNotificationConfig + ) -> PushNotificationConfig: + """Set push notification for a task. + + Args: + task_id: The ID of the task. + config: The push notification configuration. + + Returns: + The push notification configuration. + + Raises: + A2AClientError: If there is an error setting the push notification. + """ + params = TaskPushNotificationConfig(id=task_id, pushNotificationConfig=config) + request = SetTaskPushNotificationRequest(params=params) + response = await self._send_jsonrpc_request(request) + + if response.error: + raise A2AClientError( + f"Error setting push notification: {response.error.message}" + ) + + if not response.result: + raise A2AClientError( + "No result returned from set push notification request" + ) + + return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig + + async def get_push_notification( + self, task_id: str + ) -> Optional[PushNotificationConfig]: + """Get push notification for a task. + + Args: + task_id: The ID of the task. + + Returns: + The push notification configuration, or None if not set. + + Raises: + A2AClientError: If there is an error getting the push notification. + """ + params = TaskIdParams(id=task_id) + request = GetTaskPushNotificationRequest(params=params) + response = await self._send_jsonrpc_request(request) + + if response.error: + raise A2AClientError( + f"Error getting push notification: {response.error.message}" + ) + + if not response.result: + return None + + return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig + + async def _send_jsonrpc_request( + self, request: JSONRPCRequest + ) -> JSONRPCResponse: + """Send a JSON-RPC request to the A2A server. + + Args: + request: The JSON-RPC request. + + Returns: + The JSON-RPC response. + + Raises: + A2AClientError: If there is an error sending the request. + """ + if not self.api_key: + raise MissingAPIKeyError( + "API key is required. Set it in the constructor or as the A2A_API_KEY environment variable." + ) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/v1/jsonrpc", + headers=headers, + json=request.model_dump(), + timeout=self.timeout, + ) as response: + if response.status != 200: + raise A2AClientHTTPError( + response.status, await response.text() + ) + + try: + data = await response.json() + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) + + try: + return JSONRPCResponse.model_validate(data) + except ValidationError as e: + raise A2AClientError(f"Invalid response: {e}") + except aiohttp.ClientError as e: + raise A2AClientError(f"HTTP error: {e}") + + async def _handle_streaming_response( + self, + url: str, + params: TaskSendParams, + queue: asyncio.Queue, + ) -> None: + """Handle a streaming response from the A2A server. + + Args: + url: The URL to send the request to. + params: The task send parameters. + queue: The queue to put events into. + """ + if not self.api_key: + await queue.put( + Exception( + "API key is required. Set it in the constructor or as the A2A_API_KEY environment variable." + ) + ) + return + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + "Accept": "text/event-stream", + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + url, + headers=headers, + json=params.model_dump(), + timeout=self.timeout, + ) as response: + if response.status != 200: + await queue.put( + A2AClientHTTPError(response.status, await response.text()) + ) + return + + buffer = "" + async for line in response.content: + line = line.decode("utf-8") + buffer += line + + if buffer.endswith("\n\n"): + event_data = self._parse_sse_event(buffer) + buffer = "" + + if event_data: + event_type = event_data.get("event") + data = event_data.get("data") + + if event_type == "status": + try: + event = TaskStatusUpdateEvent.model_validate_json(data) + await queue.put(event) + + if event.final: + break + except ValidationError as e: + await queue.put( + A2AClientError(f"Invalid status event: {e}") + ) + elif event_type == "artifact": + try: + event = TaskArtifactUpdateEvent.model_validate_json(data) + await queue.put(event) + except ValidationError as e: + await queue.put( + A2AClientError(f"Invalid artifact event: {e}") + ) + except aiohttp.ClientError as e: + await queue.put(A2AClientError(f"HTTP error: {e}")) + except asyncio.CancelledError: + pass + except Exception as e: + await queue.put(A2AClientError(f"Error handling streaming response: {e}")) + + def _parse_sse_event(self, data: str) -> Dict[str, str]: + """Parse an SSE event. + + Args: + data: The SSE event data. + + Returns: + A dictionary with the event type and data. + """ + result = {} + for line in data.split("\n"): + line = line.strip() + if not line: + continue + + if line.startswith("event:"): + result["event"] = line[6:].strip() + elif line.startswith("data:"): + result["data"] = line[5:].strip() + + return result diff --git a/src/crewai/a2a/server.py b/src/crewai/a2a/server.py new file mode 100644 index 000000000..a815e3574 --- /dev/null +++ b/src/crewai/a2a/server.py @@ -0,0 +1,368 @@ +""" +A2A protocol server for CrewAI. + +This module implements the server for the A2A protocol in CrewAI. +""" + +import asyncio +import json +import logging +from typing import Any, Callable, Dict, List, Optional, Type, Union + +from fastapi import FastAPI, HTTPException, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import ValidationError + +from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager +from crewai.types.a2a import ( + A2ARequest, + CancelTaskRequest, + CancelTaskResponse, + ContentTypeNotSupportedError, + GetTaskPushNotificationRequest, + GetTaskPushNotificationResponse, + GetTaskRequest, + GetTaskResponse, + InternalError, + InvalidParamsError, + InvalidRequestError, + JSONParseError, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + MethodNotFoundError, + SendTaskRequest, + SendTaskResponse, + SendTaskStreamingRequest, + SendTaskStreamingResponse, + SetTaskPushNotificationRequest, + SetTaskPushNotificationResponse, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskNotCancelableError, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskQueryParams, + TaskSendParams, + TaskState, + TaskStatusUpdateEvent, + UnsupportedOperationError, +) + + +class A2AServer: + """A2A protocol server implementation.""" + + def __init__( + self, + task_manager: Optional[TaskManager] = None, + enable_cors: bool = True, + cors_origins: Optional[List[str]] = None, + ): + """Initialize the A2A server. + + Args: + task_manager: The task manager to use. If None, an InMemoryTaskManager will be created. + enable_cors: Whether to enable CORS. + cors_origins: The CORS origins to allow. + """ + self.app = FastAPI(title="A2A Server") + self.task_manager = task_manager or InMemoryTaskManager() + self.logger = logging.getLogger(__name__) + + if enable_cors: + self.app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins or ["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + self.app.post("/v1/jsonrpc")(self.handle_jsonrpc) + self.app.post("/v1/tasks/send")(self.handle_send_task) + self.app.post("/v1/tasks/sendSubscribe")(self.handle_send_task_subscribe) + self.app.post("/v1/tasks/{task_id}/cancel")(self.handle_cancel_task) + self.app.get("/v1/tasks/{task_id}")(self.handle_get_task) + + async def handle_jsonrpc(self, request: Request) -> JSONResponse: + """Handle JSON-RPC requests. + + Args: + request: The FastAPI request. + + Returns: + A JSON response. + """ + try: + body = await request.json() + except json.JSONDecodeError: + return JSONResponse( + content=JSONRPCResponse( + id=None, error=JSONParseError() + ).model_dump(), + status_code=400, + ) + + try: + if isinstance(body, list): + responses = [] + for req_data in body: + response = await self._process_jsonrpc_request(req_data) + responses.append(response.model_dump()) + return JSONResponse(content=responses) + else: + response = await self._process_jsonrpc_request(body) + return JSONResponse(content=response.model_dump()) + except Exception as e: + self.logger.exception("Error processing JSON-RPC request") + return JSONResponse( + content=JSONRPCResponse( + id=body.get("id") if isinstance(body, dict) else None, + error=InternalError(data=str(e)), + ).model_dump(), + status_code=500, + ) + + async def _process_jsonrpc_request( + self, request_data: Dict[str, Any] + ) -> JSONRPCResponse: + """Process a JSON-RPC request. + + Args: + request_data: The JSON-RPC request data. + + Returns: + A JSON-RPC response. + """ + if not isinstance(request_data, dict) or request_data.get("jsonrpc") != "2.0": + return JSONRPCResponse( + id=request_data.get("id") if isinstance(request_data, dict) else None, + error=InvalidRequestError(), + ) + + request_id = request_data.get("id") + method = request_data.get("method") + + if not method: + return JSONRPCResponse( + id=request_id, + error=InvalidRequestError(message="Method is required"), + ) + + try: + request = A2ARequest.validate_python(request_data) + except ValidationError as e: + return JSONRPCResponse( + id=request_id, + error=InvalidParamsError(data=str(e)), + ) + + try: + if isinstance(request, SendTaskRequest): + task = await self._handle_send_task(request.params) + return SendTaskResponse(id=request_id, result=task) + elif isinstance(request, GetTaskRequest): + task = await self.task_manager.get_task( + request.params.id, request.params.historyLength + ) + return GetTaskResponse(id=request_id, result=task) + elif isinstance(request, CancelTaskRequest): + task = await self.task_manager.cancel_task(request.params.id) + return CancelTaskResponse(id=request_id, result=task) + elif isinstance(request, SetTaskPushNotificationRequest): + config = await self.task_manager.set_push_notification( + request.params.id, request.params.pushNotificationConfig + ) + return SetTaskPushNotificationResponse( + id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config) + ) + elif isinstance(request, GetTaskPushNotificationRequest): + config = await self.task_manager.get_push_notification( + request.params.id + ) + if config: + return GetTaskPushNotificationResponse( + id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config) + ) + else: + return GetTaskPushNotificationResponse(id=request_id, result=None) + elif isinstance(request, SendTaskStreamingRequest): + return JSONRPCResponse( + id=request_id, + error=UnsupportedOperationError( + message="Streaming requests should be sent to the streaming endpoint" + ), + ) + else: + return JSONRPCResponse( + id=request_id, + error=MethodNotFoundError(), + ) + except KeyError: + return JSONRPCResponse( + id=request_id, + error=TaskNotFoundError(), + ) + except Exception as e: + self.logger.exception(f"Error handling {method} request") + return JSONRPCResponse( + id=request_id, + error=InternalError(data=str(e)), + ) + + async def handle_send_task(self, request: Request) -> JSONResponse: + """Handle send task requests. + + Args: + request: The FastAPI request. + + Returns: + A JSON response. + """ + try: + body = await request.json() + params = TaskSendParams.model_validate(body) + task = await self._handle_send_task(params) + return JSONResponse(content=task.model_dump()) + except ValidationError as e: + return JSONResponse( + content={"error": str(e)}, + status_code=400, + ) + except Exception as e: + self.logger.exception("Error handling send task request") + return JSONResponse( + content={"error": str(e)}, + status_code=500, + ) + + async def _handle_send_task(self, params: TaskSendParams) -> Task: + """Handle send task requests. + + Args: + params: The task send parameters. + + Returns: + The created task. + """ + task = await self.task_manager.create_task( + task_id=params.id, + session_id=params.sessionId, + message=params.message, + metadata=params.metadata, + ) + + await self.task_manager.update_task_status( + task_id=params.id, + state=TaskState.WORKING, + ) + + return task + + async def handle_send_task_subscribe(self, request: Request) -> StreamingResponse: + """Handle send task subscribe requests. + + Args: + request: The FastAPI request. + + Returns: + A streaming response. + """ + try: + body = await request.json() + params = TaskSendParams.model_validate(body) + + task = await self._handle_send_task(params) + + queue = await self.task_manager.subscribe_to_task(params.id) + + return StreamingResponse( + self._stream_task_updates(params.id, queue), + media_type="text/event-stream", + ) + except ValidationError as e: + return JSONResponse( + content={"error": str(e)}, + status_code=400, + ) + except Exception as e: + self.logger.exception("Error handling send task subscribe request") + return JSONResponse( + content={"error": str(e)}, + status_code=500, + ) + + async def _stream_task_updates( + self, task_id: str, queue: asyncio.Queue + ) -> None: + """Stream task updates. + + Args: + task_id: The ID of the task. + queue: The queue to receive updates from. + + Yields: + SSE formatted events. + """ + try: + while True: + event = await queue.get() + + if isinstance(event, TaskStatusUpdateEvent): + event_type = "status" + elif isinstance(event, TaskArtifactUpdateEvent): + event_type = "artifact" + else: + event_type = "unknown" + + data = json.dumps(event.model_dump()) + yield f"event: {event_type}\ndata: {data}\n\n" + + if isinstance(event, TaskStatusUpdateEvent) and event.final: + break + finally: + await self.task_manager.unsubscribe_from_task(task_id, queue) + + async def handle_get_task(self, task_id: str, request: Request) -> JSONResponse: + """Handle get task requests. + + Args: + task_id: The ID of the task. + request: The FastAPI request. + + Returns: + A JSON response. + """ + try: + history_length = request.query_params.get("historyLength") + history_length = int(history_length) if history_length else None + + task = await self.task_manager.get_task(task_id, history_length) + return JSONResponse(content=task.model_dump()) + except KeyError: + raise HTTPException(status_code=404, detail=f"Task {task_id} not found") + except Exception as e: + self.logger.exception(f"Error handling get task request for {task_id}") + raise HTTPException(status_code=500, detail=str(e)) + + async def handle_cancel_task(self, task_id: str, request: Request) -> JSONResponse: + """Handle cancel task requests. + + Args: + task_id: The ID of the task. + request: The FastAPI request. + + Returns: + A JSON response. + """ + try: + task = await self.task_manager.cancel_task(task_id) + return JSONResponse(content=task.model_dump()) + except KeyError: + raise HTTPException(status_code=404, detail=f"Task {task_id} not found") + except Exception as e: + self.logger.exception(f"Error handling cancel task request for {task_id}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/crewai/a2a/task_manager.py b/src/crewai/a2a/task_manager.py new file mode 100644 index 000000000..4104c506b --- /dev/null +++ b/src/crewai/a2a/task_manager.py @@ -0,0 +1,438 @@ +""" +A2A protocol task manager for CrewAI. + +This module implements the task manager for the A2A protocol in CrewAI. +""" + +import asyncio +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Optional, Set, Union +from uuid import uuid4 + +from crewai.types.a2a import ( + Artifact, + Message, + PushNotificationConfig, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) + + +class TaskManager(ABC): + """Abstract base class for A2A task managers.""" + + @abstractmethod + async def create_task( + self, + task_id: str, + session_id: Optional[str] = None, + message: Optional[Message] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Task: + """Create a new task. + + Args: + task_id: The ID of the task. + session_id: The session ID. + message: The initial message. + metadata: Additional metadata. + + Returns: + The created task. + """ + pass + + @abstractmethod + async def get_task( + self, task_id: str, history_length: Optional[int] = None + ) -> Task: + """Get a task by ID. + + Args: + task_id: The ID of the task. + history_length: The number of messages to include in the history. + + Returns: + The task. + + Raises: + KeyError: If the task is not found. + """ + pass + + @abstractmethod + async def update_task_status( + self, + task_id: str, + state: TaskState, + message: Optional[Message] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> TaskStatusUpdateEvent: + """Update the status of a task. + + Args: + task_id: The ID of the task. + state: The new state of the task. + message: An optional message to include with the status update. + metadata: Additional metadata. + + Returns: + The task status update event. + + Raises: + KeyError: If the task is not found. + """ + pass + + @abstractmethod + async def add_task_artifact( + self, + task_id: str, + artifact: Artifact, + metadata: Optional[Dict[str, Any]] = None, + ) -> TaskArtifactUpdateEvent: + """Add an artifact to a task. + + Args: + task_id: The ID of the task. + artifact: The artifact to add. + metadata: Additional metadata. + + Returns: + The task artifact update event. + + Raises: + KeyError: If the task is not found. + """ + pass + + @abstractmethod + async def cancel_task(self, task_id: str) -> Task: + """Cancel a task. + + Args: + task_id: The ID of the task. + + Returns: + The canceled task. + + Raises: + KeyError: If the task is not found. + """ + pass + + @abstractmethod + async def set_push_notification( + self, task_id: str, config: PushNotificationConfig + ) -> PushNotificationConfig: + """Set push notification for a task. + + Args: + task_id: The ID of the task. + config: The push notification configuration. + + Returns: + The push notification configuration. + + Raises: + KeyError: If the task is not found. + """ + pass + + @abstractmethod + async def get_push_notification( + self, task_id: str + ) -> Optional[PushNotificationConfig]: + """Get push notification for a task. + + Args: + task_id: The ID of the task. + + Returns: + The push notification configuration, or None if not set. + + Raises: + KeyError: If the task is not found. + """ + pass + + +class InMemoryTaskManager(TaskManager): + """In-memory implementation of the A2A task manager.""" + + def __init__(self): + """Initialize the in-memory task manager.""" + self._tasks: Dict[str, Task] = {} + self._push_notifications: Dict[str, PushNotificationConfig] = {} + self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {} + self._logger = logging.getLogger(__name__) + + async def create_task( + self, + task_id: str, + session_id: Optional[str] = None, + message: Optional[Message] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Task: + """Create a new task. + + Args: + task_id: The ID of the task. + session_id: The session ID. + message: The initial message. + metadata: Additional metadata. + + Returns: + The created task. + """ + if task_id in self._tasks: + return self._tasks[task_id] + + session_id = session_id or uuid4().hex + status = TaskStatus( + state=TaskState.SUBMITTED, + message=message, + timestamp=datetime.now(), + ) + + task = Task( + id=task_id, + sessionId=session_id, + status=status, + artifacts=[], + history=[message] if message else [], + metadata=metadata or {}, + ) + + self._tasks[task_id] = task + self._task_subscribers[task_id] = set() + return task + + async def get_task( + self, task_id: str, history_length: Optional[int] = None + ) -> Task: + """Get a task by ID. + + Args: + task_id: The ID of the task. + history_length: The number of messages to include in the history. + + Returns: + The task. + + Raises: + KeyError: If the task is not found. + """ + if task_id not in self._tasks: + raise KeyError(f"Task {task_id} not found") + + task = self._tasks[task_id] + if history_length is not None and task.history: + task_copy = task.model_copy(deep=True) + task_copy.history = task.history[-history_length:] + return task_copy + return task + + async def update_task_status( + self, + task_id: str, + state: TaskState, + message: Optional[Message] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> TaskStatusUpdateEvent: + """Update the status of a task. + + Args: + task_id: The ID of the task. + state: The new state of the task. + message: An optional message to include with the status update. + metadata: Additional metadata. + + Returns: + The task status update event. + + Raises: + KeyError: If the task is not found. + """ + if task_id not in self._tasks: + raise KeyError(f"Task {task_id} not found") + + task = self._tasks[task_id] + status = TaskStatus( + state=state, + message=message, + timestamp=datetime.now(), + ) + task.status = status + + if message and task.history is not None: + task.history.append(message) + + event = TaskStatusUpdateEvent( + id=task_id, + status=status, + final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED], + metadata=metadata or {}, + ) + + await self._notify_subscribers(task_id, event) + + return event + + async def add_task_artifact( + self, + task_id: str, + artifact: Artifact, + metadata: Optional[Dict[str, Any]] = None, + ) -> TaskArtifactUpdateEvent: + """Add an artifact to a task. + + Args: + task_id: The ID of the task. + artifact: The artifact to add. + metadata: Additional metadata. + + Returns: + The task artifact update event. + + Raises: + KeyError: If the task is not found. + """ + if task_id not in self._tasks: + raise KeyError(f"Task {task_id} not found") + + task = self._tasks[task_id] + if task.artifacts is None: + task.artifacts = [] + + if artifact.append and task.artifacts: + for existing in task.artifacts: + if existing.name == artifact.name: + existing.parts.extend(artifact.parts) + existing.lastChunk = artifact.lastChunk + break + else: + task.artifacts.append(artifact) + else: + task.artifacts.append(artifact) + + event = TaskArtifactUpdateEvent( + id=task_id, + artifact=artifact, + metadata=metadata or {}, + ) + + await self._notify_subscribers(task_id, event) + + return event + + async def cancel_task(self, task_id: str) -> Task: + """Cancel a task. + + Args: + task_id: The ID of the task. + + Returns: + The canceled task. + + Raises: + KeyError: If the task is not found. + """ + if task_id not in self._tasks: + raise KeyError(f"Task {task_id} not found") + + task = self._tasks[task_id] + + if task.status.state not in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED]: + await self.update_task_status(task_id, TaskState.CANCELED) + + return task + + async def set_push_notification( + self, task_id: str, config: PushNotificationConfig + ) -> PushNotificationConfig: + """Set push notification for a task. + + Args: + task_id: The ID of the task. + config: The push notification configuration. + + Returns: + The push notification configuration. + + Raises: + KeyError: If the task is not found. + """ + if task_id not in self._tasks: + raise KeyError(f"Task {task_id} not found") + + self._push_notifications[task_id] = config + return config + + async def get_push_notification( + self, task_id: str + ) -> Optional[PushNotificationConfig]: + """Get push notification for a task. + + Args: + task_id: The ID of the task. + + Returns: + The push notification configuration, or None if not set. + + Raises: + KeyError: If the task is not found. + """ + if task_id not in self._tasks: + raise KeyError(f"Task {task_id} not found") + + return self._push_notifications.get(task_id) + + async def subscribe_to_task(self, task_id: str) -> asyncio.Queue: + """Subscribe to task updates. + + Args: + task_id: The ID of the task. + + Returns: + A queue that will receive task updates. + + Raises: + KeyError: If the task is not found. + """ + if task_id not in self._tasks: + raise KeyError(f"Task {task_id} not found") + + queue: asyncio.Queue = asyncio.Queue() + self._task_subscribers.setdefault(task_id, set()).add(queue) + return queue + + async def unsubscribe_from_task(self, task_id: str, queue: asyncio.Queue) -> None: + """Unsubscribe from task updates. + + Args: + task_id: The ID of the task. + queue: The queue to unsubscribe. + """ + if task_id in self._task_subscribers: + self._task_subscribers[task_id].discard(queue) + + async def _notify_subscribers( + self, + task_id: str, + event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent], + ) -> None: + """Notify subscribers of a task update. + + Args: + task_id: The ID of the task. + event: The event to send to subscribers. + """ + if task_id in self._task_subscribers: + for queue in self._task_subscribers[task_id]: + await queue.put(event) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 999d1d800..e8ec11b53 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from pydantic import Field, InstanceOf, PrivateAttr, model_validator +from crewai.a2a import A2AAgentIntegration from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor @@ -131,14 +132,29 @@ class Agent(BaseAgent): default=None, description="Knowledge sources for the agent.", ) + a2a_enabled: bool = Field( + default=False, + description="Whether the agent supports the A2A protocol.", + ) + a2a_url: Optional[str] = Field( + default=None, + description="The URL where the agent's A2A server is hosted.", + ) _knowledge: Optional[Knowledge] = PrivateAttr( default=None, ) + _a2a_integration: Optional[A2AAgentIntegration] = PrivateAttr( + default=None, + ) @model_validator(mode="after") def post_init_setup(self): self._set_knowledge() self.agent_ops_agent_name = self.role + + if self.a2a_enabled: + self._a2a_integration = A2AAgentIntegration() + unaccepted_attributes = [ "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", @@ -355,6 +371,102 @@ class Agent(BaseAgent): result = tool_result["result"] return result + + async def execute_task_via_a2a( + self, + task_description: str, + context: Optional[str] = None, + agent_url: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 300, + ) -> str: + """Execute a task via the A2A protocol. + + Args: + task_description: The description of the task. + context: Additional context for the task. + agent_url: The URL of the agent to execute the task. Defaults to self.a2a_url. + api_key: The API key to use for authentication. + timeout: The timeout for the task execution in seconds. + + Returns: + The result of the task execution. + + Raises: + ValueError: If A2A is not enabled or no agent URL is provided. + TimeoutError: If the task execution times out. + Exception: If there is an error executing the task. + """ + if not self.a2a_enabled: + raise ValueError("A2A protocol is not enabled for this agent") + + if not self._a2a_integration: + self._a2a_integration = A2AAgentIntegration() + + url = agent_url or self.a2a_url + if not url: + raise ValueError("No A2A agent URL provided") + + try: + import asyncio + if asyncio.get_event_loop().is_running(): + return await self._a2a_integration.execute_task_via_a2a( + agent_url=url, + task_description=task_description, + context=context, + api_key=api_key, + timeout=timeout, + ) + else: + return asyncio.run(self._a2a_integration.execute_task_via_a2a( + agent_url=url, + task_description=task_description, + context=context, + api_key=api_key, + timeout=timeout, + )) + except Exception as e: + self._logger.exception(f"Error executing task via A2A: {e}") + raise + + async def handle_a2a_task( + self, + task_id: str, + task_description: str, + context: Optional[str] = None, + ) -> str: + """Handle an A2A task. + + Args: + task_id: The ID of the A2A task. + task_description: The description of the task. + context: Additional context for the task. + + Returns: + The result of the task execution. + + Raises: + ValueError: If A2A is not enabled. + Exception: If there is an error handling the task. + """ + if not self.a2a_enabled: + raise ValueError("A2A protocol is not enabled for this agent") + + if not self._a2a_integration: + self._a2a_integration = A2AAgentIntegration() + + # Create a Task object from the task description + task = Task( + description=task_description, + agent=self, + ) + + try: + result = self.execute_task(task, context) + return result + except Exception as e: + self._logger.exception(f"Error handling A2A task: {e}") + raise def create_agent_executor( self, tools: Optional[List[BaseTool]] = None, task=None diff --git a/src/crewai/types/__init__.py b/src/crewai/types/__init__.py index e69de29bb..d86580548 100644 --- a/src/crewai/types/__init__.py +++ b/src/crewai/types/__init__.py @@ -0,0 +1 @@ +"""Type definitions for CrewAI.""" diff --git a/src/crewai/types/a2a.py b/src/crewai/types/a2a.py new file mode 100644 index 000000000..ca6ea7a07 --- /dev/null +++ b/src/crewai/types/a2a.py @@ -0,0 +1,423 @@ +""" +A2A protocol types for CrewAI. + +This module implements the A2A (Agent-to-Agent) protocol types as defined by Google. +The A2A protocol enables interoperability between different agent systems. + +For more information, see: https://developers.googleblog.com/en/a2a-a-new-era-of-agent-interoperability/ +""" + +from datetime import datetime +from enum import Enum +from typing import Annotated, Any, Dict, List, Literal, Optional, Self, Union +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, field_serializer, model_validator + + +class TaskState(str, Enum): + """Task state in the A2A protocol.""" + SUBMITTED = 'submitted' + WORKING = 'working' + INPUT_REQUIRED = 'input-required' + COMPLETED = 'completed' + CANCELED = 'canceled' + FAILED = 'failed' + UNKNOWN = 'unknown' + + +class TextPart(BaseModel): + """Text part in the A2A protocol.""" + type: Literal['text'] = 'text' + text: str + metadata: Optional[Dict[str, Any]] = None + + +class FileContent(BaseModel): + """File content in the A2A protocol.""" + name: Optional[str] = None + mimeType: Optional[str] = None + bytes: Optional[str] = None + uri: Optional[str] = None + + @model_validator(mode='after') + def check_content(self) -> Self: + """Validate file content has either bytes or uri.""" + if not (self.bytes or self.uri): + raise ValueError( + "Either 'bytes' or 'uri' must be present in the file data" + ) + if self.bytes and self.uri: + raise ValueError( + "Only one of 'bytes' or 'uri' can be present in the file data" + ) + return self + + +class FilePart(BaseModel): + """File part in the A2A protocol.""" + type: Literal['file'] = 'file' + file: FileContent + metadata: Optional[Dict[str, Any]] = None + + +class DataPart(BaseModel): + """Data part in the A2A protocol.""" + type: Literal['data'] = 'data' + data: Dict[str, Any] + metadata: Optional[Dict[str, Any]] = None + + +Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator='type')] + + +class Message(BaseModel): + """Message in the A2A protocol.""" + role: Literal['user', 'agent'] + parts: List[Part] + metadata: Optional[Dict[str, Any]] = None + + +class TaskStatus(BaseModel): + """Task status in the A2A protocol.""" + state: TaskState + message: Optional[Message] = None + timestamp: datetime = Field(default_factory=datetime.now) + + @field_serializer('timestamp') + def serialize_dt(self, dt: datetime, _info): + """Serialize datetime to ISO format.""" + return dt.isoformat() + + +class Artifact(BaseModel): + """Artifact in the A2A protocol.""" + name: Optional[str] = None + description: Optional[str] = None + parts: List[Part] + metadata: Optional[Dict[str, Any]] = None + index: int = 0 + append: Optional[bool] = None + lastChunk: Optional[bool] = None + + +class Task(BaseModel): + """Task in the A2A protocol.""" + id: str + sessionId: Optional[str] = None + status: TaskStatus + artifacts: Optional[List[Artifact]] = None + history: Optional[List[Message]] = None + metadata: Optional[Dict[str, Any]] = None + + +class TaskStatusUpdateEvent(BaseModel): + """Task status update event in the A2A protocol.""" + id: str + status: TaskStatus + final: bool = False + metadata: Optional[Dict[str, Any]] = None + + +class TaskArtifactUpdateEvent(BaseModel): + """Task artifact update event in the A2A protocol.""" + id: str + artifact: Artifact + metadata: Optional[Dict[str, Any]] = None + + +class AuthenticationInfo(BaseModel): + """Authentication information in the A2A protocol.""" + model_config = ConfigDict(extra='allow') + + schemes: List[str] + credentials: Optional[str] = None + + +class PushNotificationConfig(BaseModel): + """Push notification configuration in the A2A protocol.""" + url: str + token: Optional[str] = None + authentication: Optional[AuthenticationInfo] = None + + +class TaskIdParams(BaseModel): + """Task ID parameters in the A2A protocol.""" + id: str + metadata: Optional[Dict[str, Any]] = None + + +class TaskQueryParams(TaskIdParams): + """Task query parameters in the A2A protocol.""" + historyLength: Optional[int] = None + + +class TaskSendParams(BaseModel): + """Task send parameters in the A2A protocol.""" + id: str + sessionId: str = Field(default_factory=lambda: uuid4().hex) + message: Message + acceptedOutputModes: Optional[List[str]] = None + pushNotification: Optional[PushNotificationConfig] = None + historyLength: Optional[int] = None + metadata: Optional[Dict[str, Any]] = None + + +class TaskPushNotificationConfig(BaseModel): + """Task push notification configuration in the A2A protocol.""" + id: str + pushNotificationConfig: PushNotificationConfig + + + +class JSONRPCMessage(BaseModel): + """JSON-RPC message in the A2A protocol.""" + jsonrpc: Literal['2.0'] = '2.0' + id: Optional[Union[int, str]] = Field(default_factory=lambda: uuid4().hex) + + +class JSONRPCRequest(JSONRPCMessage): + """JSON-RPC request in the A2A protocol.""" + method: str + params: Optional[Dict[str, Any]] = None + + +class JSONRPCError(BaseModel): + """JSON-RPC error in the A2A protocol.""" + code: int + message: str + data: Optional[Any] = None + + +class JSONRPCResponse(JSONRPCMessage): + """JSON-RPC response in the A2A protocol.""" + result: Optional[Any] = None + error: Optional[JSONRPCError] = None + + +class SendTaskRequest(JSONRPCRequest): + """Send task request in the A2A protocol.""" + method: Literal['tasks/send'] = 'tasks/send' + params: TaskSendParams + + +class SendTaskResponse(JSONRPCResponse): + """Send task response in the A2A protocol.""" + result: Optional[Task] = None + + +class SendTaskStreamingRequest(JSONRPCRequest): + """Send task streaming request in the A2A protocol.""" + method: Literal['tasks/sendSubscribe'] = 'tasks/sendSubscribe' + params: TaskSendParams + + +class SendTaskStreamingResponse(JSONRPCResponse): + """Send task streaming response in the A2A protocol.""" + result: Optional[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]] = None + + +class GetTaskRequest(JSONRPCRequest): + """Get task request in the A2A protocol.""" + method: Literal['tasks/get'] = 'tasks/get' + params: TaskQueryParams + + +class GetTaskResponse(JSONRPCResponse): + """Get task response in the A2A protocol.""" + result: Optional[Task] = None + + +class CancelTaskRequest(JSONRPCRequest): + """Cancel task request in the A2A protocol.""" + method: Literal['tasks/cancel'] = 'tasks/cancel' + params: TaskIdParams + + +class CancelTaskResponse(JSONRPCResponse): + """Cancel task response in the A2A protocol.""" + result: Optional[Task] = None + + +class SetTaskPushNotificationRequest(JSONRPCRequest): + """Set task push notification request in the A2A protocol.""" + method: Literal['tasks/pushNotification/set'] = 'tasks/pushNotification/set' + params: TaskPushNotificationConfig + + +class SetTaskPushNotificationResponse(JSONRPCResponse): + """Set task push notification response in the A2A protocol.""" + result: Optional[TaskPushNotificationConfig] = None + + +class GetTaskPushNotificationRequest(JSONRPCRequest): + """Get task push notification request in the A2A protocol.""" + method: Literal['tasks/pushNotification/get'] = 'tasks/pushNotification/get' + params: TaskIdParams + + +class GetTaskPushNotificationResponse(JSONRPCResponse): + """Get task push notification response in the A2A protocol.""" + result: Optional[TaskPushNotificationConfig] = None + + +class TaskResubscriptionRequest(JSONRPCRequest): + """Task resubscription request in the A2A protocol.""" + method: Literal['tasks/resubscribe'] = 'tasks/resubscribe' + params: TaskIdParams + + +A2ARequest = TypeAdapter( + Annotated[ + Union[ + SendTaskRequest, + GetTaskRequest, + CancelTaskRequest, + SetTaskPushNotificationRequest, + GetTaskPushNotificationRequest, + TaskResubscriptionRequest, + SendTaskStreamingRequest, + ], + Field(discriminator='method'), + ] +) + + +class JSONParseError(JSONRPCError): + """JSON parse error in the A2A protocol.""" + code: int = -32700 + message: str = 'Invalid JSON payload' + data: Optional[Any] = None + + +class InvalidRequestError(JSONRPCError): + """Invalid request error in the A2A protocol.""" + code: int = -32600 + message: str = 'Request payload validation error' + data: Optional[Any] = None + + +class MethodNotFoundError(JSONRPCError): + """Method not found error in the A2A protocol.""" + code: int = -32601 + message: str = 'Method not found' + data: None = None + + +class InvalidParamsError(JSONRPCError): + """Invalid parameters error in the A2A protocol.""" + code: int = -32602 + message: str = 'Invalid parameters' + data: Optional[Any] = None + + +class InternalError(JSONRPCError): + """Internal error in the A2A protocol.""" + code: int = -32603 + message: str = 'Internal error' + data: Optional[Any] = None + + +class TaskNotFoundError(JSONRPCError): + """Task not found error in the A2A protocol.""" + code: int = -32001 + message: str = 'Task not found' + data: None = None + + +class TaskNotCancelableError(JSONRPCError): + """Task not cancelable error in the A2A protocol.""" + code: int = -32002 + message: str = 'Task cannot be canceled' + data: None = None + + +class PushNotificationNotSupportedError(JSONRPCError): + """Push notification not supported error in the A2A protocol.""" + code: int = -32003 + message: str = 'Push Notification is not supported' + data: None = None + + +class UnsupportedOperationError(JSONRPCError): + """Unsupported operation error in the A2A protocol.""" + code: int = -32004 + message: str = 'This operation is not supported' + data: None = None + + +class ContentTypeNotSupportedError(JSONRPCError): + """Content type not supported error in the A2A protocol.""" + code: int = -32005 + message: str = 'Incompatible content types' + data: None = None + + +class AgentProvider(BaseModel): + """Agent provider in the A2A protocol.""" + organization: str + url: Optional[str] = None + + +class AgentCapabilities(BaseModel): + """Agent capabilities in the A2A protocol.""" + streaming: bool = False + pushNotifications: bool = False + stateTransitionHistory: bool = False + + +class AgentAuthentication(BaseModel): + """Agent authentication in the A2A protocol.""" + schemes: List[str] + credentials: Optional[str] = None + + +class AgentSkill(BaseModel): + """Agent skill in the A2A protocol.""" + id: str + name: str + description: Optional[str] = None + tags: Optional[List[str]] = None + examples: Optional[List[str]] = None + inputModes: Optional[List[str]] = None + outputModes: Optional[List[str]] = None + + +class AgentCard(BaseModel): + """Agent card in the A2A protocol.""" + name: str + description: Optional[str] = None + url: str + provider: Optional[AgentProvider] = None + version: str + documentationUrl: Optional[str] = None + capabilities: AgentCapabilities + authentication: Optional[AgentAuthentication] = None + defaultInputModes: List[str] = ['text'] + defaultOutputModes: List[str] = ['text'] + skills: List[AgentSkill] + + +class A2AClientError(Exception): + """Base exception for A2A client errors.""" + pass + + +class A2AClientHTTPError(A2AClientError): + """HTTP error in the A2A client.""" + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + super().__init__(f'HTTP Error {status_code}: {message}') + + +class A2AClientJSONError(A2AClientError): + """JSON error in the A2A client.""" + def __init__(self, message: str): + self.message = message + super().__init__(f'JSON Error: {message}') + + +class MissingAPIKeyError(Exception): + """Exception for missing API key.""" + pass diff --git a/tests/a2a/__init__.py b/tests/a2a/__init__.py new file mode 100644 index 000000000..081d1d037 --- /dev/null +++ b/tests/a2a/__init__.py @@ -0,0 +1 @@ +"""Tests for the A2A protocol implementation.""" diff --git a/tests/a2a/test_a2a_integration.py b/tests/a2a/test_a2a_integration.py new file mode 100644 index 000000000..7d3360eb0 --- /dev/null +++ b/tests/a2a/test_a2a_integration.py @@ -0,0 +1,230 @@ +"""Tests for the A2A protocol integration.""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from crewai.agent import Agent +from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager +from crewai.task import Task +from crewai.types.a2a import ( + Message, + Task as A2ATask, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, +) + + +@pytest.fixture +def agent(): + """Create an agent with A2A enabled.""" + return Agent( + role="test_agent", + goal="Test A2A protocol", + backstory="I am a test agent", + a2a_enabled=True, + a2a_url="http://localhost:8000", + ) + + +@pytest.fixture +def task(): + """Create a task.""" + return Task( + description="Test task", + ) + + +@pytest.fixture +def a2a_task(): + """Create an A2A task.""" + return A2ATask( + id="test_task_id", + history=[ + Message( + role="user", + parts=[TextPart(text="Test task description")], + ) + ], + ) + + +@pytest.fixture +def a2a_integration(): + """Create an A2A integration.""" + return A2AAgentIntegration() + + +@pytest.fixture +def a2a_client(): + """Create an A2A client.""" + return A2AClient(base_url="http://localhost:8000") + + +@pytest.fixture +def task_manager(): + """Create a task manager.""" + return InMemoryTaskManager() + + +class TestA2AIntegration: + """Tests for the A2A protocol integration.""" + + def test_agent_a2a_attributes(self, agent): + """Test that the agent has A2A attributes.""" + assert agent.a2a_enabled is True + assert agent.a2a_url == "http://localhost:8000" + assert agent._a2a_integration is not None + + @patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a") + def test_execute_task_via_a2a(self, mock_execute, agent): + """Test executing a task via A2A.""" + mock_execute.return_value = asyncio.Future() + mock_execute.return_value.set_result("Task result") + + result = asyncio.run( + agent.execute_task_via_a2a( + task_description="Test task", + context="Test context", + ) + ) + + assert result == "Task result" + mock_execute.assert_called_once_with( + agent_url="http://localhost:8000", + task_description="Test task", + context="Test context", + api_key=None, + timeout=300, + ) + + @patch("crewai.agent.Agent.execute_task") + def test_handle_a2a_task(self, mock_execute, agent): + """Test handling an A2A task.""" + mock_execute.return_value = "Task result" + + result = asyncio.run( + agent.handle_a2a_task( + task_id="test_task_id", + task_description="Test task", + context="Test context", + ) + ) + + assert result == "Task result" + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert args[0].description == "Test task" + assert kwargs["context"] == "Test context" + + def test_a2a_disabled(self, agent): + """Test that A2A methods raise ValueError when A2A is disabled.""" + agent.a2a_enabled = False + + with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"): + asyncio.run( + agent.execute_task_via_a2a( + task_description="Test task", + ) + ) + + with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"): + asyncio.run( + agent.handle_a2a_task( + task_id="test_task_id", + task_description="Test task", + ) + ) + + def test_no_agent_url(self, agent): + """Test that execute_task_via_a2a raises ValueError when no agent URL is provided.""" + agent.a2a_url = None + + with pytest.raises(ValueError, match="No A2A agent URL provided"): + asyncio.run( + agent.execute_task_via_a2a( + task_description="Test task", + ) + ) + + +class TestA2AAgentIntegration: + """Tests for the A2AAgentIntegration class.""" + + @patch("crewai.a2a.client.A2AClient.send_task_streaming") + async def test_execute_task_via_a2a(self, mock_send_task, a2a_integration): + """Test executing a task via A2A.""" + queue = asyncio.Queue() + await queue.put( + TaskStatusUpdateEvent( + task_id="test_task_id", + status=TaskStatus( + state=TaskState.COMPLETED, + message=Message( + role="agent", + parts=[TextPart(text="Task result")], + ), + ), + final=True, + ) + ) + + mock_send_task.return_value = queue + + result = await a2a_integration.execute_task_via_a2a( + agent_url="http://localhost:8000", + task_description="Test task", + context="Test context", + ) + + assert result == "Task result" + mock_send_task.assert_called_once() + + +class TestA2AServer: + """Tests for the A2AServer class.""" + + @patch("fastapi.FastAPI.post") + def test_server_initialization(self, mock_post, task_manager): + """Test server initialization.""" + server = A2AServer(task_manager=task_manager) + assert server.task_manager == task_manager + assert server.app is not None + assert mock_post.call_count == 4 # 4 endpoints registered + + +class TestA2AClient: + """Tests for the A2AClient class.""" + + @patch("aiohttp.ClientSession.post") + async def test_send_task(self, mock_post, a2a_client): + """Test sending a task.""" + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "id": "test_task_id", + "history": [ + { + "role": "user", + "parts": [{"text": "Test task description"}], + } + ], + } + ) + mock_post.return_value.__aenter__.return_value = mock_response + + task = await a2a_client.send_task( + task_id="test_task_id", + message=Message( + role="user", + parts=[TextPart(text="Test task description")], + ), + ) + + assert task.id == "test_task_id" + assert task.history[0].role == "user" + assert task.history[0].parts[0].text == "Test task description" + mock_post.assert_called_once()