Files
crewAI/tests/a2a/test_a2a_integration.py
2025-05-09 04:13:04 +00:00

231 lines
6.6 KiB
Python

"""Tests for the A2A protocol integration."""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from crewai.agent import Agent
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager
from crewai.task import Task
from crewai.types.a2a import (
Message,
Task as A2ATask,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
)
@pytest.fixture
def agent():
"""Create an agent with A2A enabled."""
return Agent(
role="test_agent",
goal="Test A2A protocol",
backstory="I am a test agent",
a2a_enabled=True,
a2a_url="http://localhost:8000",
)
@pytest.fixture
def task():
"""Create a task."""
return Task(
description="Test task",
)
@pytest.fixture
def a2a_task():
"""Create an A2A task."""
return A2ATask(
id="test_task_id",
history=[
Message(
role="user",
parts=[TextPart(text="Test task description")],
)
],
)
@pytest.fixture
def a2a_integration():
"""Create an A2A integration."""
return A2AAgentIntegration()
@pytest.fixture
def a2a_client():
"""Create an A2A client."""
return A2AClient(base_url="http://localhost:8000")
@pytest.fixture
def task_manager():
"""Create a task manager."""
return InMemoryTaskManager()
class TestA2AIntegration:
"""Tests for the A2A protocol integration."""
def test_agent_a2a_attributes(self, agent):
"""Test that the agent has A2A attributes."""
assert agent.a2a_enabled is True
assert agent.a2a_url == "http://localhost:8000"
assert agent._a2a_integration is not None
@patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a")
def test_execute_task_via_a2a(self, mock_execute, agent):
"""Test executing a task via A2A."""
mock_execute.return_value = asyncio.Future()
mock_execute.return_value.set_result("Task result")
result = asyncio.run(
agent.execute_task_via_a2a(
task_description="Test task",
context="Test context",
)
)
assert result == "Task result"
mock_execute.assert_called_once_with(
agent_url="http://localhost:8000",
task_description="Test task",
context="Test context",
api_key=None,
timeout=300,
)
@patch("crewai.agent.Agent.execute_task")
def test_handle_a2a_task(self, mock_execute, agent):
"""Test handling an A2A task."""
mock_execute.return_value = "Task result"
result = asyncio.run(
agent.handle_a2a_task(
task_id="test_task_id",
task_description="Test task",
context="Test context",
)
)
assert result == "Task result"
mock_execute.assert_called_once()
args, kwargs = mock_execute.call_args
assert args[0].description == "Test task"
assert kwargs["context"] == "Test context"
def test_a2a_disabled(self, agent):
"""Test that A2A methods raise ValueError when A2A is disabled."""
agent.a2a_enabled = False
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
asyncio.run(
agent.execute_task_via_a2a(
task_description="Test task",
)
)
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
asyncio.run(
agent.handle_a2a_task(
task_id="test_task_id",
task_description="Test task",
)
)
def test_no_agent_url(self, agent):
"""Test that execute_task_via_a2a raises ValueError when no agent URL is provided."""
agent.a2a_url = None
with pytest.raises(ValueError, match="No A2A agent URL provided"):
asyncio.run(
agent.execute_task_via_a2a(
task_description="Test task",
)
)
class TestA2AAgentIntegration:
"""Tests for the A2AAgentIntegration class."""
@patch("crewai.a2a.client.A2AClient.send_task_streaming")
async def test_execute_task_via_a2a(self, mock_send_task, a2a_integration):
"""Test executing a task via A2A."""
queue = asyncio.Queue()
await queue.put(
TaskStatusUpdateEvent(
task_id="test_task_id",
status=TaskStatus(
state=TaskState.COMPLETED,
message=Message(
role="agent",
parts=[TextPart(text="Task result")],
),
),
final=True,
)
)
mock_send_task.return_value = queue
result = await a2a_integration.execute_task_via_a2a(
agent_url="http://localhost:8000",
task_description="Test task",
context="Test context",
)
assert result == "Task result"
mock_send_task.assert_called_once()
class TestA2AServer:
"""Tests for the A2AServer class."""
@patch("fastapi.FastAPI.post")
def test_server_initialization(self, mock_post, task_manager):
"""Test server initialization."""
server = A2AServer(task_manager=task_manager)
assert server.task_manager == task_manager
assert server.app is not None
assert mock_post.call_count == 4 # 4 endpoints registered
class TestA2AClient:
"""Tests for the A2AClient class."""
@patch("aiohttp.ClientSession.post")
async def test_send_task(self, mock_post, a2a_client):
"""Test sending a task."""
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(
return_value={
"id": "test_task_id",
"history": [
{
"role": "user",
"parts": [{"text": "Test task description"}],
}
],
}
)
mock_post.return_value.__aenter__.return_value = mock_response
task = await a2a_client.send_task(
task_id="test_task_id",
message=Message(
role="user",
parts=[TextPart(text="Test task description")],
),
)
assert task.id == "test_task_id"
assert task.history[0].role == "user"
assert task.history[0].parts[0].text == "Test task description"
mock_post.assert_called_once()