mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 07:38:14 +00:00
Add A2A protocol support for CrewAI (Issue #2796)
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
230
tests/a2a/test_a2a_integration.py
Normal file
230
tests/a2a/test_a2a_integration.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Tests for the A2A protocol integration."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.a2a import A2AAgentIntegration, A2AClient, A2AServer, InMemoryTaskManager
|
||||
from crewai.task import Task
|
||||
from crewai.types.a2a import (
|
||||
Message,
|
||||
Task as A2ATask,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
"""Create an agent with A2A enabled."""
|
||||
return Agent(
|
||||
role="test_agent",
|
||||
goal="Test A2A protocol",
|
||||
backstory="I am a test agent",
|
||||
a2a_enabled=True,
|
||||
a2a_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task():
|
||||
"""Create a task."""
|
||||
return Task(
|
||||
description="Test task",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def a2a_task():
|
||||
"""Create an A2A task."""
|
||||
return A2ATask(
|
||||
id="test_task_id",
|
||||
history=[
|
||||
Message(
|
||||
role="user",
|
||||
parts=[TextPart(text="Test task description")],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def a2a_integration():
|
||||
"""Create an A2A integration."""
|
||||
return A2AAgentIntegration()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def a2a_client():
|
||||
"""Create an A2A client."""
|
||||
return A2AClient(base_url="http://localhost:8000")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_manager():
|
||||
"""Create a task manager."""
|
||||
return InMemoryTaskManager()
|
||||
|
||||
|
||||
class TestA2AIntegration:
|
||||
"""Tests for the A2A protocol integration."""
|
||||
|
||||
def test_agent_a2a_attributes(self, agent):
|
||||
"""Test that the agent has A2A attributes."""
|
||||
assert agent.a2a_enabled is True
|
||||
assert agent.a2a_url == "http://localhost:8000"
|
||||
assert agent._a2a_integration is not None
|
||||
|
||||
@patch("crewai.a2a.agent.A2AAgentIntegration.execute_task_via_a2a")
|
||||
def test_execute_task_via_a2a(self, mock_execute, agent):
|
||||
"""Test executing a task via A2A."""
|
||||
mock_execute.return_value = asyncio.Future()
|
||||
mock_execute.return_value.set_result("Task result")
|
||||
|
||||
result = asyncio.run(
|
||||
agent.execute_task_via_a2a(
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
)
|
||||
)
|
||||
|
||||
assert result == "Task result"
|
||||
mock_execute.assert_called_once_with(
|
||||
agent_url="http://localhost:8000",
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
api_key=None,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
@patch("crewai.agent.Agent.execute_task")
|
||||
def test_handle_a2a_task(self, mock_execute, agent):
|
||||
"""Test handling an A2A task."""
|
||||
mock_execute.return_value = "Task result"
|
||||
|
||||
result = asyncio.run(
|
||||
agent.handle_a2a_task(
|
||||
task_id="test_task_id",
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
)
|
||||
)
|
||||
|
||||
assert result == "Task result"
|
||||
mock_execute.assert_called_once()
|
||||
args, kwargs = mock_execute.call_args
|
||||
assert args[0].description == "Test task"
|
||||
assert kwargs["context"] == "Test context"
|
||||
|
||||
def test_a2a_disabled(self, agent):
|
||||
"""Test that A2A methods raise ValueError when A2A is disabled."""
|
||||
agent.a2a_enabled = False
|
||||
|
||||
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
|
||||
asyncio.run(
|
||||
agent.execute_task_via_a2a(
|
||||
task_description="Test task",
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="A2A protocol is not enabled for this agent"):
|
||||
asyncio.run(
|
||||
agent.handle_a2a_task(
|
||||
task_id="test_task_id",
|
||||
task_description="Test task",
|
||||
)
|
||||
)
|
||||
|
||||
def test_no_agent_url(self, agent):
|
||||
"""Test that execute_task_via_a2a raises ValueError when no agent URL is provided."""
|
||||
agent.a2a_url = None
|
||||
|
||||
with pytest.raises(ValueError, match="No A2A agent URL provided"):
|
||||
asyncio.run(
|
||||
agent.execute_task_via_a2a(
|
||||
task_description="Test task",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestA2AAgentIntegration:
|
||||
"""Tests for the A2AAgentIntegration class."""
|
||||
|
||||
@patch("crewai.a2a.client.A2AClient.send_task_streaming")
|
||||
async def test_execute_task_via_a2a(self, mock_send_task, a2a_integration):
|
||||
"""Test executing a task via A2A."""
|
||||
queue = asyncio.Queue()
|
||||
await queue.put(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id="test_task_id",
|
||||
status=TaskStatus(
|
||||
state=TaskState.COMPLETED,
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text="Task result")],
|
||||
),
|
||||
),
|
||||
final=True,
|
||||
)
|
||||
)
|
||||
|
||||
mock_send_task.return_value = queue
|
||||
|
||||
result = await a2a_integration.execute_task_via_a2a(
|
||||
agent_url="http://localhost:8000",
|
||||
task_description="Test task",
|
||||
context="Test context",
|
||||
)
|
||||
|
||||
assert result == "Task result"
|
||||
mock_send_task.assert_called_once()
|
||||
|
||||
|
||||
class TestA2AServer:
|
||||
"""Tests for the A2AServer class."""
|
||||
|
||||
@patch("fastapi.FastAPI.post")
|
||||
def test_server_initialization(self, mock_post, task_manager):
|
||||
"""Test server initialization."""
|
||||
server = A2AServer(task_manager=task_manager)
|
||||
assert server.task_manager == task_manager
|
||||
assert server.app is not None
|
||||
assert mock_post.call_count == 4 # 4 endpoints registered
|
||||
|
||||
|
||||
class TestA2AClient:
|
||||
"""Tests for the A2AClient class."""
|
||||
|
||||
@patch("aiohttp.ClientSession.post")
|
||||
async def test_send_task(self, mock_post, a2a_client):
|
||||
"""Test sending a task."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
"id": "test_task_id",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "Test task description"}],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
task = await a2a_client.send_task(
|
||||
task_id="test_task_id",
|
||||
message=Message(
|
||||
role="user",
|
||||
parts=[TextPart(text="Test task description")],
|
||||
),
|
||||
)
|
||||
|
||||
assert task.id == "test_task_id"
|
||||
assert task.history[0].role == "user"
|
||||
assert task.history[0].parts[0].text == "Test task description"
|
||||
mock_post.assert_called_once()
|
||||
Reference in New Issue
Block a user