mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Add A2A protocol support for CrewAI (Issue #2796)
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
59
examples/a2a_protocol_example.py
Normal file
59
examples/a2a_protocol_example.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
Example of using the A2A protocol with CrewAI.
|
||||||
|
|
||||||
|
This example demonstrates how to:
|
||||||
|
1. Create an agent with A2A protocol support
|
||||||
|
2. Start an A2A server for the agent
|
||||||
|
3. Execute a task via the A2A protocol
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import uvicorn
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from crewai import Agent
|
||||||
|
from crewai.a2a import A2AServer, InMemoryTaskManager
|
||||||
|
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Data Analyst",
|
||||||
|
goal="Analyze data and provide insights",
|
||||||
|
backstory="I am a data analyst with expertise in finding patterns and insights in data.",
|
||||||
|
a2a_enabled=True,
|
||||||
|
a2a_url="http://localhost:8000",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def start_server():
|
||||||
|
"""Start the A2A server."""
|
||||||
|
task_manager = InMemoryTaskManager()
|
||||||
|
|
||||||
|
server = A2AServer(task_manager=task_manager)
|
||||||
|
|
||||||
|
uvicorn.run(server.app, host="0.0.0.0", port=8000)
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_task_via_a2a():
|
||||||
|
"""Execute a task via the A2A protocol."""
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
result = await agent.execute_task_via_a2a(
|
||||||
|
task_description="Analyze the following data and provide insights: [1, 2, 3, 4, 5]",
|
||||||
|
context="This is a simple example of using the A2A protocol.",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Task result: {result}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run the example."""
|
||||||
|
server_thread = Thread(target=start_server)
|
||||||
|
server_thread.daemon = True
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
await execute_task_via_a2a()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -23,4 +23,9 @@ __all__ = [
|
|||||||
"LLM",
|
"LLM",
|
||||||
"Flow",
|
"Flow",
|
||||||
"Knowledge",
|
"Knowledge",
|
||||||
|
"A2AAgentIntegration",
|
||||||
|
"A2AClient",
|
||||||
|
"A2AServer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer
|
||||||
|
|||||||
14
src/crewai/a2a/__init__.py
Normal file
14
src/crewai/a2a/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""A2A protocol implementation for CrewAI."""
|
||||||
|
|
||||||
|
from crewai.a2a.agent import A2AAgentIntegration
|
||||||
|
from crewai.a2a.client import A2AClient
|
||||||
|
from crewai.a2a.server import A2AServer
|
||||||
|
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"A2AAgentIntegration",
|
||||||
|
"A2AClient",
|
||||||
|
"A2AServer",
|
||||||
|
"TaskManager",
|
||||||
|
"InMemoryTaskManager",
|
||||||
|
]
|
||||||
223
src/crewai/a2a/agent.py
Normal file
223
src/crewai/a2a/agent.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""
|
||||||
|
A2A protocol agent integration for CrewAI.
|
||||||
|
|
||||||
|
This module implements the integration between CrewAI agents and the A2A protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from crewai.a2a.client import A2AClient
|
||||||
|
from crewai.a2a.task_manager import TaskManager
|
||||||
|
from crewai.types.a2a import (
|
||||||
|
Artifact,
|
||||||
|
DataPart,
|
||||||
|
FilePart,
|
||||||
|
Message,
|
||||||
|
Part,
|
||||||
|
Task as A2ATask,
|
||||||
|
TaskArtifactUpdateEvent,
|
||||||
|
TaskState,
|
||||||
|
TaskStatusUpdateEvent,
|
||||||
|
TextPart,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class A2AAgentIntegration:
|
||||||
|
"""Integration between CrewAI agents and the A2A protocol."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_manager: Optional[TaskManager] = None,
|
||||||
|
client: Optional[A2AClient] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the A2A agent integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_manager: The task manager to use for handling A2A tasks.
|
||||||
|
client: The A2A client to use for sending tasks to other agents.
|
||||||
|
"""
|
||||||
|
self.task_manager = task_manager
|
||||||
|
self.client = client
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def execute_task_via_a2a(
|
||||||
|
self,
|
||||||
|
agent_url: str,
|
||||||
|
task_description: str,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: int = 300,
|
||||||
|
) -> str:
|
||||||
|
"""Execute a task via the A2A protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_url: The URL of the agent to execute the task.
|
||||||
|
task_description: The description of the task.
|
||||||
|
context: Additional context for the task.
|
||||||
|
api_key: The API key to use for authentication.
|
||||||
|
timeout: The timeout for the task execution in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the task execution.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If the task execution times out.
|
||||||
|
Exception: If there is an error executing the task.
|
||||||
|
"""
|
||||||
|
if not self.client:
|
||||||
|
self.client = A2AClient(base_url=agent_url, api_key=api_key)
|
||||||
|
|
||||||
|
parts: List[Part] = [TextPart(text=task_description)]
|
||||||
|
if context:
|
||||||
|
parts.append(
|
||||||
|
DataPart(
|
||||||
|
data={"context": context},
|
||||||
|
metadata={"type": "context"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
message = Message(role="user", parts=parts)
|
||||||
|
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
queue = await self.client.send_task_streaming(
|
||||||
|
task_id=task_id,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._wait_for_task_completion(queue, timeout)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error executing task via A2A: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _wait_for_task_completion(
|
||||||
|
self, queue: asyncio.Queue, timeout: int
|
||||||
|
) -> str:
|
||||||
|
"""Wait for a task to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue: The queue to receive task updates from.
|
||||||
|
timeout: The timeout for the task execution in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the task execution.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If the task execution times out.
|
||||||
|
Exception: If there is an error executing the task.
|
||||||
|
"""
|
||||||
|
result = ""
|
||||||
|
try:
|
||||||
|
async def _timeout():
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
await queue.put(TimeoutError(f"Task execution timed out after {timeout} seconds"))
|
||||||
|
|
||||||
|
timeout_task = asyncio.create_task(_timeout())
|
||||||
|
|
||||||
|
while True:
|
||||||
|
event = await queue.get()
|
||||||
|
|
||||||
|
if isinstance(event, Exception):
|
||||||
|
raise event
|
||||||
|
|
||||||
|
if isinstance(event, TaskStatusUpdateEvent):
|
||||||
|
if event.status.state == TaskState.COMPLETED:
|
||||||
|
if event.status.message:
|
||||||
|
for part in event.status.message.parts:
|
||||||
|
if isinstance(part, TextPart):
|
||||||
|
result += part.text
|
||||||
|
break
|
||||||
|
elif event.status.state in [TaskState.FAILED, TaskState.CANCELED]:
|
||||||
|
error_message = "Task failed"
|
||||||
|
if event.status.message:
|
||||||
|
for part in event.status.message.parts:
|
||||||
|
if isinstance(part, TextPart):
|
||||||
|
error_message = part.text
|
||||||
|
raise Exception(error_message)
|
||||||
|
elif isinstance(event, TaskArtifactUpdateEvent):
|
||||||
|
for part in event.artifact.parts:
|
||||||
|
if isinstance(part, TextPart):
|
||||||
|
result += part.text
|
||||||
|
finally:
|
||||||
|
timeout_task.cancel()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def handle_a2a_task(
|
||||||
|
self,
|
||||||
|
task: A2ATask,
|
||||||
|
agent_execute_func: Any,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle an A2A task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The A2A task to handle.
|
||||||
|
agent_execute_func: The function to execute the task.
|
||||||
|
context: Additional context for the task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If there is an error handling the task.
|
||||||
|
"""
|
||||||
|
if not self.task_manager:
|
||||||
|
raise ValueError("Task manager is required to handle A2A tasks")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.task_manager.update_task_status(
|
||||||
|
task_id=task.id,
|
||||||
|
state=TaskState.WORKING,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_description = ""
|
||||||
|
task_context = context or ""
|
||||||
|
|
||||||
|
if task.history and task.history[-1].role == "user":
|
||||||
|
message = task.history[-1]
|
||||||
|
for part in message.parts:
|
||||||
|
if isinstance(part, TextPart):
|
||||||
|
task_description += part.text
|
||||||
|
elif isinstance(part, DataPart) and part.data.get("context"):
|
||||||
|
task_context += part.data["context"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await agent_execute_func(task_description, task_context)
|
||||||
|
|
||||||
|
response_message = Message(
|
||||||
|
role="agent",
|
||||||
|
parts=[TextPart(text=result)],
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.task_manager.update_task_status(
|
||||||
|
task_id=task.id,
|
||||||
|
state=TaskState.COMPLETED,
|
||||||
|
message=response_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
artifact = Artifact(
|
||||||
|
name="result",
|
||||||
|
parts=[TextPart(text=result)],
|
||||||
|
)
|
||||||
|
await self.task_manager.add_task_artifact(
|
||||||
|
task_id=task.id,
|
||||||
|
artifact=artifact,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_message = Message(
|
||||||
|
role="agent",
|
||||||
|
parts=[TextPart(text=str(e))],
|
||||||
|
)
|
||||||
|
await self.task_manager.update_task_status(
|
||||||
|
task_id=task.id,
|
||||||
|
state=TaskState.FAILED,
|
||||||
|
message=error_message,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error handling A2A task: {e}")
|
||||||
|
raise
|
||||||
424
src/crewai/a2a/client.py
Normal file
424
src/crewai/a2a/client.py
Normal file
@@ -0,0 +1,424 @@
|
|||||||
|
"""
|
||||||
|
A2A protocol client for CrewAI.
|
||||||
|
|
||||||
|
This module implements the client for the A2A protocol in CrewAI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from crewai.types.a2a import (
|
||||||
|
A2AClientError,
|
||||||
|
A2AClientHTTPError,
|
||||||
|
A2AClientJSONError,
|
||||||
|
Artifact,
|
||||||
|
CancelTaskRequest,
|
||||||
|
CancelTaskResponse,
|
||||||
|
GetTaskPushNotificationRequest,
|
||||||
|
GetTaskPushNotificationResponse,
|
||||||
|
GetTaskRequest,
|
||||||
|
GetTaskResponse,
|
||||||
|
JSONRPCError,
|
||||||
|
JSONRPCRequest,
|
||||||
|
JSONRPCResponse,
|
||||||
|
Message,
|
||||||
|
MissingAPIKeyError,
|
||||||
|
PushNotificationConfig,
|
||||||
|
SendTaskRequest,
|
||||||
|
SendTaskResponse,
|
||||||
|
SendTaskStreamingRequest,
|
||||||
|
SetTaskPushNotificationRequest,
|
||||||
|
SetTaskPushNotificationResponse,
|
||||||
|
Task,
|
||||||
|
TaskArtifactUpdateEvent,
|
||||||
|
TaskIdParams,
|
||||||
|
TaskPushNotificationConfig,
|
||||||
|
TaskQueryParams,
|
||||||
|
TaskSendParams,
|
||||||
|
TaskState,
|
||||||
|
TaskStatusUpdateEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class A2AClient:
|
||||||
|
"""A2A protocol client implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
):
|
||||||
|
"""Initialize the A2A client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: The base URL of the A2A server.
|
||||||
|
api_key: The API key to use for authentication.
|
||||||
|
timeout: The timeout for HTTP requests in seconds.
|
||||||
|
"""
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.api_key = api_key or os.environ.get("A2A_API_KEY")
|
||||||
|
self.timeout = timeout
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def send_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
message: Message,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
accepted_output_modes: Optional[List[str]] = None,
|
||||||
|
push_notification: Optional[PushNotificationConfig] = None,
|
||||||
|
history_length: Optional[int] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Task:
|
||||||
|
"""Send a task to the A2A server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
message: The message to send.
|
||||||
|
session_id: The session ID.
|
||||||
|
accepted_output_modes: The accepted output modes.
|
||||||
|
push_notification: The push notification configuration.
|
||||||
|
history_length: The number of messages to include in the history.
|
||||||
|
metadata: Additional metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
A2AClientError: If there is an error sending the task.
|
||||||
|
"""
|
||||||
|
params = TaskSendParams(
|
||||||
|
id=task_id,
|
||||||
|
sessionId=session_id,
|
||||||
|
message=message,
|
||||||
|
acceptedOutputModes=accepted_output_modes,
|
||||||
|
pushNotification=push_notification,
|
||||||
|
historyLength=history_length,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
request = SendTaskRequest(params=params)
|
||||||
|
response = await self._send_jsonrpc_request(request)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise A2AClientError(f"Error sending task: {response.error.message}")
|
||||||
|
|
||||||
|
if not response.result:
|
||||||
|
raise A2AClientError("No result returned from send task request")
|
||||||
|
|
||||||
|
return cast(Task, response.result)
|
||||||
|
|
||||||
|
async def send_task_streaming(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
message: Message,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
accepted_output_modes: Optional[List[str]] = None,
|
||||||
|
push_notification: Optional[PushNotificationConfig] = None,
|
||||||
|
history_length: Optional[int] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> asyncio.Queue:
|
||||||
|
"""Send a task to the A2A server and subscribe to updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
message: The message to send.
|
||||||
|
session_id: The session ID.
|
||||||
|
accepted_output_modes: The accepted output modes.
|
||||||
|
push_notification: The push notification configuration.
|
||||||
|
history_length: The number of messages to include in the history.
|
||||||
|
metadata: Additional metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A queue that will receive task updates.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
A2AClientError: If there is an error sending the task.
|
||||||
|
"""
|
||||||
|
params = TaskSendParams(
|
||||||
|
id=task_id,
|
||||||
|
sessionId=session_id,
|
||||||
|
message=message,
|
||||||
|
acceptedOutputModes=accepted_output_modes,
|
||||||
|
pushNotification=push_notification,
|
||||||
|
historyLength=history_length,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
self._handle_streaming_response(
|
||||||
|
f"{self.base_url}/v1/tasks/sendSubscribe", params, queue
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return queue
|
||||||
|
|
||||||
|
async def get_task(
|
||||||
|
self, task_id: str, history_length: Optional[int] = None
|
||||||
|
) -> Task:
|
||||||
|
"""Get a task from the A2A server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
history_length: The number of messages to include in the history.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
A2AClientError: If there is an error getting the task.
|
||||||
|
"""
|
||||||
|
params = TaskQueryParams(id=task_id, historyLength=history_length)
|
||||||
|
request = GetTaskRequest(params=params)
|
||||||
|
response = await self._send_jsonrpc_request(request)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise A2AClientError(f"Error getting task: {response.error.message}")
|
||||||
|
|
||||||
|
if not response.result:
|
||||||
|
raise A2AClientError("No result returned from get task request")
|
||||||
|
|
||||||
|
return cast(Task, response.result)
|
||||||
|
|
||||||
|
async def cancel_task(self, task_id: str) -> Task:
|
||||||
|
"""Cancel a task on the A2A server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The canceled task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
A2AClientError: If there is an error canceling the task.
|
||||||
|
"""
|
||||||
|
params = TaskIdParams(id=task_id)
|
||||||
|
request = CancelTaskRequest(params=params)
|
||||||
|
response = await self._send_jsonrpc_request(request)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise A2AClientError(f"Error canceling task: {response.error.message}")
|
||||||
|
|
||||||
|
if not response.result:
|
||||||
|
raise A2AClientError("No result returned from cancel task request")
|
||||||
|
|
||||||
|
return cast(Task, response.result)
|
||||||
|
|
||||||
|
async def set_push_notification(
|
||||||
|
self, task_id: str, config: PushNotificationConfig
|
||||||
|
) -> PushNotificationConfig:
|
||||||
|
"""Set push notification for a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
config: The push notification configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The push notification configuration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
A2AClientError: If there is an error setting the push notification.
|
||||||
|
"""
|
||||||
|
params = TaskPushNotificationConfig(id=task_id, pushNotificationConfig=config)
|
||||||
|
request = SetTaskPushNotificationRequest(params=params)
|
||||||
|
response = await self._send_jsonrpc_request(request)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise A2AClientError(
|
||||||
|
f"Error setting push notification: {response.error.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.result:
|
||||||
|
raise A2AClientError(
|
||||||
|
"No result returned from set push notification request"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig
|
||||||
|
|
||||||
|
async def get_push_notification(
|
||||||
|
self, task_id: str
|
||||||
|
) -> Optional[PushNotificationConfig]:
|
||||||
|
"""Get push notification for a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The push notification configuration, or None if not set.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
A2AClientError: If there is an error getting the push notification.
|
||||||
|
"""
|
||||||
|
params = TaskIdParams(id=task_id)
|
||||||
|
request = GetTaskPushNotificationRequest(params=params)
|
||||||
|
response = await self._send_jsonrpc_request(request)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise A2AClientError(
|
||||||
|
f"Error getting push notification: {response.error.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cast(TaskPushNotificationConfig, response.result).pushNotificationConfig
|
||||||
|
|
||||||
|
async def _send_jsonrpc_request(
|
||||||
|
self, request: JSONRPCRequest
|
||||||
|
) -> JSONRPCResponse:
|
||||||
|
"""Send a JSON-RPC request to the A2A server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The JSON-RPC request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The JSON-RPC response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
A2AClientError: If there is an error sending the request.
|
||||||
|
"""
|
||||||
|
if not self.api_key:
|
||||||
|
raise MissingAPIKeyError(
|
||||||
|
"API key is required. Set it in the constructor or as the A2A_API_KEY environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/v1/jsonrpc",
|
||||||
|
headers=headers,
|
||||||
|
json=request.model_dump(),
|
||||||
|
timeout=self.timeout,
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise A2AClientHTTPError(
|
||||||
|
response.status, await response.text()
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await response.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise A2AClientJSONError(str(e))
|
||||||
|
|
||||||
|
try:
|
||||||
|
return JSONRPCResponse.model_validate(data)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise A2AClientError(f"Invalid response: {e}")
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise A2AClientError(f"HTTP error: {e}")
|
||||||
|
|
||||||
|
async def _handle_streaming_response(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
params: TaskSendParams,
|
||||||
|
queue: asyncio.Queue,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a streaming response from the A2A server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to send the request to.
|
||||||
|
params: The task send parameters.
|
||||||
|
queue: The queue to put events into.
|
||||||
|
"""
|
||||||
|
if not self.api_key:
|
||||||
|
await queue.put(
|
||||||
|
Exception(
|
||||||
|
"API key is required. Set it in the constructor or as the A2A_API_KEY environment variable."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=params.model_dump(),
|
||||||
|
timeout=self.timeout,
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
await queue.put(
|
||||||
|
A2AClientHTTPError(response.status, await response.text())
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
async for line in response.content:
|
||||||
|
line = line.decode("utf-8")
|
||||||
|
buffer += line
|
||||||
|
|
||||||
|
if buffer.endswith("\n\n"):
|
||||||
|
event_data = self._parse_sse_event(buffer)
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
if event_data:
|
||||||
|
event_type = event_data.get("event")
|
||||||
|
data = event_data.get("data")
|
||||||
|
|
||||||
|
if event_type == "status":
|
||||||
|
try:
|
||||||
|
event = TaskStatusUpdateEvent.model_validate_json(data)
|
||||||
|
await queue.put(event)
|
||||||
|
|
||||||
|
if event.final:
|
||||||
|
break
|
||||||
|
except ValidationError as e:
|
||||||
|
await queue.put(
|
||||||
|
A2AClientError(f"Invalid status event: {e}")
|
||||||
|
)
|
||||||
|
elif event_type == "artifact":
|
||||||
|
try:
|
||||||
|
event = TaskArtifactUpdateEvent.model_validate_json(data)
|
||||||
|
await queue.put(event)
|
||||||
|
except ValidationError as e:
|
||||||
|
await queue.put(
|
||||||
|
A2AClientError(f"Invalid artifact event: {e}")
|
||||||
|
)
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
await queue.put(A2AClientError(f"HTTP error: {e}"))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
await queue.put(A2AClientError(f"Error handling streaming response: {e}"))
|
||||||
|
|
||||||
|
def _parse_sse_event(self, data: str) -> Dict[str, str]:
|
||||||
|
"""Parse an SSE event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The SSE event data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the event type and data.
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for line in data.split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.startswith("event:"):
|
||||||
|
result["event"] = line[6:].strip()
|
||||||
|
elif line.startswith("data:"):
|
||||||
|
result["data"] = line[5:].strip()
|
||||||
|
|
||||||
|
return result
|
||||||
368
src/crewai/a2a/server.py
Normal file
368
src/crewai/a2a/server.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
"""
|
||||||
|
A2A protocol server for CrewAI.
|
||||||
|
|
||||||
|
This module implements the server for the A2A protocol in CrewAI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Request, Response
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
|
||||||
|
from crewai.types.a2a import (
|
||||||
|
A2ARequest,
|
||||||
|
CancelTaskRequest,
|
||||||
|
CancelTaskResponse,
|
||||||
|
ContentTypeNotSupportedError,
|
||||||
|
GetTaskPushNotificationRequest,
|
||||||
|
GetTaskPushNotificationResponse,
|
||||||
|
GetTaskRequest,
|
||||||
|
GetTaskResponse,
|
||||||
|
InternalError,
|
||||||
|
InvalidParamsError,
|
||||||
|
InvalidRequestError,
|
||||||
|
JSONParseError,
|
||||||
|
JSONRPCError,
|
||||||
|
JSONRPCRequest,
|
||||||
|
JSONRPCResponse,
|
||||||
|
MethodNotFoundError,
|
||||||
|
SendTaskRequest,
|
||||||
|
SendTaskResponse,
|
||||||
|
SendTaskStreamingRequest,
|
||||||
|
SendTaskStreamingResponse,
|
||||||
|
SetTaskPushNotificationRequest,
|
||||||
|
SetTaskPushNotificationResponse,
|
||||||
|
Task,
|
||||||
|
TaskArtifactUpdateEvent,
|
||||||
|
TaskIdParams,
|
||||||
|
TaskNotCancelableError,
|
||||||
|
TaskNotFoundError,
|
||||||
|
TaskPushNotificationConfig,
|
||||||
|
TaskQueryParams,
|
||||||
|
TaskSendParams,
|
||||||
|
TaskState,
|
||||||
|
TaskStatusUpdateEvent,
|
||||||
|
UnsupportedOperationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class A2AServer:
|
||||||
|
"""A2A protocol server implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_manager: Optional[TaskManager] = None,
|
||||||
|
enable_cors: bool = True,
|
||||||
|
cors_origins: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the A2A server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_manager: The task manager to use. If None, an InMemoryTaskManager will be created.
|
||||||
|
enable_cors: Whether to enable CORS.
|
||||||
|
cors_origins: The CORS origins to allow.
|
||||||
|
"""
|
||||||
|
self.app = FastAPI(title="A2A Server")
|
||||||
|
self.task_manager = task_manager or InMemoryTaskManager()
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if enable_cors:
|
||||||
|
self.app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=cors_origins or ["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.app.post("/v1/jsonrpc")(self.handle_jsonrpc)
|
||||||
|
self.app.post("/v1/tasks/send")(self.handle_send_task)
|
||||||
|
self.app.post("/v1/tasks/sendSubscribe")(self.handle_send_task_subscribe)
|
||||||
|
self.app.post("/v1/tasks/{task_id}/cancel")(self.handle_cancel_task)
|
||||||
|
self.app.get("/v1/tasks/{task_id}")(self.handle_get_task)
|
||||||
|
|
||||||
|
async def handle_jsonrpc(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle JSON-RPC requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The FastAPI request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return JSONResponse(
|
||||||
|
content=JSONRPCResponse(
|
||||||
|
id=None, error=JSONParseError()
|
||||||
|
).model_dump(),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(body, list):
|
||||||
|
responses = []
|
||||||
|
for req_data in body:
|
||||||
|
response = await self._process_jsonrpc_request(req_data)
|
||||||
|
responses.append(response.model_dump())
|
||||||
|
return JSONResponse(content=responses)
|
||||||
|
else:
|
||||||
|
response = await self._process_jsonrpc_request(body)
|
||||||
|
return JSONResponse(content=response.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception("Error processing JSON-RPC request")
|
||||||
|
return JSONResponse(
|
||||||
|
content=JSONRPCResponse(
|
||||||
|
id=body.get("id") if isinstance(body, dict) else None,
|
||||||
|
error=InternalError(data=str(e)),
|
||||||
|
).model_dump(),
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_jsonrpc_request(
|
||||||
|
self, request_data: Dict[str, Any]
|
||||||
|
) -> JSONRPCResponse:
|
||||||
|
"""Process a JSON-RPC request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_data: The JSON-RPC request data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON-RPC response.
|
||||||
|
"""
|
||||||
|
if not isinstance(request_data, dict) or request_data.get("jsonrpc") != "2.0":
|
||||||
|
return JSONRPCResponse(
|
||||||
|
id=request_data.get("id") if isinstance(request_data, dict) else None,
|
||||||
|
error=InvalidRequestError(),
|
||||||
|
)
|
||||||
|
|
||||||
|
request_id = request_data.get("id")
|
||||||
|
method = request_data.get("method")
|
||||||
|
|
||||||
|
if not method:
|
||||||
|
return JSONRPCResponse(
|
||||||
|
id=request_id,
|
||||||
|
error=InvalidRequestError(message="Method is required"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = A2ARequest.validate_python(request_data)
|
||||||
|
except ValidationError as e:
|
||||||
|
return JSONRPCResponse(
|
||||||
|
id=request_id,
|
||||||
|
error=InvalidParamsError(data=str(e)),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(request, SendTaskRequest):
|
||||||
|
task = await self._handle_send_task(request.params)
|
||||||
|
return SendTaskResponse(id=request_id, result=task)
|
||||||
|
elif isinstance(request, GetTaskRequest):
|
||||||
|
task = await self.task_manager.get_task(
|
||||||
|
request.params.id, request.params.historyLength
|
||||||
|
)
|
||||||
|
return GetTaskResponse(id=request_id, result=task)
|
||||||
|
elif isinstance(request, CancelTaskRequest):
|
||||||
|
task = await self.task_manager.cancel_task(request.params.id)
|
||||||
|
return CancelTaskResponse(id=request_id, result=task)
|
||||||
|
elif isinstance(request, SetTaskPushNotificationRequest):
|
||||||
|
config = await self.task_manager.set_push_notification(
|
||||||
|
request.params.id, request.params.pushNotificationConfig
|
||||||
|
)
|
||||||
|
return SetTaskPushNotificationResponse(
|
||||||
|
id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config)
|
||||||
|
)
|
||||||
|
elif isinstance(request, GetTaskPushNotificationRequest):
|
||||||
|
config = await self.task_manager.get_push_notification(
|
||||||
|
request.params.id
|
||||||
|
)
|
||||||
|
if config:
|
||||||
|
return GetTaskPushNotificationResponse(
|
||||||
|
id=request_id, result=TaskPushNotificationConfig(id=request.params.id, pushNotificationConfig=config)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return GetTaskPushNotificationResponse(id=request_id, result=None)
|
||||||
|
elif isinstance(request, SendTaskStreamingRequest):
|
||||||
|
return JSONRPCResponse(
|
||||||
|
id=request_id,
|
||||||
|
error=UnsupportedOperationError(
|
||||||
|
message="Streaming requests should be sent to the streaming endpoint"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return JSONRPCResponse(
|
||||||
|
id=request_id,
|
||||||
|
error=MethodNotFoundError(),
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
return JSONRPCResponse(
|
||||||
|
id=request_id,
|
||||||
|
error=TaskNotFoundError(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error handling {method} request")
|
||||||
|
return JSONRPCResponse(
|
||||||
|
id=request_id,
|
||||||
|
error=InternalError(data=str(e)),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_send_task(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle send task requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The FastAPI request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
params = TaskSendParams.model_validate(body)
|
||||||
|
task = await self._handle_send_task(params)
|
||||||
|
return JSONResponse(content=task.model_dump())
|
||||||
|
except ValidationError as e:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(e)},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception("Error handling send task request")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(e)},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_send_task(self, params: TaskSendParams) -> Task:
|
||||||
|
"""Handle send task requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: The task send parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created task.
|
||||||
|
"""
|
||||||
|
task = await self.task_manager.create_task(
|
||||||
|
task_id=params.id,
|
||||||
|
session_id=params.sessionId,
|
||||||
|
message=params.message,
|
||||||
|
metadata=params.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.task_manager.update_task_status(
|
||||||
|
task_id=params.id,
|
||||||
|
state=TaskState.WORKING,
|
||||||
|
)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def handle_send_task_subscribe(self, request: Request) -> StreamingResponse:
|
||||||
|
"""Handle send task subscribe requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The FastAPI request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A streaming response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
params = TaskSendParams.model_validate(body)
|
||||||
|
|
||||||
|
task = await self._handle_send_task(params)
|
||||||
|
|
||||||
|
queue = await self.task_manager.subscribe_to_task(params.id)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
self._stream_task_updates(params.id, queue),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(e)},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception("Error handling send task subscribe request")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(e)},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_task_updates(
|
||||||
|
self, task_id: str, queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
"""Stream task updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
queue: The queue to receive updates from.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
SSE formatted events.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
event = await queue.get()
|
||||||
|
|
||||||
|
if isinstance(event, TaskStatusUpdateEvent):
|
||||||
|
event_type = "status"
|
||||||
|
elif isinstance(event, TaskArtifactUpdateEvent):
|
||||||
|
event_type = "artifact"
|
||||||
|
else:
|
||||||
|
event_type = "unknown"
|
||||||
|
|
||||||
|
data = json.dumps(event.model_dump())
|
||||||
|
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||||
|
|
||||||
|
if isinstance(event, TaskStatusUpdateEvent) and event.final:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
await self.task_manager.unsubscribe_from_task(task_id, queue)
|
||||||
|
|
||||||
|
async def handle_get_task(self, task_id: str, request: Request) -> JSONResponse:
|
||||||
|
"""Handle get task requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
request: The FastAPI request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
history_length = request.query_params.get("historyLength")
|
||||||
|
history_length = int(history_length) if history_length else None
|
||||||
|
|
||||||
|
task = await self.task_manager.get_task(task_id, history_length)
|
||||||
|
return JSONResponse(content=task.model_dump())
|
||||||
|
except KeyError:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error handling get task request for {task_id}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
async def handle_cancel_task(self, task_id: str, request: Request) -> JSONResponse:
|
||||||
|
"""Handle cancel task requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
request: The FastAPI request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task = await self.task_manager.cancel_task(task_id)
|
||||||
|
return JSONResponse(content=task.model_dump())
|
||||||
|
except KeyError:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error handling cancel task request for {task_id}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
438
src/crewai/a2a/task_manager.py
Normal file
438
src/crewai/a2a/task_manager.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
"""
|
||||||
|
A2A protocol task manager for CrewAI.
|
||||||
|
|
||||||
|
This module implements the task manager for the A2A protocol in CrewAI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from crewai.types.a2a import (
|
||||||
|
Artifact,
|
||||||
|
Message,
|
||||||
|
PushNotificationConfig,
|
||||||
|
Task,
|
||||||
|
TaskArtifactUpdateEvent,
|
||||||
|
TaskState,
|
||||||
|
TaskStatus,
|
||||||
|
TaskStatusUpdateEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskManager(ABC):
|
||||||
|
"""Abstract base class for A2A task managers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
message: Optional[Message] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Task:
|
||||||
|
"""Create a new task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
session_id: The session ID.
|
||||||
|
message: The initial message.
|
||||||
|
metadata: Additional metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created task.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_task(
|
||||||
|
self, task_id: str, history_length: Optional[int] = None
|
||||||
|
) -> Task:
|
||||||
|
"""Get a task by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
history_length: The number of messages to include in the history.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_task_status(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
state: TaskState,
|
||||||
|
message: Optional[Message] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> TaskStatusUpdateEvent:
|
||||||
|
"""Update the status of a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
state: The new state of the task.
|
||||||
|
message: An optional message to include with the status update.
|
||||||
|
metadata: Additional metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task status update event.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_task_artifact(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
artifact: Artifact,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> TaskArtifactUpdateEvent:
|
||||||
|
"""Add an artifact to a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
artifact: The artifact to add.
|
||||||
|
metadata: Additional metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task artifact update event.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def cancel_task(self, task_id: str) -> Task:
|
||||||
|
"""Cancel a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The canceled task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def set_push_notification(
|
||||||
|
self, task_id: str, config: PushNotificationConfig
|
||||||
|
) -> PushNotificationConfig:
|
||||||
|
"""Set push notification for a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
config: The push notification configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The push notification configuration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_push_notification(
|
||||||
|
self, task_id: str
|
||||||
|
) -> Optional[PushNotificationConfig]:
|
||||||
|
"""Get push notification for a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The push notification configuration, or None if not set.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryTaskManager(TaskManager):
|
||||||
|
"""In-memory implementation of the A2A task manager."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the in-memory task manager."""
|
||||||
|
self._tasks: Dict[str, Task] = {}
|
||||||
|
self._push_notifications: Dict[str, PushNotificationConfig] = {}
|
||||||
|
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
|
||||||
|
self._logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def create_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
message: Optional[Message] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Task:
|
||||||
|
"""Create a new task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
session_id: The session ID.
|
||||||
|
message: The initial message.
|
||||||
|
metadata: Additional metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created task.
|
||||||
|
"""
|
||||||
|
if task_id in self._tasks:
|
||||||
|
return self._tasks[task_id]
|
||||||
|
|
||||||
|
session_id = session_id or uuid4().hex
|
||||||
|
status = TaskStatus(
|
||||||
|
state=TaskState.SUBMITTED,
|
||||||
|
message=message,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
id=task_id,
|
||||||
|
sessionId=session_id,
|
||||||
|
status=status,
|
||||||
|
artifacts=[],
|
||||||
|
history=[message] if message else [],
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._tasks[task_id] = task
|
||||||
|
self._task_subscribers[task_id] = set()
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def get_task(
|
||||||
|
self, task_id: str, history_length: Optional[int] = None
|
||||||
|
) -> Task:
|
||||||
|
"""Get a task by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
history_length: The number of messages to include in the history.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
raise KeyError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
task = self._tasks[task_id]
|
||||||
|
if history_length is not None and task.history:
|
||||||
|
task_copy = task.model_copy(deep=True)
|
||||||
|
task_copy.history = task.history[-history_length:]
|
||||||
|
return task_copy
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def update_task_status(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
state: TaskState,
|
||||||
|
message: Optional[Message] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> TaskStatusUpdateEvent:
|
||||||
|
"""Update the status of a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
state: The new state of the task.
|
||||||
|
message: An optional message to include with the status update.
|
||||||
|
metadata: Additional metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task status update event.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
raise KeyError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
task = self._tasks[task_id]
|
||||||
|
status = TaskStatus(
|
||||||
|
state=state,
|
||||||
|
message=message,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
)
|
||||||
|
task.status = status
|
||||||
|
|
||||||
|
if message and task.history is not None:
|
||||||
|
task.history.append(message)
|
||||||
|
|
||||||
|
event = TaskStatusUpdateEvent(
|
||||||
|
id=task_id,
|
||||||
|
status=status,
|
||||||
|
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED],
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._notify_subscribers(task_id, event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def add_task_artifact(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
artifact: Artifact,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> TaskArtifactUpdateEvent:
|
||||||
|
"""Add an artifact to a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
artifact: The artifact to add.
|
||||||
|
metadata: Additional metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task artifact update event.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
raise KeyError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
task = self._tasks[task_id]
|
||||||
|
if task.artifacts is None:
|
||||||
|
task.artifacts = []
|
||||||
|
|
||||||
|
if artifact.append and task.artifacts:
|
||||||
|
for existing in task.artifacts:
|
||||||
|
if existing.name == artifact.name:
|
||||||
|
existing.parts.extend(artifact.parts)
|
||||||
|
existing.lastChunk = artifact.lastChunk
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
task.artifacts.append(artifact)
|
||||||
|
else:
|
||||||
|
task.artifacts.append(artifact)
|
||||||
|
|
||||||
|
event = TaskArtifactUpdateEvent(
|
||||||
|
id=task_id,
|
||||||
|
artifact=artifact,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._notify_subscribers(task_id, event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def cancel_task(self, task_id: str) -> Task:
|
||||||
|
"""Cancel a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The canceled task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
raise KeyError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
task = self._tasks[task_id]
|
||||||
|
|
||||||
|
if task.status.state not in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED]:
|
||||||
|
await self.update_task_status(task_id, TaskState.CANCELED)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def set_push_notification(
|
||||||
|
self, task_id: str, config: PushNotificationConfig
|
||||||
|
) -> PushNotificationConfig:
|
||||||
|
"""Set push notification for a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
config: The push notification configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The push notification configuration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
raise KeyError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
self._push_notifications[task_id] = config
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def get_push_notification(
|
||||||
|
self, task_id: str
|
||||||
|
) -> Optional[PushNotificationConfig]:
|
||||||
|
"""Get push notification for a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The push notification configuration, or None if not set.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
raise KeyError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
return self._push_notifications.get(task_id)
|
||||||
|
|
||||||
|
async def subscribe_to_task(self, task_id: str) -> asyncio.Queue:
|
||||||
|
"""Subscribe to task updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A queue that will receive task updates.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the task is not found.
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
raise KeyError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._task_subscribers.setdefault(task_id, set()).add(queue)
|
||||||
|
return queue
|
||||||
|
|
||||||
|
async def unsubscribe_from_task(self, task_id: str, queue: asyncio.Queue) -> None:
|
||||||
|
"""Unsubscribe from task updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
queue: The queue to unsubscribe.
|
||||||
|
"""
|
||||||
|
if task_id in self._task_subscribers:
|
||||||
|
self._task_subscribers[task_id].discard(queue)
|
||||||
|
|
||||||
|
async def _notify_subscribers(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent],
|
||||||
|
) -> None:
|
||||||
|
"""Notify subscribers of a task update.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
event: The event to send to subscribers.
|
||||||
|
"""
|
||||||
|
if task_id in self._task_subscribers:
|
||||||
|
for queue in self._task_subscribers[task_id]:
|
||||||
|
await queue.put(event)
|
||||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||||
|
|
||||||
|
from crewai.a2a import A2AAgentIntegration
|
||||||
from crewai.agents import CacheHandler
|
from crewai.agents import CacheHandler
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
@@ -131,14 +132,29 @@ class Agent(BaseAgent):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Knowledge sources for the agent.",
|
description="Knowledge sources for the agent.",
|
||||||
)
|
)
|
||||||
|
a2a_enabled: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether the agent supports the A2A protocol.",
|
||||||
|
)
|
||||||
|
a2a_url: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The URL where the agent's A2A server is hosted.",
|
||||||
|
)
|
||||||
_knowledge: Optional[Knowledge] = PrivateAttr(
|
_knowledge: Optional[Knowledge] = PrivateAttr(
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
_a2a_integration: Optional[A2AAgentIntegration] = PrivateAttr(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def post_init_setup(self):
|
def post_init_setup(self):
|
||||||
self._set_knowledge()
|
self._set_knowledge()
|
||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role
|
||||||
|
|
||||||
|
if self.a2a_enabled:
|
||||||
|
self._a2a_integration = A2AAgentIntegration()
|
||||||
|
|
||||||
unaccepted_attributes = [
|
unaccepted_attributes = [
|
||||||
"AWS_ACCESS_KEY_ID",
|
"AWS_ACCESS_KEY_ID",
|
||||||
"AWS_SECRET_ACCESS_KEY",
|
"AWS_SECRET_ACCESS_KEY",
|
||||||
@@ -356,6 +372,102 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def execute_task_via_a2a(
|
||||||
|
self,
|
||||||
|
task_description: str,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
agent_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: int = 300,
|
||||||
|
) -> str:
|
||||||
|
"""Execute a task via the A2A protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_description: The description of the task.
|
||||||
|
context: Additional context for the task.
|
||||||
|
agent_url: The URL of the agent to execute the task. Defaults to self.a2a_url.
|
||||||
|
api_key: The API key to use for authentication.
|
||||||
|
timeout: The timeout for the task execution in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the task execution.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If A2A is not enabled or no agent URL is provided.
|
||||||
|
TimeoutError: If the task execution times out.
|
||||||
|
Exception: If there is an error executing the task.
|
||||||
|
"""
|
||||||
|
if not self.a2a_enabled:
|
||||||
|
raise ValueError("A2A protocol is not enabled for this agent")
|
||||||
|
|
||||||
|
if not self._a2a_integration:
|
||||||
|
self._a2a_integration = A2AAgentIntegration()
|
||||||
|
|
||||||
|
url = agent_url or self.a2a_url
|
||||||
|
if not url:
|
||||||
|
raise ValueError("No A2A agent URL provided")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
if asyncio.get_event_loop().is_running():
|
||||||
|
return await self._a2a_integration.execute_task_via_a2a(
|
||||||
|
agent_url=url,
|
||||||
|
task_description=task_description,
|
||||||
|
context=context,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return asyncio.run(self._a2a_integration.execute_task_via_a2a(
|
||||||
|
agent_url=url,
|
||||||
|
task_description=task_description,
|
||||||
|
context=context,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.exception(f"Error executing task via A2A: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def handle_a2a_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
task_description: str,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Handle an A2A task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the A2A task.
|
||||||
|
task_description: The description of the task.
|
||||||
|
context: Additional context for the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the task execution.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If A2A is not enabled.
|
||||||
|
Exception: If there is an error handling the task.
|
||||||
|
"""
|
||||||
|
if not self.a2a_enabled:
|
||||||
|
raise ValueError("A2A protocol is not enabled for this agent")
|
||||||
|
|
||||||
|
if not self._a2a_integration:
|
||||||
|
self._a2a_integration = A2AAgentIntegration()
|
||||||
|
|
||||||
|
# Create a Task object from the task description
|
||||||
|
task = Task(
|
||||||
|
description=task_description,
|
||||||
|
agent=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.execute_task(task, context)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.exception(f"Error handling A2A task: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
def create_agent_executor(
|
def create_agent_executor(
|
||||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Type definitions for CrewAI."""
|
||||||
|
|||||||
423
src/crewai/types/a2a.py
Normal file
423
src/crewai/types/a2a.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
"""
|
||||||
|
A2A protocol types for CrewAI.
|
||||||
|
|
||||||
|
This module implements the A2A (Agent-to-Agent) protocol types as defined by Google.
|
||||||
|
The A2A protocol enables interoperability between different agent systems.
|
||||||
|
|
||||||
|
For more information, see: https://developers.googleblog.com/en/a2a-a-new-era-of-agent-interoperability/
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Self, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, field_serializer, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class TaskState(str, Enum):
|
||||||
|
"""Task state in the A2A protocol."""
|
||||||
|
SUBMITTED = 'submitted'
|
||||||
|
WORKING = 'working'
|
||||||
|
INPUT_REQUIRED = 'input-required'
|
||||||
|
COMPLETED = 'completed'
|
||||||
|
CANCELED = 'canceled'
|
||||||
|
FAILED = 'failed'
|
||||||
|
UNKNOWN = 'unknown'
|
||||||
|
|
||||||
|
|
||||||
|
class TextPart(BaseModel):
|
||||||
|
"""Text part in the A2A protocol."""
|
||||||
|
type: Literal['text'] = 'text'
|
||||||
|
text: str
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileContent(BaseModel):
|
||||||
|
"""File content in the A2A protocol."""
|
||||||
|
name: Optional[str] = None
|
||||||
|
mimeType: Optional[str] = None
|
||||||
|
bytes: Optional[str] = None
|
||||||
|
uri: Optional[str] = None
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def check_content(self) -> Self:
|
||||||
|
"""Validate file content has either bytes or uri."""
|
||||||
|
if not (self.bytes or self.uri):
|
||||||
|
raise ValueError(
|
||||||
|
"Either 'bytes' or 'uri' must be present in the file data"
|
||||||
|
)
|
||||||
|
if self.bytes and self.uri:
|
||||||
|
raise ValueError(
|
||||||
|
"Only one of 'bytes' or 'uri' can be present in the file data"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class FilePart(BaseModel):
|
||||||
|
"""File part in the A2A protocol."""
|
||||||
|
type: Literal['file'] = 'file'
|
||||||
|
file: FileContent
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DataPart(BaseModel):
|
||||||
|
"""Data part in the A2A protocol."""
|
||||||
|
type: Literal['data'] = 'data'
|
||||||
|
data: Dict[str, Any]
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator='type')]
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
"""Message in the A2A protocol."""
|
||||||
|
role: Literal['user', 'agent']
|
||||||
|
parts: List[Part]
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(BaseModel):
|
||||||
|
"""Task status in the A2A protocol."""
|
||||||
|
state: TaskState
|
||||||
|
message: Optional[Message] = None
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
@field_serializer('timestamp')
|
||||||
|
def serialize_dt(self, dt: datetime, _info):
|
||||||
|
"""Serialize datetime to ISO format."""
|
||||||
|
return dt.isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
class Artifact(BaseModel):
|
||||||
|
"""Artifact in the A2A protocol."""
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
parts: List[Part]
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
index: int = 0
|
||||||
|
append: Optional[bool] = None
|
||||||
|
lastChunk: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Task(BaseModel):
|
||||||
|
"""Task in the A2A protocol."""
|
||||||
|
id: str
|
||||||
|
sessionId: Optional[str] = None
|
||||||
|
status: TaskStatus
|
||||||
|
artifacts: Optional[List[Artifact]] = None
|
||||||
|
history: Optional[List[Message]] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusUpdateEvent(BaseModel):
|
||||||
|
"""Task status update event in the A2A protocol."""
|
||||||
|
id: str
|
||||||
|
status: TaskStatus
|
||||||
|
final: bool = False
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskArtifactUpdateEvent(BaseModel):
|
||||||
|
"""Task artifact update event in the A2A protocol."""
|
||||||
|
id: str
|
||||||
|
artifact: Artifact
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationInfo(BaseModel):
|
||||||
|
"""Authentication information in the A2A protocol."""
|
||||||
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
schemes: List[str]
|
||||||
|
credentials: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PushNotificationConfig(BaseModel):
|
||||||
|
"""Push notification configuration in the A2A protocol."""
|
||||||
|
url: str
|
||||||
|
token: Optional[str] = None
|
||||||
|
authentication: Optional[AuthenticationInfo] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskIdParams(BaseModel):
|
||||||
|
"""Task ID parameters in the A2A protocol."""
|
||||||
|
id: str
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskQueryParams(TaskIdParams):
|
||||||
|
"""Task query parameters in the A2A protocol."""
|
||||||
|
historyLength: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskSendParams(BaseModel):
|
||||||
|
"""Task send parameters in the A2A protocol."""
|
||||||
|
id: str
|
||||||
|
sessionId: str = Field(default_factory=lambda: uuid4().hex)
|
||||||
|
message: Message
|
||||||
|
acceptedOutputModes: Optional[List[str]] = None
|
||||||
|
pushNotification: Optional[PushNotificationConfig] = None
|
||||||
|
historyLength: Optional[int] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskPushNotificationConfig(BaseModel):
|
||||||
|
"""Task push notification configuration in the A2A protocol."""
|
||||||
|
id: str
|
||||||
|
pushNotificationConfig: PushNotificationConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCMessage(BaseModel):
|
||||||
|
"""JSON-RPC message in the A2A protocol."""
|
||||||
|
jsonrpc: Literal['2.0'] = '2.0'
|
||||||
|
id: Optional[Union[int, str]] = Field(default_factory=lambda: uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCRequest(JSONRPCMessage):
|
||||||
|
"""JSON-RPC request in the A2A protocol."""
|
||||||
|
method: str
|
||||||
|
params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCError(BaseModel):
|
||||||
|
"""JSON-RPC error in the A2A protocol."""
|
||||||
|
code: int
|
||||||
|
message: str
|
||||||
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCResponse(JSONRPCMessage):
|
||||||
|
"""JSON-RPC response in the A2A protocol."""
|
||||||
|
result: Optional[Any] = None
|
||||||
|
error: Optional[JSONRPCError] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SendTaskRequest(JSONRPCRequest):
|
||||||
|
"""Send task request in the A2A protocol."""
|
||||||
|
method: Literal['tasks/send'] = 'tasks/send'
|
||||||
|
params: TaskSendParams
|
||||||
|
|
||||||
|
|
||||||
|
class SendTaskResponse(JSONRPCResponse):
|
||||||
|
"""Send task response in the A2A protocol."""
|
||||||
|
result: Optional[Task] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SendTaskStreamingRequest(JSONRPCRequest):
|
||||||
|
"""Send task streaming request in the A2A protocol."""
|
||||||
|
method: Literal['tasks/sendSubscribe'] = 'tasks/sendSubscribe'
|
||||||
|
params: TaskSendParams
|
||||||
|
|
||||||
|
|
||||||
|
class SendTaskStreamingResponse(JSONRPCResponse):
|
||||||
|
"""Send task streaming response in the A2A protocol."""
|
||||||
|
result: Optional[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskRequest(JSONRPCRequest):
|
||||||
|
"""Get task request in the A2A protocol."""
|
||||||
|
method: Literal['tasks/get'] = 'tasks/get'
|
||||||
|
params: TaskQueryParams
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskResponse(JSONRPCResponse):
|
||||||
|
"""Get task response in the A2A protocol."""
|
||||||
|
result: Optional[Task] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CancelTaskRequest(JSONRPCRequest):
|
||||||
|
"""Cancel task request in the A2A protocol."""
|
||||||
|
method: Literal['tasks/cancel'] = 'tasks/cancel'
|
||||||
|
params: TaskIdParams
|
||||||
|
|
||||||
|
|
||||||
|
class CancelTaskResponse(JSONRPCResponse):
|
||||||
|
"""Cancel task response in the A2A protocol."""
|
||||||
|
result: Optional[Task] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SetTaskPushNotificationRequest(JSONRPCRequest):
|
||||||
|
"""Set task push notification request in the A2A protocol."""
|
||||||
|
method: Literal['tasks/pushNotification/set'] = 'tasks/pushNotification/set'
|
||||||
|
params: TaskPushNotificationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SetTaskPushNotificationResponse(JSONRPCResponse):
|
||||||
|
"""Set task push notification response in the A2A protocol."""
|
||||||
|
result: Optional[TaskPushNotificationConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskPushNotificationRequest(JSONRPCRequest):
|
||||||
|
"""Get task push notification request in the A2A protocol."""
|
||||||
|
method: Literal['tasks/pushNotification/get'] = 'tasks/pushNotification/get'
|
||||||
|
params: TaskIdParams
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskPushNotificationResponse(JSONRPCResponse):
|
||||||
|
"""Get task push notification response in the A2A protocol."""
|
||||||
|
result: Optional[TaskPushNotificationConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskResubscriptionRequest(JSONRPCRequest):
|
||||||
|
"""Task resubscription request in the A2A protocol."""
|
||||||
|
method: Literal['tasks/resubscribe'] = 'tasks/resubscribe'
|
||||||
|
params: TaskIdParams
|
||||||
|
|
||||||
|
|
||||||
|
A2ARequest = TypeAdapter(
|
||||||
|
Annotated[
|
||||||
|
Union[
|
||||||
|
SendTaskRequest,
|
||||||
|
GetTaskRequest,
|
||||||
|
CancelTaskRequest,
|
||||||
|
SetTaskPushNotificationRequest,
|
||||||
|
GetTaskPushNotificationRequest,
|
||||||
|
TaskResubscriptionRequest,
|
||||||
|
SendTaskStreamingRequest,
|
||||||
|
],
|
||||||
|
Field(discriminator='method'),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONParseError(JSONRPCError):
|
||||||
|
"""JSON parse error in the A2A protocol."""
|
||||||
|
code: int = -32700
|
||||||
|
message: str = 'Invalid JSON payload'
|
||||||
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRequestError(JSONRPCError):
|
||||||
|
"""Invalid request error in the A2A protocol."""
|
||||||
|
code: int = -32600
|
||||||
|
message: str = 'Request payload validation error'
|
||||||
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MethodNotFoundError(JSONRPCError):
|
||||||
|
"""Method not found error in the A2A protocol."""
|
||||||
|
code: int = -32601
|
||||||
|
message: str = 'Method not found'
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidParamsError(JSONRPCError):
|
||||||
|
"""Invalid parameters error in the A2A protocol."""
|
||||||
|
code: int = -32602
|
||||||
|
message: str = 'Invalid parameters'
|
||||||
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class InternalError(JSONRPCError):
|
||||||
|
"""Internal error in the A2A protocol."""
|
||||||
|
code: int = -32603
|
||||||
|
message: str = 'Internal error'
|
||||||
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskNotFoundError(JSONRPCError):
|
||||||
|
"""Task not found error in the A2A protocol."""
|
||||||
|
code: int = -32001
|
||||||
|
message: str = 'Task not found'
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskNotCancelableError(JSONRPCError):
|
||||||
|
"""Task not cancelable error in the A2A protocol."""
|
||||||
|
code: int = -32002
|
||||||
|
message: str = 'Task cannot be canceled'
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PushNotificationNotSupportedError(JSONRPCError):
|
||||||
|
"""Push notification not supported error in the A2A protocol."""
|
||||||
|
code: int = -32003
|
||||||
|
message: str = 'Push Notification is not supported'
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedOperationError(JSONRPCError):
|
||||||
|
"""Unsupported operation error in the A2A protocol."""
|
||||||
|
code: int = -32004
|
||||||
|
message: str = 'This operation is not supported'
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ContentTypeNotSupportedError(JSONRPCError):
|
||||||
|
"""Content type not supported error in the A2A protocol."""
|
||||||
|
code: int = -32005
|
||||||
|
message: str = 'Incompatible content types'
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentProvider(BaseModel):
|
||||||
|
"""Agent provider in the A2A protocol."""
|
||||||
|
organization: str
|
||||||
|
url: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCapabilities(BaseModel):
|
||||||
|
"""Agent capabilities in the A2A protocol."""
|
||||||
|
streaming: bool = False
|
||||||
|
pushNotifications: bool = False
|
||||||
|
stateTransitionHistory: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AgentAuthentication(BaseModel):
|
||||||
|
"""Agent authentication in the A2A protocol."""
|
||||||
|
schemes: List[str]
|
||||||
|
credentials: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSkill(BaseModel):
|
||||||
|
"""Agent skill in the A2A protocol."""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
examples: Optional[List[str]] = None
|
||||||
|
inputModes: Optional[List[str]] = None
|
||||||
|
outputModes: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCard(BaseModel):
|
||||||
|
"""Agent card in the A2A protocol."""
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
url: str
|
||||||
|
provider: Optional[AgentProvider] = None
|
||||||
|
version: str
|
||||||
|
documentationUrl: Optional[str] = None
|
||||||
|
capabilities: AgentCapabilities
|
||||||
|
authentication: Optional[AgentAuthentication] = None
|
||||||
|
defaultInputModes: List[str] = ['text']
|
||||||
|
defaultOutputModes: List[str] = ['text']
|
||||||
|
skills: List[AgentSkill]
|
||||||
|
|
||||||
|
|
||||||
|
class A2AClientError(Exception):
|
||||||
|
"""Base exception for A2A client errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class A2AClientHTTPError(A2AClientError):
|
||||||
|
"""HTTP error in the A2A client."""
|
||||||
|
def __init__(self, status_code: int, message: str):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
super().__init__(f'HTTP Error {status_code}: {message}')
|
||||||
|
|
||||||
|
|
||||||
|
class A2AClientJSONError(A2AClientError):
|
||||||
|
"""JSON error in the A2A client."""
|
||||||
|
def __init__(self, message: str):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(f'JSON Error: {message}')
|
||||||
|
|
||||||
|
|
||||||
|
class MissingAPIKeyError(Exception):
|
||||||
|
"""Exception for missing API key."""
|
||||||
|
pass
|
||||||
1
tests/a2a/__init__.py
Normal file
1
tests/a2a/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for the A2A protocol implementation."""
|
||||||
230
tests/a2a/test_a2a_integration.py
Normal file
230
tests/a2a/test_a2a_integration.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""Tests for the A2A protocol integration."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager
|
||||||
|
from crewai.task import Task
|
||||||
|
from crewai.types.a2a import (
|
||||||
|
Message,
|
||||||
|
Task as A2ATask,
|
||||||
|
TaskState,
|
||||||
|
TaskStatus,
|
||||||
|
TaskStatusUpdateEvent,
|
||||||
|
TextPart,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent():
|
||||||
|
"""Create an agent with A2A enabled."""
|
||||||
|
return Agent(
|
||||||
|
role="test_agent",
|
||||||
|
goal="Test A2A protocol",
|
||||||
|
backstory="I am a test agent",
|
||||||
|
a2a_enabled=True,
|
||||||
|
a2a_url="http://localhost:8000",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def task():
|
||||||
|
"""Create a task."""
|
||||||
|
return Task(
|
||||||
|
description="Test task",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def a2a_task():
|
||||||
|
"""Create an A2A task."""
|
||||||
|
return A2ATask(
|
||||||
|
id="test_task_id",
|
||||||
|
history=[
|
||||||
|
Message(
|
||||||
|
role="user",
|
||||||
|
parts=[TextPart(text="Test task description")],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def a2a_integration():
|
||||||
|
"""Create an A2A integration."""
|
||||||
|
return A2AAgentIntegration()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def a2a_client():
|
||||||
|
"""Create an A2A client."""
|
||||||
|
return A2AClient(base_url="http://localhost:8000")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def task_manager():
|
||||||
|
"""Create a task manager."""
|
||||||
|
return InMemoryTaskManager()
|
||||||
|
|
||||||
|
|
||||||
|
class TestA2AIntegration:
|
||||||
|
"""Tests for the A2A protocol integration."""
|
||||||
|
|
||||||
|
def test_agent_a2a_attributes(self, agent):
|
||||||
|
"""Test that the agent has A2A attributes."""
|
||||||
|
assert agent.a2a_enabled is True
|
||||||
|
assert agent.a2a_url == "http://localhost:8000"
|
||||||
|
assert agent._a2a_integration is not None
|
||||||
|
|
||||||
|
@patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a")
|
||||||
|
def test_execute_task_via_a2a(self, mock_execute, agent):
|
||||||
|
"""Test executing a task via A2A."""
|
||||||
|
mock_execute.return_value = asyncio.Future()
|
||||||
|
mock_execute.return_value.set_result("Task result")
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
agent.execute_task_via_a2a(
|
||||||
|
task_description="Test task",
|
||||||
|
context="Test context",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Task result"
|
||||||
|
mock_execute.assert_called_once_with(
|
||||||
|
agent_url="http://localhost:8000",
|
||||||
|
task_description="Test task",
|
||||||
|
context="Test context",
|
||||||
|
api_key=None,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("crewai.agent.Agent.execute_task")
|
||||||
|
def test_handle_a2a_task(self, mock_execute, agent):
|
||||||
|
"""Test handling an A2A task."""
|
||||||
|
mock_execute.return_value = "Task result"
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
agent.handle_a2a_task(
|
||||||
|
task_id="test_task_id",
|
||||||
|
task_description="Test task",
|
||||||
|
context="Test context",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Task result"
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
args, kwargs = mock_execute.call_args
|
||||||
|
assert args[0].description == "Test task"
|
||||||
|
assert kwargs["context"] == "Test context"
|
||||||
|
|
||||||
|
def test_a2a_disabled(self, agent):
|
||||||
|
"""Test that A2A methods raise ValueError when A2A is disabled."""
|
||||||
|
agent.a2a_enabled = False
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
|
||||||
|
asyncio.run(
|
||||||
|
agent.execute_task_via_a2a(
|
||||||
|
task_description="Test task",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
|
||||||
|
asyncio.run(
|
||||||
|
agent.handle_a2a_task(
|
||||||
|
task_id="test_task_id",
|
||||||
|
task_description="Test task",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_agent_url(self, agent):
|
||||||
|
"""Test that execute_task_via_a2a raises ValueError when no agent URL is provided."""
|
||||||
|
agent.a2a_url = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No A2A agent URL provided"):
|
||||||
|
asyncio.run(
|
||||||
|
agent.execute_task_via_a2a(
|
||||||
|
task_description="Test task",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestA2AAgentIntegration:
|
||||||
|
"""Tests for the A2AAgentIntegration class."""
|
||||||
|
|
||||||
|
@patch("crewai.a2a.client.A2AClient.send_task_streaming")
|
||||||
|
async def test_execute_task_via_a2a(self, mock_send_task, a2a_integration):
|
||||||
|
"""Test executing a task via A2A."""
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
await queue.put(
|
||||||
|
TaskStatusUpdateEvent(
|
||||||
|
task_id="test_task_id",
|
||||||
|
status=TaskStatus(
|
||||||
|
state=TaskState.COMPLETED,
|
||||||
|
message=Message(
|
||||||
|
role="agent",
|
||||||
|
parts=[TextPart(text="Task result")],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
final=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_send_task.return_value = queue
|
||||||
|
|
||||||
|
result = await a2a_integration.execute_task_via_a2a(
|
||||||
|
agent_url="http://localhost:8000",
|
||||||
|
task_description="Test task",
|
||||||
|
context="Test context",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Task result"
|
||||||
|
mock_send_task.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestA2AServer:
|
||||||
|
"""Tests for the A2AServer class."""
|
||||||
|
|
||||||
|
@patch("fastapi.FastAPI.post")
|
||||||
|
def test_server_initialization(self, mock_post, task_manager):
|
||||||
|
"""Test server initialization."""
|
||||||
|
server = A2AServer(task_manager=task_manager)
|
||||||
|
assert server.task_manager == task_manager
|
||||||
|
assert server.app is not None
|
||||||
|
assert mock_post.call_count == 4 # 4 endpoints registered
|
||||||
|
|
||||||
|
|
||||||
|
class TestA2AClient:
|
||||||
|
"""Tests for the A2AClient class."""
|
||||||
|
|
||||||
|
@patch("aiohttp.ClientSession.post")
|
||||||
|
async def test_send_task(self, mock_post, a2a_client):
|
||||||
|
"""Test sending a task."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "test_task_id",
|
||||||
|
"history": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "Test task description"}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_post.return_value.__aenter__.return_value = mock_response
|
||||||
|
|
||||||
|
task = await a2a_client.send_task(
|
||||||
|
task_id="test_task_id",
|
||||||
|
message=Message(
|
||||||
|
role="user",
|
||||||
|
parts=[TextPart(text="Test task description")],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.id == "test_task_id"
|
||||||
|
assert task.history[0].role == "user"
|
||||||
|
assert task.history[0].parts[0].text == "Test task description"
|
||||||
|
mock_post.assert_called_once()
|
||||||
Reference in New Issue
Block a user