mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-06 09:42:39 +00:00
- Add make_send_request() helper in _compat.py for v1.0 API - Update all handlers to wrap Message in SendMessageRequest - Fix ServerError(error=...) → ServerError(message) in task.py - Fix MessageToDict parameter name (always_print_fields_with_no_presence) - Update integration tests for v1.0 client API (A2ACardResolver, ClientFactory) - Fix test mocks to use real protobuf Message instead of MagicMock Co-Authored-By: João <joao@crewai.com>
332 lines
10 KiB
Python
332 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import uuid
|
|
|
|
import httpx
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from a2a.client import A2ACardResolver, ClientFactory
|
|
from a2a.types import AgentCapabilities, AgentCard, AgentInterface, Message, Part, Role, TaskState
|
|
|
|
from crewai.a2a._compat import (
|
|
ROLE_AGENT,
|
|
ROLE_USER,
|
|
TASK_STATE_COMPLETED,
|
|
TASK_STATE_FAILED,
|
|
agent_card_url,
|
|
make_send_request,
|
|
new_text_message,
|
|
new_text_part,
|
|
)
|
|
from crewai.a2a.updates.polling.handler import PollingHandler
|
|
from crewai.a2a.updates.streaming.handler import StreamingHandler
|
|
|
|
|
|
A2A_TEST_ENDPOINT = os.getenv("A2A_TEST_ENDPOINT", "http://localhost:9999")
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def card_resolver():
|
|
"""Create an A2ACardResolver for the test server."""
|
|
async with httpx.AsyncClient() as http_client:
|
|
resolver = A2ACardResolver(http_client, A2A_TEST_ENDPOINT)
|
|
yield resolver
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def agent_card(card_resolver) -> AgentCard:
|
|
"""Fetch the real agent card from the server."""
|
|
return await card_resolver.get_agent_card()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def a2a_client(agent_card):
|
|
"""Create A2A client for test server."""
|
|
factory = ClientFactory()
|
|
client = factory.create(agent_card)
|
|
yield client
|
|
|
|
|
|
@pytest.fixture
|
|
def test_message() -> Message:
|
|
"""Create a simple test message."""
|
|
return new_text_message("What is 2 + 2?", role=ROLE_USER)
|
|
|
|
|
|
class TestA2AAgentCardFetching:
|
|
"""Integration tests for agent card fetching."""
|
|
|
|
@pytest.mark.vcr()
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_agent_card(self, card_resolver) -> None:
|
|
"""Test fetching an agent card from the server."""
|
|
card = await card_resolver.get_agent_card()
|
|
|
|
assert card is not None
|
|
assert card.name == "GPT Assistant"
|
|
assert card.supported_interfaces is not None
|
|
assert card.capabilities is not None
|
|
assert card.capabilities.streaming is True
|
|
|
|
|
|
class TestA2APollingIntegration:
|
|
"""Integration tests for A2A polling handler."""
|
|
|
|
@pytest.mark.vcr()
|
|
@pytest.mark.asyncio
|
|
async def test_polling_completes_task(
|
|
self,
|
|
a2a_client,
|
|
test_message: Message,
|
|
agent_card: AgentCard,
|
|
) -> None:
|
|
"""Test that polling handler completes a task successfully."""
|
|
new_messages: list[Message] = []
|
|
|
|
result = await PollingHandler.execute(
|
|
client=a2a_client,
|
|
message=test_message,
|
|
new_messages=new_messages,
|
|
agent_card=agent_card,
|
|
polling_interval=0.5,
|
|
polling_timeout=30.0,
|
|
)
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["status"] == TASK_STATE_COMPLETED
|
|
assert result.get("result") is not None
|
|
assert "4" in result["result"]
|
|
|
|
|
|
class TestA2AStreamingIntegration:
|
|
"""Integration tests for A2A streaming handler."""
|
|
|
|
@pytest.mark.vcr()
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_completes_task(
|
|
self,
|
|
a2a_client,
|
|
test_message: Message,
|
|
agent_card: AgentCard,
|
|
) -> None:
|
|
"""Test that streaming handler completes a task successfully."""
|
|
new_messages: list[Message] = []
|
|
|
|
result = await StreamingHandler.execute(
|
|
client=a2a_client,
|
|
message=test_message,
|
|
new_messages=new_messages,
|
|
agent_card=agent_card,
|
|
endpoint=agent_card_url(agent_card),
|
|
)
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["status"] == TASK_STATE_COMPLETED
|
|
assert result.get("result") is not None
|
|
|
|
|
|
class TestA2ATaskOperations:
|
|
"""Integration tests for task operations."""
|
|
|
|
@pytest.mark.vcr()
|
|
@pytest.mark.asyncio
|
|
async def test_send_message_and_get_response(
|
|
self,
|
|
a2a_client,
|
|
test_message: Message,
|
|
) -> None:
|
|
"""Test sending a message and getting a response."""
|
|
from a2a.types import StreamResponse, Task
|
|
|
|
from crewai.a2a._compat import is_stream_task
|
|
|
|
final_task: Task | None = None
|
|
async for event in a2a_client.send_message(make_send_request(test_message)):
|
|
if isinstance(event, StreamResponse) and is_stream_task(event):
|
|
final_task = event.task
|
|
|
|
assert final_task is not None
|
|
assert final_task.id != ""
|
|
assert final_task.status is not None
|
|
assert final_task.status.state == TaskState.TASK_STATE_COMPLETED
|
|
|
|
|
|
class TestA2APushNotificationHandler:
|
|
"""Tests for push notification handler.
|
|
|
|
These tests use mocks for the result store since webhook callbacks
|
|
are incoming requests that can't be recorded with VCR.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def mock_agent_card(self) -> AgentCard:
|
|
"""Create a minimal valid agent card for testing."""
|
|
return AgentCard(
|
|
name="Test Agent",
|
|
description="Test agent for push notification tests",
|
|
supported_interfaces=[
|
|
AgentInterface(
|
|
url="http://localhost:9999",
|
|
protocol_binding="JSONRPC",
|
|
),
|
|
],
|
|
version="1.0.0",
|
|
capabilities=AgentCapabilities(streaming=True, push_notifications=True),
|
|
default_input_modes=["text"],
|
|
default_output_modes=["text"],
|
|
)
|
|
|
|
@pytest.fixture
|
|
def mock_task(self) -> "Task":
|
|
"""Create a minimal valid task for testing."""
|
|
from a2a.types import Task, TaskStatus
|
|
|
|
return Task(
|
|
id="task-123",
|
|
context_id="ctx-123",
|
|
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_push_handler_waits_for_result(
|
|
self,
|
|
mock_agent_card: AgentCard,
|
|
mock_task,
|
|
) -> None:
|
|
"""Test that push handler waits for result from store."""
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from a2a.types import StreamResponse, Task, TaskStatus
|
|
from pydantic import AnyHttpUrl
|
|
|
|
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
|
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
|
|
|
completed_task = Task(
|
|
id="task-123",
|
|
context_id="ctx-123",
|
|
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
|
)
|
|
|
|
mock_store = MagicMock()
|
|
mock_store.wait_for_result = AsyncMock(return_value=completed_task)
|
|
|
|
async def mock_send_message(*args, **kwargs):
|
|
yield StreamResponse(task=mock_task)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.send_message = mock_send_message
|
|
|
|
config = PushNotificationConfig(
|
|
url=AnyHttpUrl("http://localhost:8080/a2a/callback"),
|
|
token="secret-token",
|
|
result_store=mock_store,
|
|
)
|
|
|
|
test_msg = new_text_message("What is 2+2?", role=ROLE_USER)
|
|
|
|
new_messages: list[Message] = []
|
|
|
|
result = await PushNotificationHandler.execute(
|
|
client=mock_client,
|
|
message=test_msg,
|
|
new_messages=new_messages,
|
|
agent_card=mock_agent_card,
|
|
config=config,
|
|
result_store=mock_store,
|
|
polling_timeout=30.0,
|
|
polling_interval=1.0,
|
|
endpoint=agent_card_url(mock_agent_card),
|
|
)
|
|
|
|
mock_store.wait_for_result.assert_called_once_with(
|
|
task_id="task-123",
|
|
timeout=30.0,
|
|
poll_interval=1.0,
|
|
)
|
|
|
|
assert result["status"] == TASK_STATE_COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_push_handler_returns_failure_on_timeout(
|
|
self,
|
|
mock_agent_card: AgentCard,
|
|
) -> None:
|
|
"""Test that push handler returns failure when result store times out."""
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from a2a.types import StreamResponse, Task, TaskStatus
|
|
from pydantic import AnyHttpUrl
|
|
|
|
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
|
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
|
|
|
mock_store = MagicMock()
|
|
mock_store.wait_for_result = AsyncMock(return_value=None)
|
|
|
|
working_task = Task(
|
|
id="task-456",
|
|
context_id="ctx-456",
|
|
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
|
)
|
|
|
|
async def mock_send_message(*args, **kwargs):
|
|
yield StreamResponse(task=working_task)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.send_message = mock_send_message
|
|
|
|
config = PushNotificationConfig(
|
|
url=AnyHttpUrl("http://localhost:8080/a2a/callback"),
|
|
token="token",
|
|
result_store=mock_store,
|
|
)
|
|
|
|
test_msg = new_text_message("test", role=ROLE_USER)
|
|
|
|
new_messages: list[Message] = []
|
|
|
|
result = await PushNotificationHandler.execute(
|
|
client=mock_client,
|
|
message=test_msg,
|
|
new_messages=new_messages,
|
|
agent_card=mock_agent_card,
|
|
config=config,
|
|
result_store=mock_store,
|
|
polling_timeout=5.0,
|
|
polling_interval=0.5,
|
|
endpoint=agent_card_url(mock_agent_card),
|
|
)
|
|
|
|
assert result["status"] == TASK_STATE_FAILED
|
|
assert "timeout" in result.get("error", "").lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_push_handler_requires_config(
|
|
self,
|
|
mock_agent_card: AgentCard,
|
|
) -> None:
|
|
"""Test that push handler fails gracefully without config."""
|
|
from unittest.mock import MagicMock
|
|
|
|
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
|
|
|
mock_client = MagicMock()
|
|
|
|
test_msg = new_text_message("test", role=ROLE_USER)
|
|
|
|
new_messages: list[Message] = []
|
|
|
|
result = await PushNotificationHandler.execute(
|
|
client=mock_client,
|
|
message=test_msg,
|
|
new_messages=new_messages,
|
|
agent_card=mock_agent_card,
|
|
endpoint=agent_card_url(mock_agent_card),
|
|
)
|
|
|
|
assert result["status"] == TASK_STATE_FAILED
|
|
assert "config" in result.get("error", "").lower()
|