Address PR feedback: Improve error handling, add OpenAPI docs, and verify task management

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-09 04:58:19 +00:00
parent cfabb9fa78
commit 9bb8854c25
8 changed files with 482 additions and 57 deletions

View File

@@ -2,6 +2,7 @@
from crewai.a2a.agent import A2AAgentIntegration
from crewai.a2a.client import A2AClient
from crewai.a2a.config import A2AConfig
from crewai.a2a.server import A2AServer
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
@@ -11,4 +12,5 @@ __all__ = [
"A2AServer",
"TaskManager",
"InMemoryTaskManager",
"A2AConfig",
]

View File

@@ -8,11 +8,14 @@ import asyncio
import json
import logging
import os
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast
import aiohttp
from pydantic import ValidationError
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
from crewai.types.a2a import (
A2AClientError,
A2AClientHTTPError,
@@ -53,7 +56,8 @@ class A2AClient:
self,
base_url: str,
api_key: Optional[str] = None,
timeout: int = 60,
timeout: Optional[int] = None,
config: Optional["A2AConfig"] = None,
):
"""Initialize the A2A client.
@@ -61,10 +65,22 @@ class A2AClient:
base_url: The base URL of the A2A server.
api_key: The API key to use for authentication.
timeout: The timeout for HTTP requests in seconds.
config: The A2A configuration. If provided, other parameters are ignored.
"""
if config:
from crewai.a2a.config import A2AConfig
self.config = config
else:
from crewai.a2a.config import A2AConfig
self.config = A2AConfig()
if api_key:
self.config.api_key = api_key
if timeout:
self.config.client_timeout = timeout
self.base_url = base_url.rstrip("/")
self.api_key = api_key or os.environ.get("A2A_API_KEY")
self.timeout = timeout
self.api_key = self.config.api_key or os.environ.get("A2A_API_KEY")
self.timeout = self.config.client_timeout
self.logger = logging.getLogger(__name__)
async def send_task(
@@ -92,7 +108,10 @@ class A2AClient:
The created task.
Raises:
A2AClientError: If there is an error sending the task.
MissingAPIKeyError: If no API key is provided.
A2AClientHTTPError: If there is an HTTP error.
A2AClientJSONError: If there is an error parsing the JSON response.
A2AClientError: If there is any other error sending the task.
"""
params = TaskSendParams(
id=task_id,
@@ -105,15 +124,26 @@ class A2AClient:
)
request = SendTaskRequest(params=params)
response = await self._send_jsonrpc_request(request)
try:
response = await self._send_jsonrpc_request(request)
if response.error:
raise A2AClientError(f"Error sending task: {response.error.message}")
if response.error:
raise A2AClientError(f"Error sending task: {response.error.message}")
if not response.result:
raise A2AClientError("No result returned from send task request")
if not response.result:
raise A2AClientError("No result returned from send task request")
return cast(Task, response.result)
if isinstance(response.result, dict):
return Task.model_validate(response.result)
return cast(Task, response.result)
except asyncio.TimeoutError as e:
raise A2AClientError(f"Task request timed out: {e}")
except aiohttp.ClientError as e:
if isinstance(e, aiohttp.ClientResponseError):
raise A2AClientHTTPError(e.status, str(e))
else:
raise A2AClientError(f"Client error: {e}")
async def send_task_streaming(
self,
@@ -318,6 +348,14 @@ class A2AClient:
return JSONRPCResponse.model_validate(data)
except ValidationError as e:
raise A2AClientError(f"Invalid response: {e}")
except aiohttp.ClientConnectorError as e:
raise A2AClientHTTPError(status=0, message=f"Connection error: {e}")
except aiohttp.ClientOSError as e:
raise A2AClientHTTPError(status=0, message=f"OS error: {e}")
except aiohttp.ServerDisconnectedError as e:
raise A2AClientHTTPError(status=0, message=f"Server disconnected: {e}")
except aiohttp.ClientResponseError as e:
raise A2AClientHTTPError(e.status, str(e))
except aiohttp.ClientError as e:
raise A2AClientError(f"HTTP error: {e}")
@@ -394,6 +432,14 @@ class A2AClient:
await queue.put(
A2AClientError(f"Invalid artifact event: {e}")
)
except aiohttp.ClientConnectorError as e:
await queue.put(A2AClientHTTPError(status=0, message=f"Connection error: {e}"))
except aiohttp.ClientOSError as e:
await queue.put(A2AClientHTTPError(status=0, message=f"OS error: {e}"))
except aiohttp.ServerDisconnectedError as e:
await queue.put(A2AClientHTTPError(status=0, message=f"Server disconnected: {e}"))
except aiohttp.ClientResponseError as e:
await queue.put(A2AClientHTTPError(e.status, str(e)))
except aiohttp.ClientError as e:
await queue.put(A2AClientError(f"HTTP error: {e}"))
except asyncio.CancelledError:

89
src/crewai/a2a/config.py Normal file
View File

@@ -0,0 +1,89 @@
"""
Configuration management for A2A protocol in CrewAI.
This module provides configuration management for the A2A protocol implementation
in CrewAI, including default values and environment variable support.
"""
import os
from typing import Dict, Optional, Union
from pydantic import BaseModel, Field
class A2AConfig(BaseModel):
"""Configuration for A2A protocol."""
server_host: str = Field(
default="0.0.0.0",
description="Host to bind the A2A server to.",
)
server_port: int = Field(
default=8000,
description="Port to bind the A2A server to.",
)
enable_cors: bool = Field(
default=True,
description="Whether to enable CORS for the A2A server.",
)
cors_origins: Optional[list[str]] = Field(
default=None,
description="CORS origins to allow. If None, all origins are allowed.",
)
client_timeout: int = Field(
default=60,
description="Timeout for A2A client requests in seconds.",
)
api_key: Optional[str] = Field(
default=None,
description="API key for A2A authentication.",
)
task_ttl: int = Field(
default=3600,
description="Time-to-live for tasks in seconds.",
)
cleanup_interval: int = Field(
default=300,
description="Interval for cleaning up expired tasks in seconds.",
)
max_history_length: int = Field(
default=100,
description="Maximum number of messages to include in task history.",
)
@classmethod
def from_env(cls) -> "A2AConfig":
"""Create a configuration from environment variables.
Environment variables are prefixed with A2A_ and are uppercase.
For example, A2A_SERVER_PORT=8080 will set server_port to 8080.
Returns:
A2AConfig: The configuration.
"""
config_dict: Dict[str, Union[str, int, bool, list[str]]] = {}
if "A2A_SERVER_HOST" in os.environ:
config_dict["server_host"] = os.environ["A2A_SERVER_HOST"]
if "A2A_SERVER_PORT" in os.environ:
config_dict["server_port"] = int(os.environ["A2A_SERVER_PORT"])
if "A2A_ENABLE_CORS" in os.environ:
config_dict["enable_cors"] = os.environ["A2A_ENABLE_CORS"].lower() == "true"
if "A2A_CORS_ORIGINS" in os.environ:
config_dict["cors_origins"] = os.environ["A2A_CORS_ORIGINS"].split(",")
if "A2A_CLIENT_TIMEOUT" in os.environ:
config_dict["client_timeout"] = int(os.environ["A2A_CLIENT_TIMEOUT"])
if "A2A_API_KEY" in os.environ:
config_dict["api_key"] = os.environ["A2A_API_KEY"]
if "A2A_TASK_TTL" in os.environ:
config_dict["task_ttl"] = int(os.environ["A2A_TASK_TTL"])
if "A2A_CLEANUP_INTERVAL" in os.environ:
config_dict["cleanup_interval"] = int(os.environ["A2A_CLEANUP_INTERVAL"])
if "A2A_MAX_HISTORY_LENGTH" in os.environ:
config_dict["max_history_length"] = int(os.environ["A2A_MAX_HISTORY_LENGTH"])
return cls(**config_dict)

View File

@@ -7,13 +7,16 @@ This module implements the server for the A2A protocol in CrewAI.
import asyncio
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import ValidationError
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager
from crewai.types.a2a import (
A2ARequest,
@@ -58,17 +61,47 @@ class A2AServer:
def __init__(
self,
task_manager: Optional[TaskManager] = None,
enable_cors: bool = True,
enable_cors: Optional[bool] = None,
cors_origins: Optional[List[str]] = None,
config: Optional["A2AConfig"] = None,
):
"""Initialize the A2A server.
Args:
task_manager: The task manager to use. If None, an InMemoryTaskManager will be created.
enable_cors: Whether to enable CORS.
cors_origins: The CORS origins to allow.
enable_cors: Whether to enable CORS. If None, uses config value.
cors_origins: The CORS origins to allow. If None, uses config value.
config: The A2A configuration. If provided, other parameters are ignored.
"""
self.app = FastAPI(title="A2A Server")
from crewai.a2a.config import A2AConfig
self.config = config or A2AConfig.from_env()
enable_cors = enable_cors if enable_cors is not None else self.config.enable_cors
cors_origins = cors_origins or self.config.cors_origins
self.app = FastAPI(
title="A2A Protocol Server",
description="""
A2A (Agent-to-Agent) protocol server for CrewAI.
This server implements Google's A2A protocol specification, enabling interoperability
between different agent systems. It provides endpoints for task creation, retrieval,
cancellation, and streaming updates.
""",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_tags=[
{
"name": "tasks",
"description": "Operations for managing A2A tasks",
},
{
"name": "jsonrpc",
"description": "JSON-RPC interface for the A2A protocol",
},
],
)
self.task_manager = task_manager or InMemoryTaskManager()
self.logger = logging.getLogger(__name__)
@@ -81,11 +114,125 @@ class A2AServer:
allow_headers=["*"],
)
self.app.post("/v1/jsonrpc")(self.handle_jsonrpc)
self.app.post("/v1/tasks/send")(self.handle_send_task)
self.app.post("/v1/tasks/sendSubscribe")(self.handle_send_task_subscribe)
self.app.post("/v1/tasks/{task_id}/cancel")(self.handle_cancel_task)
self.app.get("/v1/tasks/{task_id}")(self.handle_get_task)
@self.app.post(
"/v1/jsonrpc",
summary="Handle JSON-RPC requests",
description="""
Process JSON-RPC requests for the A2A protocol.
This endpoint handles all JSON-RPC requests for the A2A protocol, including:
- SendTask: Create a new task
- GetTask: Retrieve a task by ID
- CancelTask: Cancel a running task
- SetTaskPushNotification: Configure push notifications for a task
- GetTaskPushNotification: Retrieve push notification configuration for a task
""",
response_model=JSONRPCResponse,
responses={
200: {"description": "Successful response with result or error"},
400: {"description": "Invalid request format or parameters"},
500: {"description": "Internal server error during processing"},
},
tags=["jsonrpc"],
)
async def handle_jsonrpc(request: Request):
return await self.handle_jsonrpc(request)
@self.app.post(
"/v1/tasks/send",
summary="Send a task to an agent",
description="""
Create a new task and send it to an agent for execution.
This endpoint allows clients to send tasks to agents for processing.
The task is created with the provided parameters and immediately
transitions to the WORKING state. The response includes the created
task with its current status.
""",
response_model=Task,
responses={
200: {"description": "Task created successfully and processing started"},
400: {"description": "Invalid request format or parameters"},
500: {"description": "Internal server error during task creation or processing"},
},
tags=["tasks"],
)
async def handle_send_task(request: Request):
return await self.handle_send_task(request)
@self.app.post(
"/v1/tasks/sendSubscribe",
summary="Send a task and subscribe to updates",
description="""
Create a new task and subscribe to status updates via Server-Sent Events (SSE).
This endpoint allows clients to send tasks to agents and receive real-time
updates as the task progresses. The response is a streaming SSE connection
that provides status updates and artifact notifications until the task
reaches a terminal state (COMPLETED, FAILED, CANCELED, or EXPIRED).
""",
responses={
200: {
"description": "Streaming response with task updates",
"content": {
"text/event-stream": {
"schema": {"type": "string"},
"example": 'event: status\ndata: {"task_id": "123", "status": {"state": "WORKING"}}\n\n',
}
},
},
400: {"description": "Invalid request format or parameters"},
500: {"description": "Internal server error during task creation or processing"},
},
tags=["tasks"],
)
async def handle_send_task_subscribe(request: Request):
return await self.handle_send_task_subscribe(request)
@self.app.post(
"/v1/tasks/{task_id}/cancel",
summary="Cancel a task",
description="""
Cancel a running task by ID.
This endpoint allows clients to cancel a task that is currently in progress.
The task must be in a non-terminal state (PENDING, WORKING) to be canceled.
Once canceled, the task transitions to the CANCELED state and cannot be
resumed. The response includes the updated task with its current status.
""",
response_model=Task,
responses={
200: {"description": "Task canceled successfully and status updated to CANCELED"},
404: {"description": "Task not found or already expired"},
409: {"description": "Task cannot be canceled (already in terminal state)"},
500: {"description": "Internal server error during task cancellation"},
},
tags=["tasks"],
)
async def handle_cancel_task(task_id: str, request: Request):
return await self.handle_cancel_task(task_id, request)
@self.app.get(
"/v1/tasks/{task_id}",
summary="Get task details",
description="""
Retrieve details of a task by ID.
This endpoint allows clients to retrieve the current state and details of a task.
The response includes the task's status, history, and any associated metadata.
Clients can specify the history_length parameter to limit the number of messages
included in the response.
""",
response_model=Task,
responses={
200: {"description": "Task details retrieved successfully with current status"},
404: {"description": "Task not found or expired"},
500: {"description": "Internal server error during task retrieval"},
},
tags=["tasks"],
)
async def handle_get_task(task_id: str, request: Request):
return await self.handle_get_task(task_id, request)
async def handle_jsonrpc(self, request: Request) -> JSONResponse:
"""Handle JSON-RPC requests.
@@ -121,7 +268,7 @@ class A2AServer:
return JSONResponse(
content=JSONRPCResponse(
id=body.get("id") if isinstance(body, dict) else None,
error=InternalError(data=str(e)),
error=InternalError(message="Internal server error"),
).model_dump(),
status_code=500,
)
@@ -210,7 +357,7 @@ class A2AServer:
self.logger.exception(f"Error handling {method} request")
return JSONRPCResponse(
id=request_id,
error=InternalError(data=str(e)),
error=InternalError(message="Internal server error"),
)
async def handle_send_task(self, request: Request) -> JSONResponse:
@@ -227,15 +374,15 @@ class A2AServer:
params = TaskSendParams.model_validate(body)
task = await self._handle_send_task(params)
return JSONResponse(content=task.model_dump())
except ValidationError as e:
except ValidationError:
return JSONResponse(
content={"error": str(e)},
content={"error": "Invalid request format or parameters"},
status_code=400,
)
except Exception as e:
self.logger.exception("Error handling send task request")
return JSONResponse(
content={"error": str(e)},
content={"error": "Internal server error"},
status_code=500,
)
@@ -283,15 +430,15 @@ class A2AServer:
self._stream_task_updates(params.id, queue),
media_type="text/event-stream",
)
except ValidationError as e:
except ValidationError:
return JSONResponse(
content={"error": str(e)},
content={"error": "Invalid request format or parameters"},
status_code=400,
)
except Exception as e:
self.logger.exception("Error handling send task subscribe request")
return JSONResponse(
content={"error": str(e)},
content={"error": "Internal server error"},
status_code=500,
)
@@ -346,7 +493,7 @@ class A2AServer:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
except Exception as e:
self.logger.exception(f"Error handling get task request for {task_id}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail="Internal server error")
async def handle_cancel_task(self, task_id: str, request: Request) -> JSONResponse:
"""Handle cancel task requests.
@@ -365,4 +512,4 @@ class A2AServer:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
except Exception as e:
self.logger.exception(f"Error handling cancel task request for {task_id}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -8,9 +8,12 @@ import asyncio
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
from uuid import uuid4
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
from crewai.types.a2a import (
Artifact,
Message,
@@ -165,12 +168,37 @@ class TaskManager(ABC):
class InMemoryTaskManager(TaskManager):
"""In-memory implementation of the A2A task manager."""
def __init__(self):
"""Initialize the in-memory task manager."""
def __init__(
self,
task_ttl: Optional[int] = None,
cleanup_interval: Optional[int] = None,
config: Optional["A2AConfig"] = None,
):
"""Initialize the in-memory task manager.
Args:
task_ttl: Time to live for tasks in seconds. Default is 1 hour.
cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes.
config: The A2A configuration. If provided, other parameters are ignored.
"""
from crewai.a2a.config import A2AConfig
self.config = config or A2AConfig.from_env()
self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl
self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval
self._tasks: Dict[str, Task] = {}
self._push_notifications: Dict[str, PushNotificationConfig] = {}
self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {}
self._task_timestamps: Dict[str, datetime] = {}
self._logger = logging.getLogger(__name__)
self._cleanup_task = None
try:
if asyncio.get_running_loop():
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
except RuntimeError:
self._logger.info("No running event loop, periodic cleanup disabled")
async def create_task(
self,
@@ -198,6 +226,7 @@ class InMemoryTaskManager(TaskManager):
state=TaskState.SUBMITTED,
message=message,
timestamp=datetime.now(),
previous_state=None, # Initial state has no previous state
)
task = Task(
@@ -211,6 +240,7 @@ class InMemoryTaskManager(TaskManager):
self._tasks[task_id] = task
self._task_subscribers[task_id] = set()
self._task_timestamps[task_id] = datetime.now()
return task
async def get_task(
@@ -263,20 +293,29 @@ class InMemoryTaskManager(TaskManager):
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
task = self._tasks[task_id]
previous_state = task.status.state if task.status else None
if previous_state and not TaskState.is_valid_transition(previous_state, state):
raise ValueError(f"Invalid state transition from {previous_state} to {state}")
status = TaskStatus(
state=state,
message=message,
timestamp=datetime.now(),
previous_state=previous_state,
)
task.status = status
if message and task.history is not None:
task.history.append(message)
self._task_timestamps[task_id] = datetime.now()
event = TaskStatusUpdateEvent(
id=task_id,
status=status,
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED],
final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED],
metadata=metadata or {},
)
@@ -436,3 +475,48 @@ class InMemoryTaskManager(TaskManager):
if task_id in self._task_subscribers:
for queue in self._task_subscribers[task_id]:
await queue.put(event)
async def _periodic_cleanup(self) -> None:
"""Periodically clean up expired tasks."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired_tasks()
except asyncio.CancelledError:
break
except Exception as e:
self._logger.exception(f"Error during periodic cleanup: {e}")
async def _cleanup_expired_tasks(self) -> None:
"""Clean up expired tasks."""
now = datetime.now()
expired_tasks = []
for task_id, timestamp in self._task_timestamps.items():
if (now - timestamp).total_seconds() > self._task_ttl:
expired_tasks.append(task_id)
for task_id in expired_tasks:
self._logger.info(f"Cleaning up expired task: {task_id}")
self._tasks.pop(task_id, None)
self._push_notifications.pop(task_id, None)
self._task_timestamps.pop(task_id, None)
if task_id in self._task_subscribers:
previous_state = None
if task_id in self._tasks and self._tasks[task_id].status:
previous_state = self._tasks[task_id].status.state
status = TaskStatus(
state=TaskState.EXPIRED,
timestamp=now,
previous_state=previous_state,
)
event = TaskStatusUpdateEvent(
task_id=task_id,
status=status,
final=True,
)
await self._notify_subscribers(task_id, event)
self._task_subscribers.pop(task_id, None)

View File

@@ -459,10 +459,11 @@ class Agent(BaseAgent):
task = Task(
description=task_description,
agent=self,
expected_output="text", # Default to text output
)
try:
result = self.execute_task(task, context)
result = self.execute_task(task=task, context=context)
return result
except Exception as e:
self._logger.exception(f"Error handling A2A task: {e}")

View File

@@ -24,6 +24,42 @@ class TaskState(str, Enum):
CANCELED = 'canceled'
FAILED = 'failed'
UNKNOWN = 'unknown'
EXPIRED = 'expired'
@classmethod
def valid_transitions(cls) -> Dict[str, List[str]]:
"""Get valid state transitions.
Returns:
A dictionary mapping from state to list of valid next states.
"""
return {
cls.SUBMITTED: [cls.WORKING, cls.CANCELED, cls.FAILED],
cls.WORKING: [cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
cls.INPUT_REQUIRED: [cls.WORKING, cls.CANCELED, cls.FAILED],
cls.COMPLETED: [], # Terminal state
cls.CANCELED: [], # Terminal state
cls.FAILED: [], # Terminal state
cls.UNKNOWN: [cls.SUBMITTED, cls.WORKING, cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED],
cls.EXPIRED: [], # Terminal state
}
@classmethod
def is_valid_transition(cls, from_state: 'TaskState', to_state: 'TaskState') -> bool:
"""Check if a state transition is valid.
Args:
from_state: The current state.
to_state: The target state.
Returns:
True if the transition is valid, False otherwise.
"""
if from_state == to_state:
return True
valid_next_states = cls.valid_transitions().get(from_state, [])
return to_state in valid_next_states
class TextPart(BaseModel):
@@ -83,11 +119,21 @@ class TaskStatus(BaseModel):
state: TaskState
message: Optional[Message] = None
timestamp: datetime = Field(default_factory=datetime.now)
previous_state: Optional[TaskState] = None
@field_serializer('timestamp')
def serialize_dt(self, dt: datetime, _info):
"""Serialize datetime to ISO format."""
return dt.isoformat()
@model_validator(mode='after')
def validate_state_transition(self) -> Self:
"""Validate state transition."""
if self.previous_state and not TaskState.is_valid_transition(self.previous_state, self.state):
raise ValueError(
f"Invalid state transition from {self.previous_state} to {self.state}"
)
return self
class Artifact(BaseModel):

View File

@@ -1,13 +1,17 @@
"""Tests for the A2A protocol integration."""
import asyncio
from datetime import datetime
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
pytestmark = pytest.mark.asyncio
from crewai.agent import Agent
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager
from crewai.task import Task
from crewai.types.a2a import (
JSONRPCResponse,
Message,
Task as A2ATask,
TaskState,
@@ -60,7 +64,7 @@ def a2a_integration():
@pytest.fixture
def a2a_client():
"""Create an A2A client."""
return A2AClient(base_url="http://localhost:8000")
return A2AClient(base_url="http://localhost:8000", api_key="test_api_key")
@pytest.fixture
@@ -81,8 +85,7 @@ class TestA2AIntegration:
@patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a")
def test_execute_task_via_a2a(self, mock_execute, agent):
"""Test executing a task via A2A."""
mock_execute.return_value = asyncio.Future()
mock_execute.return_value.set_result("Task result")
mock_execute.return_value = "Task result"
result = asyncio.run(
agent.execute_task_via_a2a(
@@ -116,8 +119,8 @@ class TestA2AIntegration:
assert result == "Task result"
mock_execute.assert_called_once()
args, kwargs = mock_execute.call_args
assert args[0].description == "Test task"
assert kwargs["context"] == "Test context"
assert kwargs["task"].description == "Test task"
def test_a2a_disabled(self, agent):
"""Test that A2A methods raise ValueError when A2A is disabled."""
@@ -159,7 +162,7 @@ class TestA2AAgentIntegration:
queue = asyncio.Queue()
await queue.put(
TaskStatusUpdateEvent(
task_id="test_task_id",
id="test_task_id",
status=TaskStatus(
state=TaskState.COMPLETED,
message=Message(
@@ -198,23 +201,29 @@ class TestA2AServer:
class TestA2AClient:
"""Tests for the A2AClient class."""
@patch("aiohttp.ClientSession.post")
async def test_send_task(self, mock_post, a2a_client):
@patch("crewai.a2a.client.A2AClient._send_jsonrpc_request")
async def test_send_task(self, mock_send_request, a2a_client):
"""Test sending a task."""
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(
return_value={
"id": "test_task_id",
"history": [
{
"role": "user",
"parts": [{"text": "Test task description"}],
}
mock_response = JSONRPCResponse(
jsonrpc="2.0",
id="test_request_id",
result=A2ATask(
id="test_task_id",
sessionId="test_session_id",
status=TaskStatus(
state=TaskState.SUBMITTED,
timestamp=datetime.now(),
),
history=[
Message(
role="user",
parts=[TextPart(text="Test task description")],
)
],
}
)
)
mock_post.return_value.__aenter__.return_value = mock_response
mock_send_request.return_value = mock_response
task = await a2a_client.send_task(
task_id="test_task_id",
@@ -222,9 +231,10 @@ class TestA2AClient:
role="user",
parts=[TextPart(text="Test task description")],
),
session_id="test_session_id",
)
assert task.id == "test_task_id"
assert task.history[0].role == "user"
assert task.history[0].parts[0].text == "Test task description"
mock_post.assert_called_once()
mock_send_request.assert_called_once()