mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-06 17:52:35 +00:00
Fix #5607: CrewAI 1.14.2 is incompatible with a2a-sdk v1.0.1+ Breaking changes in a2a-sdk v1.0: - A2AClientHTTPError renamed to A2AClientError - Protobuf-based types replace Pydantic models - Enum values changed to SCREAMING_SNAKE_CASE - TextPart/DataPart/FilePart removed (Part uses oneof) - AgentCard.url removed (use supported_interfaces) - StreamResponse wraps all event types - model_dump/model_copy replaced with protobuf serialization Changes: - Add _compat.py: centralized compatibility layer with helpers - Update pyproject.toml: a2a-sdk>=1.0.0,<2 - Update all a2a module files to use protobuf API - Update existing tests for v1.0 patterns - Add comprehensive test_a2a_sdk_v1_compat.py (46 tests) Co-Authored-By: João <joao@crewai.com>
322 lines
9.7 KiB
Python
322 lines
9.7 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import uuid
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from a2a.client import ClientFactory
|
|
from a2a.types import AgentCapabilities, AgentCard, AgentInterface, Message, Part, Role, TaskState
|
|
|
|
from crewai.a2a._compat import (
|
|
ROLE_AGENT,
|
|
ROLE_USER,
|
|
TASK_STATE_COMPLETED,
|
|
TASK_STATE_FAILED,
|
|
agent_card_url,
|
|
new_text_message,
|
|
new_text_part,
|
|
)
|
|
from crewai.a2a.updates.polling.handler import PollingHandler
|
|
from crewai.a2a.updates.streaming.handler import StreamingHandler
|
|
|
|
|
|
A2A_TEST_ENDPOINT = os.getenv("A2A_TEST_ENDPOINT", "http://localhost:9999")
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def a2a_client():
|
|
"""Create A2A client for test server."""
|
|
client = await ClientFactory.connect(A2A_TEST_ENDPOINT)
|
|
yield client
|
|
await client.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def test_message() -> Message:
|
|
"""Create a simple test message."""
|
|
return new_text_message("What is 2 + 2?", role=ROLE_USER)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def agent_card(a2a_client) -> AgentCard:
|
|
"""Fetch the real agent card from the server."""
|
|
return await a2a_client.get_card()
|
|
|
|
|
|
class TestA2AAgentCardFetching:
|
|
"""Integration tests for agent card fetching."""
|
|
|
|
@pytest.mark.vcr()
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_agent_card(self, a2a_client) -> None:
|
|
"""Test fetching an agent card from the server."""
|
|
card = await a2a_client.get_card()
|
|
|
|
assert card is not None
|
|
assert card.name == "GPT Assistant"
|
|
assert card.supported_interfaces is not None
|
|
assert card.capabilities is not None
|
|
assert card.capabilities.streaming is True
|
|
|
|
|
|
class TestA2APollingIntegration:
|
|
"""Integration tests for A2A polling handler."""
|
|
|
|
@pytest.mark.vcr()
|
|
@pytest.mark.asyncio
|
|
async def test_polling_completes_task(
|
|
self,
|
|
a2a_client,
|
|
test_message: Message,
|
|
agent_card: AgentCard,
|
|
) -> None:
|
|
"""Test that polling handler completes a task successfully."""
|
|
new_messages: list[Message] = []
|
|
|
|
result = await PollingHandler.execute(
|
|
client=a2a_client,
|
|
message=test_message,
|
|
new_messages=new_messages,
|
|
agent_card=agent_card,
|
|
polling_interval=0.5,
|
|
polling_timeout=30.0,
|
|
)
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["status"] == TASK_STATE_COMPLETED
|
|
assert result.get("result") is not None
|
|
assert "4" in result["result"]
|
|
|
|
|
|
class TestA2AStreamingIntegration:
|
|
"""Integration tests for A2A streaming handler."""
|
|
|
|
@pytest.mark.vcr()
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_completes_task(
|
|
self,
|
|
a2a_client,
|
|
test_message: Message,
|
|
agent_card: AgentCard,
|
|
) -> None:
|
|
"""Test that streaming handler completes a task successfully."""
|
|
new_messages: list[Message] = []
|
|
|
|
result = await StreamingHandler.execute(
|
|
client=a2a_client,
|
|
message=test_message,
|
|
new_messages=new_messages,
|
|
agent_card=agent_card,
|
|
endpoint=agent_card_url(agent_card),
|
|
)
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["status"] == TASK_STATE_COMPLETED
|
|
assert result.get("result") is not None
|
|
|
|
|
|
class TestA2ATaskOperations:
|
|
"""Integration tests for task operations."""
|
|
|
|
@pytest.mark.vcr()
|
|
@pytest.mark.asyncio
|
|
async def test_send_message_and_get_response(
|
|
self,
|
|
a2a_client,
|
|
test_message: Message,
|
|
) -> None:
|
|
"""Test sending a message and getting a response."""
|
|
from a2a.types import StreamResponse, Task
|
|
|
|
from crewai.a2a._compat import is_stream_task
|
|
|
|
final_task: Task | None = None
|
|
async for event in a2a_client.send_message(test_message):
|
|
if isinstance(event, StreamResponse) and is_stream_task(event):
|
|
final_task = event.task
|
|
|
|
assert final_task is not None
|
|
assert final_task.id != ""
|
|
assert final_task.status is not None
|
|
assert final_task.status.state == TaskState.TASK_STATE_COMPLETED
|
|
|
|
|
|
class TestA2APushNotificationHandler:
|
|
"""Tests for push notification handler.
|
|
|
|
These tests use mocks for the result store since webhook callbacks
|
|
are incoming requests that can't be recorded with VCR.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def mock_agent_card(self) -> AgentCard:
|
|
"""Create a minimal valid agent card for testing."""
|
|
return AgentCard(
|
|
name="Test Agent",
|
|
description="Test agent for push notification tests",
|
|
supported_interfaces=[
|
|
AgentInterface(
|
|
url="http://localhost:9999",
|
|
protocol_binding="JSONRPC",
|
|
),
|
|
],
|
|
version="1.0.0",
|
|
capabilities=AgentCapabilities(streaming=True, push_notifications=True),
|
|
default_input_modes=["text"],
|
|
default_output_modes=["text"],
|
|
)
|
|
|
|
@pytest.fixture
|
|
def mock_task(self) -> "Task":
|
|
"""Create a minimal valid task for testing."""
|
|
from a2a.types import Task, TaskStatus
|
|
|
|
return Task(
|
|
id="task-123",
|
|
context_id="ctx-123",
|
|
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_push_handler_waits_for_result(
|
|
self,
|
|
mock_agent_card: AgentCard,
|
|
mock_task,
|
|
) -> None:
|
|
"""Test that push handler waits for result from store."""
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from a2a.types import StreamResponse, Task, TaskStatus
|
|
from pydantic import AnyHttpUrl
|
|
|
|
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
|
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
|
|
|
completed_task = Task(
|
|
id="task-123",
|
|
context_id="ctx-123",
|
|
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
|
)
|
|
|
|
mock_store = MagicMock()
|
|
mock_store.wait_for_result = AsyncMock(return_value=completed_task)
|
|
|
|
async def mock_send_message(*args, **kwargs):
|
|
yield StreamResponse(task=mock_task)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.send_message = mock_send_message
|
|
|
|
config = PushNotificationConfig(
|
|
url=AnyHttpUrl("http://localhost:8080/a2a/callback"),
|
|
token="secret-token",
|
|
result_store=mock_store,
|
|
)
|
|
|
|
test_msg = new_text_message("What is 2+2?", role=ROLE_USER)
|
|
|
|
new_messages: list[Message] = []
|
|
|
|
result = await PushNotificationHandler.execute(
|
|
client=mock_client,
|
|
message=test_msg,
|
|
new_messages=new_messages,
|
|
agent_card=mock_agent_card,
|
|
config=config,
|
|
result_store=mock_store,
|
|
polling_timeout=30.0,
|
|
polling_interval=1.0,
|
|
endpoint=agent_card_url(mock_agent_card),
|
|
)
|
|
|
|
mock_store.wait_for_result.assert_called_once_with(
|
|
task_id="task-123",
|
|
timeout=30.0,
|
|
poll_interval=1.0,
|
|
)
|
|
|
|
assert result["status"] == TASK_STATE_COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_push_handler_returns_failure_on_timeout(
|
|
self,
|
|
mock_agent_card: AgentCard,
|
|
) -> None:
|
|
"""Test that push handler returns failure when result store times out."""
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from a2a.types import StreamResponse, Task, TaskStatus
|
|
from pydantic import AnyHttpUrl
|
|
|
|
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
|
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
|
|
|
mock_store = MagicMock()
|
|
mock_store.wait_for_result = AsyncMock(return_value=None)
|
|
|
|
working_task = Task(
|
|
id="task-456",
|
|
context_id="ctx-456",
|
|
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
|
)
|
|
|
|
async def mock_send_message(*args, **kwargs):
|
|
yield StreamResponse(task=working_task)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.send_message = mock_send_message
|
|
|
|
config = PushNotificationConfig(
|
|
url=AnyHttpUrl("http://localhost:8080/a2a/callback"),
|
|
token="token",
|
|
result_store=mock_store,
|
|
)
|
|
|
|
test_msg = new_text_message("test", role=ROLE_USER)
|
|
|
|
new_messages: list[Message] = []
|
|
|
|
result = await PushNotificationHandler.execute(
|
|
client=mock_client,
|
|
message=test_msg,
|
|
new_messages=new_messages,
|
|
agent_card=mock_agent_card,
|
|
config=config,
|
|
result_store=mock_store,
|
|
polling_timeout=5.0,
|
|
polling_interval=0.5,
|
|
endpoint=agent_card_url(mock_agent_card),
|
|
)
|
|
|
|
assert result["status"] == TASK_STATE_FAILED
|
|
assert "timeout" in result.get("error", "").lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_push_handler_requires_config(
|
|
self,
|
|
mock_agent_card: AgentCard,
|
|
) -> None:
|
|
"""Test that push handler fails gracefully without config."""
|
|
from unittest.mock import MagicMock
|
|
|
|
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
|
|
|
mock_client = MagicMock()
|
|
|
|
test_msg = new_text_message("test", role=ROLE_USER)
|
|
|
|
new_messages: list[Message] = []
|
|
|
|
result = await PushNotificationHandler.execute(
|
|
client=mock_client,
|
|
message=test_msg,
|
|
new_messages=new_messages,
|
|
agent_card=mock_agent_card,
|
|
endpoint=agent_card_url(mock_agent_card),
|
|
)
|
|
|
|
assert result["status"] == TASK_STATE_FAILED
|
|
assert "config" in result.get("error", "").lower()
|