from __future__ import annotations import os import uuid import httpx import pytest import pytest_asyncio from a2a.client import A2ACardResolver, ClientFactory from a2a.types import AgentCapabilities, AgentCard, AgentInterface, Message, Part, Role, TaskState from crewai.a2a._compat import ( ROLE_AGENT, ROLE_USER, TASK_STATE_COMPLETED, TASK_STATE_FAILED, agent_card_url, make_send_request, new_text_message, new_text_part, ) from crewai.a2a.updates.polling.handler import PollingHandler from crewai.a2a.updates.streaming.handler import StreamingHandler A2A_TEST_ENDPOINT = os.getenv("A2A_TEST_ENDPOINT", "http://localhost:9999") @pytest_asyncio.fixture async def card_resolver(): """Create an A2ACardResolver for the test server.""" async with httpx.AsyncClient() as http_client: resolver = A2ACardResolver(http_client, A2A_TEST_ENDPOINT) yield resolver @pytest_asyncio.fixture async def agent_card(card_resolver) -> AgentCard: """Fetch the real agent card from the server.""" return await card_resolver.get_agent_card() @pytest_asyncio.fixture async def a2a_client(agent_card): """Create A2A client for test server.""" factory = ClientFactory() client = factory.create(agent_card) yield client @pytest.fixture def test_message() -> Message: """Create a simple test message.""" return new_text_message("What is 2 + 2?", role=ROLE_USER) class TestA2AAgentCardFetching: """Integration tests for agent card fetching.""" @pytest.mark.vcr() @pytest.mark.asyncio async def test_fetch_agent_card(self, card_resolver) -> None: """Test fetching an agent card from the server.""" card = await card_resolver.get_agent_card() assert card is not None assert card.name == "GPT Assistant" assert card.supported_interfaces is not None assert card.capabilities is not None assert card.capabilities.streaming is True class TestA2APollingIntegration: """Integration tests for A2A polling handler.""" @pytest.mark.vcr() @pytest.mark.asyncio async def test_polling_completes_task( self, a2a_client, test_message: Message, agent_card: AgentCard, ) -> None: """Test that polling handler completes a task successfully.""" new_messages: list[Message] = [] result = await PollingHandler.execute( client=a2a_client, message=test_message, new_messages=new_messages, agent_card=agent_card, polling_interval=0.5, polling_timeout=30.0, ) assert isinstance(result, dict) assert result["status"] == TASK_STATE_COMPLETED assert result.get("result") is not None assert "4" in result["result"] class TestA2AStreamingIntegration: """Integration tests for A2A streaming handler.""" @pytest.mark.vcr() @pytest.mark.asyncio async def test_streaming_completes_task( self, a2a_client, test_message: Message, agent_card: AgentCard, ) -> None: """Test that streaming handler completes a task successfully.""" new_messages: list[Message] = [] result = await StreamingHandler.execute( client=a2a_client, message=test_message, new_messages=new_messages, agent_card=agent_card, endpoint=agent_card_url(agent_card), ) assert isinstance(result, dict) assert result["status"] == TASK_STATE_COMPLETED assert result.get("result") is not None class TestA2ATaskOperations: """Integration tests for task operations.""" @pytest.mark.vcr() @pytest.mark.asyncio async def test_send_message_and_get_response( self, a2a_client, test_message: Message, ) -> None: """Test sending a message and getting a response.""" from a2a.types import StreamResponse, Task from crewai.a2a._compat import is_stream_task final_task: Task | None = None async for event in a2a_client.send_message(make_send_request(test_message)): if isinstance(event, StreamResponse) and is_stream_task(event): final_task = event.task assert final_task is not None assert final_task.id != "" assert final_task.status is not None assert final_task.status.state == TaskState.TASK_STATE_COMPLETED class TestA2APushNotificationHandler: """Tests for push notification handler. These tests use mocks for the result store since webhook callbacks are incoming requests that can't be recorded with VCR. """ @pytest.fixture def mock_agent_card(self) -> AgentCard: """Create a minimal valid agent card for testing.""" return AgentCard( name="Test Agent", description="Test agent for push notification tests", supported_interfaces=[ AgentInterface( url="http://localhost:9999", protocol_binding="JSONRPC", ), ], version="1.0.0", capabilities=AgentCapabilities(streaming=True, push_notifications=True), default_input_modes=["text"], default_output_modes=["text"], ) @pytest.fixture def mock_task(self) -> "Task": """Create a minimal valid task for testing.""" from a2a.types import Task, TaskStatus return Task( id="task-123", context_id="ctx-123", status=TaskStatus(state=TaskState.TASK_STATE_WORKING), ) @pytest.mark.asyncio async def test_push_handler_waits_for_result( self, mock_agent_card: AgentCard, mock_task, ) -> None: """Test that push handler waits for result from store.""" from unittest.mock import AsyncMock, MagicMock from a2a.types import StreamResponse, Task, TaskStatus from pydantic import AnyHttpUrl from crewai.a2a.updates.push_notifications.config import PushNotificationConfig from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler completed_task = Task( id="task-123", context_id="ctx-123", status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), ) mock_store = MagicMock() mock_store.wait_for_result = AsyncMock(return_value=completed_task) async def mock_send_message(*args, **kwargs): yield StreamResponse(task=mock_task) mock_client = MagicMock() mock_client.send_message = mock_send_message config = PushNotificationConfig( url=AnyHttpUrl("http://localhost:8080/a2a/callback"), token="secret-token", result_store=mock_store, ) test_msg = new_text_message("What is 2+2?", role=ROLE_USER) new_messages: list[Message] = [] result = await PushNotificationHandler.execute( client=mock_client, message=test_msg, new_messages=new_messages, agent_card=mock_agent_card, config=config, result_store=mock_store, polling_timeout=30.0, polling_interval=1.0, endpoint=agent_card_url(mock_agent_card), ) mock_store.wait_for_result.assert_called_once_with( task_id="task-123", timeout=30.0, poll_interval=1.0, ) assert result["status"] == TASK_STATE_COMPLETED @pytest.mark.asyncio async def test_push_handler_returns_failure_on_timeout( self, mock_agent_card: AgentCard, ) -> None: """Test that push handler returns failure when result store times out.""" from unittest.mock import AsyncMock, MagicMock from a2a.types import StreamResponse, Task, TaskStatus from pydantic import AnyHttpUrl from crewai.a2a.updates.push_notifications.config import PushNotificationConfig from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler mock_store = MagicMock() mock_store.wait_for_result = AsyncMock(return_value=None) working_task = Task( id="task-456", context_id="ctx-456", status=TaskStatus(state=TaskState.TASK_STATE_WORKING), ) async def mock_send_message(*args, **kwargs): yield StreamResponse(task=working_task) mock_client = MagicMock() mock_client.send_message = mock_send_message config = PushNotificationConfig( url=AnyHttpUrl("http://localhost:8080/a2a/callback"), token="token", result_store=mock_store, ) test_msg = new_text_message("test", role=ROLE_USER) new_messages: list[Message] = [] result = await PushNotificationHandler.execute( client=mock_client, message=test_msg, new_messages=new_messages, agent_card=mock_agent_card, config=config, result_store=mock_store, polling_timeout=5.0, polling_interval=0.5, endpoint=agent_card_url(mock_agent_card), ) assert result["status"] == TASK_STATE_FAILED assert "timeout" in result.get("error", "").lower() @pytest.mark.asyncio async def test_push_handler_requires_config( self, mock_agent_card: AgentCard, ) -> None: """Test that push handler fails gracefully without config.""" from unittest.mock import MagicMock from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler mock_client = MagicMock() test_msg = new_text_message("test", role=ROLE_USER) new_messages: list[Message] = [] result = await PushNotificationHandler.execute( client=mock_client, message=test_msg, new_messages=new_messages, agent_card=mock_agent_card, endpoint=agent_card_url(mock_agent_card), ) assert result["status"] == TASK_STATE_FAILED assert "config" in result.get("error", "").lower()