mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-15 11:08:33 +00:00
370 lines
12 KiB
Python
370 lines
12 KiB
Python
"""Tests for A2A task utilities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from a2a.server.agent_execution import RequestContext
|
|
from a2a.server.events import EventQueue
|
|
from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus
|
|
|
|
from crewai.a2a.utils.task import cancel, cancellable, execute
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent() -> MagicMock:
|
|
"""Create a mock CrewAI agent."""
|
|
agent = MagicMock()
|
|
agent.role = "Test Agent"
|
|
agent.tools = []
|
|
agent.aexecute_task = AsyncMock(return_value="Task completed successfully")
|
|
return agent
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_task() -> MagicMock:
|
|
"""Create a mock Task."""
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_context() -> MagicMock:
|
|
"""Create a mock RequestContext."""
|
|
context = MagicMock(spec=RequestContext)
|
|
context.task_id = "test-task-123"
|
|
context.context_id = "test-context-456"
|
|
context.get_user_input.return_value = "Test user message"
|
|
context.message = MagicMock(spec=Message)
|
|
context.current_task = None
|
|
return context
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_event_queue() -> AsyncMock:
|
|
"""Create a mock EventQueue."""
|
|
queue = AsyncMock(spec=EventQueue)
|
|
queue.enqueue_event = AsyncMock()
|
|
return queue
|
|
|
|
|
|
@pytest_asyncio.fixture(autouse=True)
|
|
async def clear_cache(mock_context: MagicMock) -> None:
|
|
"""Clear cancel flag from cache before each test."""
|
|
from aiocache import caches
|
|
|
|
cache = caches.get("default")
|
|
await cache.delete(f"cancel:{mock_context.task_id}")
|
|
|
|
|
|
class TestCancellableDecorator:
|
|
"""Tests for the cancellable decorator."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_executes_function_without_context(self) -> None:
|
|
"""Function executes normally when no RequestContext is provided."""
|
|
call_count = 0
|
|
|
|
@cancellable
|
|
async def my_func(value: int) -> int:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return value * 2
|
|
|
|
result = await my_func(5)
|
|
|
|
assert result == 10
|
|
assert call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_executes_function_with_context(self, mock_context: MagicMock) -> None:
|
|
"""Function executes normally with RequestContext when not cancelled."""
|
|
@cancellable
|
|
async def my_func(context: RequestContext) -> str:
|
|
await asyncio.sleep(0.01)
|
|
return "completed"
|
|
|
|
result = await my_func(mock_context)
|
|
|
|
assert result == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancellation_raises_cancelled_error(
|
|
self, mock_context: MagicMock
|
|
) -> None:
|
|
"""Function raises CancelledError when cancel flag is set."""
|
|
from aiocache import caches
|
|
|
|
cache = caches.get("default")
|
|
|
|
@cancellable
|
|
async def slow_func(context: RequestContext) -> str:
|
|
await asyncio.sleep(1.0)
|
|
return "should not reach"
|
|
|
|
await cache.set(f"cancel:{mock_context.task_id}", True)
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await slow_func(mock_context)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_removes_cancel_flag(self, mock_context: MagicMock) -> None:
|
|
"""Cancel flag is cleaned up after execution."""
|
|
from aiocache import caches
|
|
|
|
cache = caches.get("default")
|
|
|
|
@cancellable
|
|
async def quick_func(context: RequestContext) -> str:
|
|
return "done"
|
|
|
|
await quick_func(mock_context)
|
|
|
|
flag = await cache.get(f"cancel:{mock_context.task_id}")
|
|
assert flag is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extracts_context_from_kwargs(self, mock_context: MagicMock) -> None:
|
|
"""Context can be passed as keyword argument."""
|
|
@cancellable
|
|
async def my_func(value: int, context: RequestContext | None = None) -> int:
|
|
return value + 1
|
|
|
|
result = await my_func(10, context=mock_context)
|
|
|
|
assert result == 11
|
|
|
|
|
|
class TestExecute:
|
|
"""Tests for the execute function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_execution(
|
|
self,
|
|
mock_agent: MagicMock,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
mock_task: MagicMock,
|
|
) -> None:
|
|
"""Execute completes successfully and enqueues completed task."""
|
|
with (
|
|
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
|
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
|
):
|
|
await execute(mock_agent, mock_context, mock_event_queue)
|
|
|
|
mock_agent.aexecute_task.assert_called_once()
|
|
mock_event_queue.enqueue_event.assert_called_once()
|
|
assert mock_bus.emit.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_emits_started_event(
|
|
self,
|
|
mock_agent: MagicMock,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
mock_task: MagicMock,
|
|
) -> None:
|
|
"""Execute emits A2AServerTaskStartedEvent."""
|
|
with (
|
|
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
|
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
|
):
|
|
await execute(mock_agent, mock_context, mock_event_queue)
|
|
|
|
first_call = mock_bus.emit.call_args_list[0]
|
|
event = first_call[0][1]
|
|
|
|
assert event.type == "a2a_server_task_started"
|
|
assert event.a2a_task_id == mock_context.task_id
|
|
assert event.a2a_context_id == mock_context.context_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_emits_completed_event(
|
|
self,
|
|
mock_agent: MagicMock,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
mock_task: MagicMock,
|
|
) -> None:
|
|
"""Execute emits A2AServerTaskCompletedEvent on success."""
|
|
with (
|
|
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
|
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
|
):
|
|
await execute(mock_agent, mock_context, mock_event_queue)
|
|
|
|
second_call = mock_bus.emit.call_args_list[1]
|
|
event = second_call[0][1]
|
|
|
|
assert event.type == "a2a_server_task_completed"
|
|
assert event.a2a_task_id == mock_context.task_id
|
|
assert event.result == "Task completed successfully"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_emits_failed_event_on_exception(
|
|
self,
|
|
mock_agent: MagicMock,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
mock_task: MagicMock,
|
|
) -> None:
|
|
"""Execute emits A2AServerTaskFailedEvent on exception."""
|
|
mock_agent.aexecute_task = AsyncMock(side_effect=ValueError("Test error"))
|
|
|
|
with (
|
|
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
|
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
|
):
|
|
with pytest.raises(Exception):
|
|
await execute(mock_agent, mock_context, mock_event_queue)
|
|
|
|
failed_call = mock_bus.emit.call_args_list[1]
|
|
event = failed_call[0][1]
|
|
|
|
assert event.type == "a2a_server_task_failed"
|
|
assert "Test error" in event.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_emits_canceled_event_on_cancellation(
|
|
self,
|
|
mock_agent: MagicMock,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
mock_task: MagicMock,
|
|
) -> None:
|
|
"""Execute emits A2AServerTaskCanceledEvent on CancelledError."""
|
|
mock_agent.aexecute_task = AsyncMock(side_effect=asyncio.CancelledError())
|
|
|
|
with (
|
|
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
|
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
|
):
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await execute(mock_agent, mock_context, mock_event_queue)
|
|
|
|
canceled_call = mock_bus.emit.call_args_list[1]
|
|
event = canceled_call[0][1]
|
|
|
|
assert event.type == "a2a_server_task_canceled"
|
|
assert event.a2a_task_id == mock_context.task_id
|
|
|
|
|
|
class TestCancel:
|
|
"""Tests for the cancel function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sets_cancel_flag_in_cache(
|
|
self,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
) -> None:
|
|
"""Cancel sets the cancel flag in cache."""
|
|
from aiocache import caches
|
|
|
|
cache = caches.get("default")
|
|
|
|
await cancel(mock_context, mock_event_queue)
|
|
|
|
flag = await cache.get(f"cancel:{mock_context.task_id}")
|
|
assert flag is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enqueues_task_status_update_event(
|
|
self,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
) -> None:
|
|
"""Cancel enqueues TaskStatusUpdateEvent with canceled state."""
|
|
await cancel(mock_context, mock_event_queue)
|
|
|
|
mock_event_queue.enqueue_event.assert_called_once()
|
|
event = mock_event_queue.enqueue_event.call_args[0][0]
|
|
|
|
assert event.task_id == mock_context.task_id
|
|
assert event.context_id == mock_context.context_id
|
|
assert event.status.state == TaskState.canceled
|
|
assert event.final is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_none_when_no_current_task(
|
|
self,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
) -> None:
|
|
"""Cancel returns None when context has no current_task."""
|
|
mock_context.current_task = None
|
|
|
|
result = await cancel(mock_context, mock_event_queue)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_updated_task_when_current_task_exists(
|
|
self,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
) -> None:
|
|
"""Cancel returns updated task when context has current_task."""
|
|
current_task = MagicMock(spec=A2ATask)
|
|
current_task.status = TaskStatus(state=TaskState.working)
|
|
mock_context.current_task = current_task
|
|
|
|
result = await cancel(mock_context, mock_event_queue)
|
|
|
|
assert result is current_task
|
|
assert result.status.state == TaskState.canceled
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_after_cancel(
|
|
self,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
) -> None:
|
|
"""Cancel flag persists for cancellable decorator to detect."""
|
|
from aiocache import caches
|
|
|
|
cache = caches.get("default")
|
|
|
|
await cancel(mock_context, mock_event_queue)
|
|
|
|
flag = await cache.get(f"cancel:{mock_context.task_id}")
|
|
assert flag is True
|
|
|
|
await cache.delete(f"cancel:{mock_context.task_id}")
|
|
|
|
|
|
class TestExecuteAndCancelIntegration:
|
|
"""Integration tests for execute and cancel working together."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_stops_running_execute(
|
|
self,
|
|
mock_agent: MagicMock,
|
|
mock_context: MagicMock,
|
|
mock_event_queue: AsyncMock,
|
|
mock_task: MagicMock,
|
|
) -> None:
|
|
"""Calling cancel stops a running execute."""
|
|
async def slow_task(**kwargs: Any) -> str:
|
|
await asyncio.sleep(2.0)
|
|
return "should not complete"
|
|
|
|
mock_agent.aexecute_task = slow_task
|
|
|
|
with (
|
|
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
|
patch("crewai.a2a.utils.task.crewai_event_bus"),
|
|
):
|
|
execute_task = asyncio.create_task(
|
|
execute(mock_agent, mock_context, mock_event_queue)
|
|
)
|
|
|
|
await asyncio.sleep(0.1)
|
|
await cancel(mock_context, mock_event_queue)
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await execute_task |