diff --git a/src/crewai/a2a/__init__.py b/src/crewai/a2a/__init__.py index 0ebf9ae20..06e322e48 100644 --- a/src/crewai/a2a/__init__.py +++ b/src/crewai/a2a/__init__.py @@ -2,6 +2,7 @@ from crewai.a2a.agent import A2AAgentIntegration from crewai.a2a.client import A2AClient +from crewai.a2a.config import A2AConfig from crewai.a2a.server import A2AServer from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager @@ -11,4 +12,5 @@ __all__ = [ "A2AServer", "TaskManager", "InMemoryTaskManager", + "A2AConfig", ] diff --git a/src/crewai/a2a/client.py b/src/crewai/a2a/client.py index 743009ca3..8592eefc7 100644 --- a/src/crewai/a2a/client.py +++ b/src/crewai/a2a/client.py @@ -8,11 +8,14 @@ import asyncio import json import logging import os -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast import aiohttp from pydantic import ValidationError +if TYPE_CHECKING: + from crewai.a2a.config import A2AConfig + from crewai.types.a2a import ( A2AClientError, A2AClientHTTPError, @@ -53,7 +56,8 @@ class A2AClient: self, base_url: str, api_key: Optional[str] = None, - timeout: int = 60, + timeout: Optional[int] = None, + config: Optional["A2AConfig"] = None, ): """Initialize the A2A client. @@ -61,10 +65,22 @@ class A2AClient: base_url: The base URL of the A2A server. api_key: The API key to use for authentication. timeout: The timeout for HTTP requests in seconds. + config: The A2A configuration. If provided, other parameters are ignored. """ + if config: + from crewai.a2a.config import A2AConfig + self.config = config + else: + from crewai.a2a.config import A2AConfig + self.config = A2AConfig() + if api_key: + self.config.api_key = api_key + if timeout: + self.config.client_timeout = timeout + self.base_url = base_url.rstrip("/") - self.api_key = api_key or os.environ.get("A2A_API_KEY") - self.timeout = timeout + self.api_key = self.config.api_key or os.environ.get("A2A_API_KEY") + self.timeout = self.config.client_timeout self.logger = logging.getLogger(__name__) async def send_task( @@ -92,7 +108,10 @@ class A2AClient: The created task. Raises: - A2AClientError: If there is an error sending the task. + MissingAPIKeyError: If no API key is provided. + A2AClientHTTPError: If there is an HTTP error. + A2AClientJSONError: If there is an error parsing the JSON response. + A2AClientError: If there is any other error sending the task. """ params = TaskSendParams( id=task_id, @@ -105,15 +124,26 @@ class A2AClient: ) request = SendTaskRequest(params=params) - response = await self._send_jsonrpc_request(request) + + try: + response = await self._send_jsonrpc_request(request) + + if response.error: + raise A2AClientError(f"Error sending task: {response.error.message}") - if response.error: - raise A2AClientError(f"Error sending task: {response.error.message}") + if not response.result: + raise A2AClientError("No result returned from send task request") - if not response.result: - raise A2AClientError("No result returned from send task request") - - return cast(Task, response.result) + if isinstance(response.result, dict): + return Task.model_validate(response.result) + return cast(Task, response.result) + except asyncio.TimeoutError as e: + raise A2AClientError(f"Task request timed out: {e}") + except aiohttp.ClientError as e: + if isinstance(e, aiohttp.ClientResponseError): + raise A2AClientHTTPError(e.status, str(e)) + else: + raise A2AClientError(f"Client error: {e}") async def send_task_streaming( self, @@ -318,6 +348,14 @@ class A2AClient: return JSONRPCResponse.model_validate(data) except ValidationError as e: raise A2AClientError(f"Invalid response: {e}") + except aiohttp.ClientConnectorError as e: + raise A2AClientHTTPError(status=0, message=f"Connection error: {e}") + except aiohttp.ClientOSError as e: + raise A2AClientHTTPError(status=0, message=f"OS error: {e}") + except aiohttp.ServerDisconnectedError as e: + raise A2AClientHTTPError(status=0, message=f"Server disconnected: {e}") + except aiohttp.ClientResponseError as e: + raise A2AClientHTTPError(e.status, str(e)) except aiohttp.ClientError as e: raise A2AClientError(f"HTTP error: {e}") @@ -394,6 +432,14 @@ class A2AClient: await queue.put( A2AClientError(f"Invalid artifact event: {e}") ) + except aiohttp.ClientConnectorError as e: + await queue.put(A2AClientHTTPError(status=0, message=f"Connection error: {e}")) + except aiohttp.ClientOSError as e: + await queue.put(A2AClientHTTPError(status=0, message=f"OS error: {e}")) + except aiohttp.ServerDisconnectedError as e: + await queue.put(A2AClientHTTPError(status=0, message=f"Server disconnected: {e}")) + except aiohttp.ClientResponseError as e: + await queue.put(A2AClientHTTPError(e.status, str(e))) except aiohttp.ClientError as e: await queue.put(A2AClientError(f"HTTP error: {e}")) except asyncio.CancelledError: diff --git a/src/crewai/a2a/config.py b/src/crewai/a2a/config.py new file mode 100644 index 000000000..1d14f4b4a --- /dev/null +++ b/src/crewai/a2a/config.py @@ -0,0 +1,89 @@ +""" +Configuration management for A2A protocol in CrewAI. + +This module provides configuration management for the A2A protocol implementation +in CrewAI, including default values and environment variable support. +""" + +import os +from typing import Dict, Optional, Union + +from pydantic import BaseModel, Field + + +class A2AConfig(BaseModel): + """Configuration for A2A protocol.""" + + server_host: str = Field( + default="0.0.0.0", + description="Host to bind the A2A server to.", + ) + server_port: int = Field( + default=8000, + description="Port to bind the A2A server to.", + ) + enable_cors: bool = Field( + default=True, + description="Whether to enable CORS for the A2A server.", + ) + cors_origins: Optional[list[str]] = Field( + default=None, + description="CORS origins to allow. If None, all origins are allowed.", + ) + + client_timeout: int = Field( + default=60, + description="Timeout for A2A client requests in seconds.", + ) + api_key: Optional[str] = Field( + default=None, + description="API key for A2A authentication.", + ) + + task_ttl: int = Field( + default=3600, + description="Time-to-live for tasks in seconds.", + ) + cleanup_interval: int = Field( + default=300, + description="Interval for cleaning up expired tasks in seconds.", + ) + max_history_length: int = Field( + default=100, + description="Maximum number of messages to include in task history.", + ) + + @classmethod + def from_env(cls) -> "A2AConfig": + """Create a configuration from environment variables. + + Environment variables are prefixed with A2A_ and are uppercase. + For example, A2A_SERVER_PORT=8080 will set server_port to 8080. + + Returns: + A2AConfig: The configuration. + """ + config_dict: Dict[str, Union[str, int, bool, list[str]]] = {} + + if "A2A_SERVER_HOST" in os.environ: + config_dict["server_host"] = os.environ["A2A_SERVER_HOST"] + if "A2A_SERVER_PORT" in os.environ: + config_dict["server_port"] = int(os.environ["A2A_SERVER_PORT"]) + if "A2A_ENABLE_CORS" in os.environ: + config_dict["enable_cors"] = os.environ["A2A_ENABLE_CORS"].lower() == "true" + if "A2A_CORS_ORIGINS" in os.environ: + config_dict["cors_origins"] = os.environ["A2A_CORS_ORIGINS"].split(",") + + if "A2A_CLIENT_TIMEOUT" in os.environ: + config_dict["client_timeout"] = int(os.environ["A2A_CLIENT_TIMEOUT"]) + if "A2A_API_KEY" in os.environ: + config_dict["api_key"] = os.environ["A2A_API_KEY"] + + if "A2A_TASK_TTL" in os.environ: + config_dict["task_ttl"] = int(os.environ["A2A_TASK_TTL"]) + if "A2A_CLEANUP_INTERVAL" in os.environ: + config_dict["cleanup_interval"] = int(os.environ["A2A_CLEANUP_INTERVAL"]) + if "A2A_MAX_HISTORY_LENGTH" in os.environ: + config_dict["max_history_length"] = int(os.environ["A2A_MAX_HISTORY_LENGTH"]) + + return cls(**config_dict) diff --git a/src/crewai/a2a/server.py b/src/crewai/a2a/server.py index a815e3574..09b1c0261 100644 --- a/src/crewai/a2a/server.py +++ b/src/crewai/a2a/server.py @@ -7,13 +7,16 @@ This module implements the server for the A2A protocol in CrewAI. import asyncio import json import logging -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union from fastapi import FastAPI, HTTPException, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import ValidationError +if TYPE_CHECKING: + from crewai.a2a.config import A2AConfig + from crewai.a2a.task_manager import InMemoryTaskManager, TaskManager from crewai.types.a2a import ( A2ARequest, @@ -58,17 +61,47 @@ class A2AServer: def __init__( self, task_manager: Optional[TaskManager] = None, - enable_cors: bool = True, + enable_cors: Optional[bool] = None, cors_origins: Optional[List[str]] = None, + config: Optional["A2AConfig"] = None, ): """Initialize the A2A server. Args: task_manager: The task manager to use. If None, an InMemoryTaskManager will be created. - enable_cors: Whether to enable CORS. - cors_origins: The CORS origins to allow. + enable_cors: Whether to enable CORS. If None, uses config value. + cors_origins: The CORS origins to allow. If None, uses config value. + config: The A2A configuration. If provided, other parameters are ignored. """ - self.app = FastAPI(title="A2A Server") + from crewai.a2a.config import A2AConfig + self.config = config or A2AConfig.from_env() + + enable_cors = enable_cors if enable_cors is not None else self.config.enable_cors + cors_origins = cors_origins or self.config.cors_origins + + self.app = FastAPI( + title="A2A Protocol Server", + description=""" + A2A (Agent-to-Agent) protocol server for CrewAI. + + This server implements Google's A2A protocol specification, enabling interoperability + between different agent systems. It provides endpoints for task creation, retrieval, + cancellation, and streaming updates. + """, + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", + openapi_tags=[ + { + "name": "tasks", + "description": "Operations for managing A2A tasks", + }, + { + "name": "jsonrpc", + "description": "JSON-RPC interface for the A2A protocol", + }, + ], + ) self.task_manager = task_manager or InMemoryTaskManager() self.logger = logging.getLogger(__name__) @@ -81,11 +114,125 @@ class A2AServer: allow_headers=["*"], ) - self.app.post("/v1/jsonrpc")(self.handle_jsonrpc) - self.app.post("/v1/tasks/send")(self.handle_send_task) - self.app.post("/v1/tasks/sendSubscribe")(self.handle_send_task_subscribe) - self.app.post("/v1/tasks/{task_id}/cancel")(self.handle_cancel_task) - self.app.get("/v1/tasks/{task_id}")(self.handle_get_task) + @self.app.post( + "/v1/jsonrpc", + summary="Handle JSON-RPC requests", + description=""" + Process JSON-RPC requests for the A2A protocol. + + This endpoint handles all JSON-RPC requests for the A2A protocol, including: + - SendTask: Create a new task + - GetTask: Retrieve a task by ID + - CancelTask: Cancel a running task + - SetTaskPushNotification: Configure push notifications for a task + - GetTaskPushNotification: Retrieve push notification configuration for a task + """, + response_model=JSONRPCResponse, + responses={ + 200: {"description": "Successful response with result or error"}, + 400: {"description": "Invalid request format or parameters"}, + 500: {"description": "Internal server error during processing"}, + }, + tags=["jsonrpc"], + ) + async def handle_jsonrpc(request: Request): + return await self.handle_jsonrpc(request) + + @self.app.post( + "/v1/tasks/send", + summary="Send a task to an agent", + description=""" + Create a new task and send it to an agent for execution. + + This endpoint allows clients to send tasks to agents for processing. + The task is created with the provided parameters and immediately + transitions to the WORKING state. The response includes the created + task with its current status. + """, + response_model=Task, + responses={ + 200: {"description": "Task created successfully and processing started"}, + 400: {"description": "Invalid request format or parameters"}, + 500: {"description": "Internal server error during task creation or processing"}, + }, + tags=["tasks"], + ) + async def handle_send_task(request: Request): + return await self.handle_send_task(request) + + @self.app.post( + "/v1/tasks/sendSubscribe", + summary="Send a task and subscribe to updates", + description=""" + Create a new task and subscribe to status updates via Server-Sent Events (SSE). + + This endpoint allows clients to send tasks to agents and receive real-time + updates as the task progresses. The response is a streaming SSE connection + that provides status updates and artifact notifications until the task + reaches a terminal state (COMPLETED, FAILED, CANCELED, or EXPIRED). + """, + responses={ + 200: { + "description": "Streaming response with task updates", + "content": { + "text/event-stream": { + "schema": {"type": "string"}, + "example": 'event: status\ndata: {"task_id": "123", "status": {"state": "WORKING"}}\n\n', + } + }, + }, + 400: {"description": "Invalid request format or parameters"}, + 500: {"description": "Internal server error during task creation or processing"}, + }, + tags=["tasks"], + ) + async def handle_send_task_subscribe(request: Request): + return await self.handle_send_task_subscribe(request) + + @self.app.post( + "/v1/tasks/{task_id}/cancel", + summary="Cancel a task", + description=""" + Cancel a running task by ID. + + This endpoint allows clients to cancel a task that is currently in progress. + The task must be in a non-terminal state (PENDING, WORKING) to be canceled. + Once canceled, the task transitions to the CANCELED state and cannot be + resumed. The response includes the updated task with its current status. + """, + response_model=Task, + responses={ + 200: {"description": "Task canceled successfully and status updated to CANCELED"}, + 404: {"description": "Task not found or already expired"}, + 409: {"description": "Task cannot be canceled (already in terminal state)"}, + 500: {"description": "Internal server error during task cancellation"}, + }, + tags=["tasks"], + ) + async def handle_cancel_task(task_id: str, request: Request): + return await self.handle_cancel_task(task_id, request) + + @self.app.get( + "/v1/tasks/{task_id}", + summary="Get task details", + description=""" + Retrieve details of a task by ID. + + This endpoint allows clients to retrieve the current state and details of a task. + The response includes the task's status, history, and any associated metadata. + Clients can specify the history_length parameter to limit the number of messages + included in the response. + """, + response_model=Task, + responses={ + 200: {"description": "Task details retrieved successfully with current status"}, + 404: {"description": "Task not found or expired"}, + 500: {"description": "Internal server error during task retrieval"}, + }, + tags=["tasks"], + ) + async def handle_get_task(task_id: str, request: Request): + return await self.handle_get_task(task_id, request) async def handle_jsonrpc(self, request: Request) -> JSONResponse: """Handle JSON-RPC requests. @@ -121,7 +268,7 @@ class A2AServer: return JSONResponse( content=JSONRPCResponse( id=body.get("id") if isinstance(body, dict) else None, - error=InternalError(data=str(e)), + error=InternalError(message="Internal server error"), ).model_dump(), status_code=500, ) @@ -210,7 +357,7 @@ class A2AServer: self.logger.exception(f"Error handling {method} request") return JSONRPCResponse( id=request_id, - error=InternalError(data=str(e)), + error=InternalError(message="Internal server error"), ) async def handle_send_task(self, request: Request) -> JSONResponse: @@ -227,15 +374,15 @@ class A2AServer: params = TaskSendParams.model_validate(body) task = await self._handle_send_task(params) return JSONResponse(content=task.model_dump()) - except ValidationError as e: + except ValidationError: return JSONResponse( - content={"error": str(e)}, + content={"error": "Invalid request format or parameters"}, status_code=400, ) except Exception as e: self.logger.exception("Error handling send task request") return JSONResponse( - content={"error": str(e)}, + content={"error": "Internal server error"}, status_code=500, ) @@ -283,15 +430,15 @@ class A2AServer: self._stream_task_updates(params.id, queue), media_type="text/event-stream", ) - except ValidationError as e: + except ValidationError: return JSONResponse( - content={"error": str(e)}, + content={"error": "Invalid request format or parameters"}, status_code=400, ) except Exception as e: self.logger.exception("Error handling send task subscribe request") return JSONResponse( - content={"error": str(e)}, + content={"error": "Internal server error"}, status_code=500, ) @@ -346,7 +493,7 @@ class A2AServer: raise HTTPException(status_code=404, detail=f"Task {task_id} not found") except Exception as e: self.logger.exception(f"Error handling get task request for {task_id}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail="Internal server error") async def handle_cancel_task(self, task_id: str, request: Request) -> JSONResponse: """Handle cancel task requests. @@ -365,4 +512,4 @@ class A2AServer: raise HTTPException(status_code=404, detail=f"Task {task_id} not found") except Exception as e: self.logger.exception(f"Error handling cancel task request for {task_id}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail="Internal server error") diff --git a/src/crewai/a2a/task_manager.py b/src/crewai/a2a/task_manager.py index 4104c506b..431950fc3 100644 --- a/src/crewai/a2a/task_manager.py +++ b/src/crewai/a2a/task_manager.py @@ -8,9 +8,12 @@ import asyncio import logging from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union from uuid import uuid4 +if TYPE_CHECKING: + from crewai.a2a.config import A2AConfig + from crewai.types.a2a import ( Artifact, Message, @@ -165,12 +168,37 @@ class TaskManager(ABC): class InMemoryTaskManager(TaskManager): """In-memory implementation of the A2A task manager.""" - def __init__(self): - """Initialize the in-memory task manager.""" + def __init__( + self, + task_ttl: Optional[int] = None, + cleanup_interval: Optional[int] = None, + config: Optional["A2AConfig"] = None, + ): + """Initialize the in-memory task manager. + + Args: + task_ttl: Time to live for tasks in seconds. Default is 1 hour. + cleanup_interval: Interval for cleaning up expired tasks in seconds. Default is 5 minutes. + config: The A2A configuration. If provided, other parameters are ignored. + """ + from crewai.a2a.config import A2AConfig + self.config = config or A2AConfig.from_env() + + self._task_ttl = task_ttl if task_ttl is not None else self.config.task_ttl + self._cleanup_interval = cleanup_interval if cleanup_interval is not None else self.config.cleanup_interval + self._tasks: Dict[str, Task] = {} self._push_notifications: Dict[str, PushNotificationConfig] = {} self._task_subscribers: Dict[str, Set[asyncio.Queue]] = {} + self._task_timestamps: Dict[str, datetime] = {} self._logger = logging.getLogger(__name__) + self._cleanup_task = None + + try: + if asyncio.get_running_loop(): + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + except RuntimeError: + self._logger.info("No running event loop, periodic cleanup disabled") async def create_task( self, @@ -198,6 +226,7 @@ class InMemoryTaskManager(TaskManager): state=TaskState.SUBMITTED, message=message, timestamp=datetime.now(), + previous_state=None, # Initial state has no previous state ) task = Task( @@ -211,6 +240,7 @@ class InMemoryTaskManager(TaskManager): self._tasks[task_id] = task self._task_subscribers[task_id] = set() + self._task_timestamps[task_id] = datetime.now() return task async def get_task( @@ -263,20 +293,29 @@ class InMemoryTaskManager(TaskManager): raise KeyError(f"Task {task_id} not found") task = self._tasks[task_id] + task = self._tasks[task_id] + previous_state = task.status.state if task.status else None + + if previous_state and not TaskState.is_valid_transition(previous_state, state): + raise ValueError(f"Invalid state transition from {previous_state} to {state}") + status = TaskStatus( state=state, message=message, timestamp=datetime.now(), + previous_state=previous_state, ) task.status = status if message and task.history is not None: task.history.append(message) + self._task_timestamps[task_id] = datetime.now() + event = TaskStatusUpdateEvent( id=task_id, status=status, - final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED], + final=state in [TaskState.COMPLETED, TaskState.CANCELED, TaskState.FAILED, TaskState.EXPIRED], metadata=metadata or {}, ) @@ -436,3 +475,48 @@ class InMemoryTaskManager(TaskManager): if task_id in self._task_subscribers: for queue in self._task_subscribers[task_id]: await queue.put(event) + + async def _periodic_cleanup(self) -> None: + """Periodically clean up expired tasks.""" + while True: + try: + await asyncio.sleep(self._cleanup_interval) + await self._cleanup_expired_tasks() + except asyncio.CancelledError: + break + except Exception as e: + self._logger.exception(f"Error during periodic cleanup: {e}") + + async def _cleanup_expired_tasks(self) -> None: + """Clean up expired tasks.""" + now = datetime.now() + expired_tasks = [] + + for task_id, timestamp in self._task_timestamps.items(): + if (now - timestamp).total_seconds() > self._task_ttl: + expired_tasks.append(task_id) + + for task_id in expired_tasks: + self._logger.info(f"Cleaning up expired task: {task_id}") + self._tasks.pop(task_id, None) + self._push_notifications.pop(task_id, None) + self._task_timestamps.pop(task_id, None) + + if task_id in self._task_subscribers: + previous_state = None + if task_id in self._tasks and self._tasks[task_id].status: + previous_state = self._tasks[task_id].status.state + + status = TaskStatus( + state=TaskState.EXPIRED, + timestamp=now, + previous_state=previous_state, + ) + event = TaskStatusUpdateEvent( + task_id=task_id, + status=status, + final=True, + ) + await self._notify_subscribers(task_id, event) + + self._task_subscribers.pop(task_id, None) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index e8ec11b53..3710d1ff7 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -459,10 +459,11 @@ class Agent(BaseAgent): task = Task( description=task_description, agent=self, + expected_output="text", # Default to text output ) try: - result = self.execute_task(task, context) + result = self.execute_task(task=task, context=context) return result except Exception as e: self._logger.exception(f"Error handling A2A task: {e}") diff --git a/src/crewai/types/a2a.py b/src/crewai/types/a2a.py index ca6ea7a07..6638b3bb7 100644 --- a/src/crewai/types/a2a.py +++ b/src/crewai/types/a2a.py @@ -24,6 +24,42 @@ class TaskState(str, Enum): CANCELED = 'canceled' FAILED = 'failed' UNKNOWN = 'unknown' + EXPIRED = 'expired' + + @classmethod + def valid_transitions(cls) -> Dict[str, List[str]]: + """Get valid state transitions. + + Returns: + A dictionary mapping from state to list of valid next states. + """ + return { + cls.SUBMITTED: [cls.WORKING, cls.CANCELED, cls.FAILED], + cls.WORKING: [cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED], + cls.INPUT_REQUIRED: [cls.WORKING, cls.CANCELED, cls.FAILED], + cls.COMPLETED: [], # Terminal state + cls.CANCELED: [], # Terminal state + cls.FAILED: [], # Terminal state + cls.UNKNOWN: [cls.SUBMITTED, cls.WORKING, cls.INPUT_REQUIRED, cls.COMPLETED, cls.CANCELED, cls.FAILED], + cls.EXPIRED: [], # Terminal state + } + + @classmethod + def is_valid_transition(cls, from_state: 'TaskState', to_state: 'TaskState') -> bool: + """Check if a state transition is valid. + + Args: + from_state: The current state. + to_state: The target state. + + Returns: + True if the transition is valid, False otherwise. + """ + if from_state == to_state: + return True + + valid_next_states = cls.valid_transitions().get(from_state, []) + return to_state in valid_next_states class TextPart(BaseModel): @@ -83,11 +119,21 @@ class TaskStatus(BaseModel): state: TaskState message: Optional[Message] = None timestamp: datetime = Field(default_factory=datetime.now) + previous_state: Optional[TaskState] = None @field_serializer('timestamp') def serialize_dt(self, dt: datetime, _info): """Serialize datetime to ISO format.""" return dt.isoformat() + + @model_validator(mode='after') + def validate_state_transition(self) -> Self: + """Validate state transition.""" + if self.previous_state and not TaskState.is_valid_transition(self.previous_state, self.state): + raise ValueError( + f"Invalid state transition from {self.previous_state} to {self.state}" + ) + return self class Artifact(BaseModel): diff --git a/tests/a2a/test_a2a_integration.py b/tests/a2a/test_a2a_integration.py index 7d3360eb0..436434fb6 100644 --- a/tests/a2a/test_a2a_integration.py +++ b/tests/a2a/test_a2a_integration.py @@ -1,13 +1,17 @@ """Tests for the A2A protocol integration.""" import asyncio +from datetime import datetime import pytest from unittest.mock import AsyncMock, MagicMock, patch +pytestmark = pytest.mark.asyncio + from crewai.agent import Agent from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager from crewai.task import Task from crewai.types.a2a import ( + JSONRPCResponse, Message, Task as A2ATask, TaskState, @@ -60,7 +64,7 @@ def a2a_integration(): @pytest.fixture def a2a_client(): """Create an A2A client.""" - return A2AClient(base_url="http://localhost:8000") + return A2AClient(base_url="http://localhost:8000", api_key="test_api_key") @pytest.fixture @@ -81,8 +85,7 @@ class TestA2AIntegration: @patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a") def test_execute_task_via_a2a(self, mock_execute, agent): """Test executing a task via A2A.""" - mock_execute.return_value = asyncio.Future() - mock_execute.return_value.set_result("Task result") + mock_execute.return_value = "Task result" result = asyncio.run( agent.execute_task_via_a2a( @@ -116,8 +119,8 @@ class TestA2AIntegration: assert result == "Task result" mock_execute.assert_called_once() args, kwargs = mock_execute.call_args - assert args[0].description == "Test task" assert kwargs["context"] == "Test context" + assert kwargs["task"].description == "Test task" def test_a2a_disabled(self, agent): """Test that A2A methods raise ValueError when A2A is disabled.""" @@ -159,7 +162,7 @@ class TestA2AAgentIntegration: queue = asyncio.Queue() await queue.put( TaskStatusUpdateEvent( - task_id="test_task_id", + id="test_task_id", status=TaskStatus( state=TaskState.COMPLETED, message=Message( @@ -198,23 +201,29 @@ class TestA2AServer: class TestA2AClient: """Tests for the A2AClient class.""" - @patch("aiohttp.ClientSession.post") - async def test_send_task(self, mock_post, a2a_client): + @patch("crewai.a2a.client.A2AClient._send_jsonrpc_request") + async def test_send_task(self, mock_send_request, a2a_client): """Test sending a task.""" - mock_response = MagicMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value={ - "id": "test_task_id", - "history": [ - { - "role": "user", - "parts": [{"text": "Test task description"}], - } + mock_response = JSONRPCResponse( + jsonrpc="2.0", + id="test_request_id", + result=A2ATask( + id="test_task_id", + sessionId="test_session_id", + status=TaskStatus( + state=TaskState.SUBMITTED, + timestamp=datetime.now(), + ), + history=[ + Message( + role="user", + parts=[TextPart(text="Test task description")], + ) ], - } + ) ) - mock_post.return_value.__aenter__.return_value = mock_response + + mock_send_request.return_value = mock_response task = await a2a_client.send_task( task_id="test_task_id", @@ -222,9 +231,10 @@ class TestA2AClient: role="user", parts=[TextPart(text="Test task description")], ), + session_id="test_session_id", ) assert task.id == "test_task_id" assert task.history[0].role == "user" assert task.history[0].parts[0].text == "Test task description" - mock_post.assert_called_once() + mock_send_request.assert_called_once()