diff --git a/lib/crewai/src/crewai/a2a/config.py b/lib/crewai/src/crewai/a2a/config.py index 1597ae821..1b8cd7d81 100644 --- a/lib/crewai/src/crewai/a2a/config.py +++ b/lib/crewai/src/crewai/a2a/config.py @@ -8,14 +8,6 @@ from __future__ import annotations from importlib.metadata import version from typing import Any, ClassVar, Literal -from a2a.types import ( - AgentCapabilities, - AgentCardSignature, - AgentInterface, - AgentProvider, - AgentSkill, - SecurityScheme, -) from pydantic import BaseModel, ConfigDict, Field from typing_extensions import deprecated @@ -24,8 +16,24 @@ from crewai.a2a.types import TransportType, Url try: + from a2a.types import ( + AgentCapabilities, + AgentCardSignature, + AgentInterface, + AgentProvider, + AgentSkill, + SecurityScheme, + ) + from crewai.a2a.updates import UpdateConfig except ImportError: + UpdateConfig = Any + AgentCapabilities = Any + AgentCardSignature = Any + AgentInterface = Any + AgentProvider = Any + SecurityScheme = Any + AgentSkill = Any UpdateConfig = Any # type: ignore[misc,assignment] diff --git a/lib/crewai/src/crewai/a2a/utils/response_model.py b/lib/crewai/src/crewai/a2a/utils/response_model.py index 44d8a5ba6..4e65ef2b7 100644 --- a/lib/crewai/src/crewai/a2a/utils/response_model.py +++ b/lib/crewai/src/crewai/a2a/utils/response_model.py @@ -14,15 +14,19 @@ A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig -def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel]: +def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None: """Create a dynamic AgentResponse model with Literal types for agent IDs. Args: agent_ids: List of available A2A agent IDs. Returns: - Dynamically created Pydantic model with Literal-constrained a2a_ids field. + Dynamically created Pydantic model with Literal-constrained a2a_ids field, + or None if agent_ids is empty. """ + if not agent_ids: + return None + DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806 return create_model( @@ -83,7 +87,7 @@ def extract_a2a_agent_ids_from_config( def get_a2a_agents_and_response_model( a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None, -) -> tuple[list[A2AClientConfigTypes], type[BaseModel]]: +) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]: """Get A2A agent configs and response model. Args: diff --git a/lib/crewai/src/crewai/a2a/utils/task.py b/lib/crewai/src/crewai/a2a/utils/task.py new file mode 100644 index 000000000..5669e7e4b --- /dev/null +++ b/lib/crewai/src/crewai/a2a/utils/task.py @@ -0,0 +1,284 @@ +"""A2A task utilities for server-side task management.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable, Coroutine +from functools import wraps +import logging +import os +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast + +from a2a.server.agent_execution import RequestContext +from a2a.server.events import EventQueue +from a2a.types import ( + InternalError, + InvalidParamsError, + Message, + Task as A2ATask, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.utils import new_agent_text_message, new_text_artifact +from a2a.utils.errors import ServerError +from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped] + +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2AServerTaskCanceledEvent, + A2AServerTaskCompletedEvent, + A2AServerTaskFailedEvent, + A2AServerTaskStartedEvent, +) +from crewai.task import Task + + +if TYPE_CHECKING: + from crewai.agent import Agent + + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +T = TypeVar("T") + + +def _parse_redis_url(url: str) -> dict[str, Any]: + from urllib.parse import urlparse + + parsed = urlparse(url) + config: dict[str, Any] = { + "cache": "aiocache.RedisCache", + "endpoint": parsed.hostname or "localhost", + "port": parsed.port or 6379, + } + if parsed.path and parsed.path != "/": + try: + config["db"] = int(parsed.path.lstrip("/")) + except ValueError: + pass + if parsed.password: + config["password"] = parsed.password + return config + + +_redis_url = os.environ.get("REDIS_URL") + +caches.set_config( + { + "default": _parse_redis_url(_redis_url) + if _redis_url + else { + "cache": "aiocache.SimpleMemoryCache", + } + } +) + + +def cancellable( + fn: Callable[P, Coroutine[Any, Any, T]], +) -> Callable[P, Coroutine[Any, Any, T]]: + """Decorator that enables cancellation for A2A task execution. + + Runs a cancellation watcher concurrently with the wrapped function. + When a cancel event is published, the execution is cancelled. + + Args: + fn: The async function to wrap. + + Returns: + Wrapped function with cancellation support. + """ + + @wraps(fn) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + """Wrap function with cancellation monitoring.""" + context: RequestContext | None = None + for arg in args: + if isinstance(arg, RequestContext): + context = arg + break + if context is None: + context = cast(RequestContext | None, kwargs.get("context")) + + if context is None: + return await fn(*args, **kwargs) + + task_id = context.task_id + cache = caches.get("default") + + async def poll_for_cancel() -> bool: + """Poll cache for cancellation flag.""" + while True: + if await cache.get(f"cancel:{task_id}"): + return True + await asyncio.sleep(0.1) + + async def watch_for_cancel() -> bool: + """Watch for cancellation events via pub/sub or polling.""" + if isinstance(cache, SimpleMemoryCache): + return await poll_for_cancel() + + try: + client = cache.client + pubsub = client.pubsub() + await pubsub.subscribe(f"cancel:{task_id}") + async for message in pubsub.listen(): + if message["type"] == "message": + return True + except Exception as e: + logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e) + return await poll_for_cancel() + return False + + execute_task = asyncio.create_task(fn(*args, **kwargs)) + cancel_watch = asyncio.create_task(watch_for_cancel()) + + try: + done, _ = await asyncio.wait( + [execute_task, cancel_watch], + return_when=asyncio.FIRST_COMPLETED, + ) + + if cancel_watch in done: + execute_task.cancel() + try: + await execute_task + except asyncio.CancelledError: + pass + raise asyncio.CancelledError(f"Task {task_id} was cancelled") + cancel_watch.cancel() + return execute_task.result() + finally: + await cache.delete(f"cancel:{task_id}") + + return wrapper + + +@cancellable +async def execute( + agent: Agent, + context: RequestContext, + event_queue: EventQueue, +) -> None: + """Execute an A2A task using a CrewAI agent. + + Args: + agent: The CrewAI agent to execute the task. + context: The A2A request context containing the user's message. + event_queue: The event queue for sending responses back. + + TODOs: + * need to impl both of structured output and file inputs, depends on `file_inputs` for + `crewai.task.Task`, pass the below two to Task. both utils in `a2a.utils.parts` + * structured outputs ingestion, `structured_inputs = get_data_parts(parts=context.message.parts)` + * file inputs ingestion, `file_inputs = get_file_parts(parts=context.message.parts)` + """ + + user_message = context.get_user_input() + task_id = context.task_id + context_id = context.context_id + if task_id is None or context_id is None: + msg = "task_id and context_id are required" + crewai_event_bus.emit( + agent, + A2AServerTaskFailedEvent(a2a_task_id="", a2a_context_id="", error=msg), + ) + raise ServerError(InvalidParamsError(message=msg)) from None + + task = Task( + description=user_message, + expected_output="Response to the user's request", + agent=agent, + ) + + crewai_event_bus.emit( + agent, + A2AServerTaskStartedEvent(a2a_task_id=task_id, a2a_context_id=context_id), + ) + + try: + result = await agent.aexecute_task(task=task, tools=agent.tools) + result_str = str(result) + history: list[Message] = [context.message] if context.message else [] + history.append(new_agent_text_message(result_str, context_id, task_id)) + await event_queue.enqueue_event( + A2ATask( + id=task_id, + context_id=context_id, + status=TaskStatus(state=TaskState.input_required), + artifacts=[new_text_artifact(result_str, f"result_{task_id}")], + history=history, + ) + ) + crewai_event_bus.emit( + agent, + A2AServerTaskCompletedEvent( + a2a_task_id=task_id, a2a_context_id=context_id, result=str(result) + ), + ) + except asyncio.CancelledError: + crewai_event_bus.emit( + agent, + A2AServerTaskCanceledEvent(a2a_task_id=task_id, a2a_context_id=context_id), + ) + raise + except Exception as e: + crewai_event_bus.emit( + agent, + A2AServerTaskFailedEvent( + a2a_task_id=task_id, a2a_context_id=context_id, error=str(e) + ), + ) + raise ServerError( + error=InternalError(message=f"Task execution failed: {e}") + ) from e + + +async def cancel( + context: RequestContext, + event_queue: EventQueue, +) -> A2ATask | None: + """Cancel an A2A task. + + Publishes a cancel event that the cancellable decorator listens for. + + Args: + context: The A2A request context containing task information. + event_queue: The event queue for sending the cancellation status. + + Returns: + The canceled task with updated status. + """ + task_id = context.task_id + context_id = context.context_id + if task_id is None or context_id is None: + raise ServerError(InvalidParamsError(message="task_id and context_id required")) + + if context.current_task and context.current_task.status.state in ( + TaskState.completed, + TaskState.failed, + TaskState.canceled, + ): + return context.current_task + + cache = caches.get("default") + + await cache.set(f"cancel:{task_id}", True, ttl=3600) + if not isinstance(cache, SimpleMemoryCache): + await cache.client.publish(f"cancel:{task_id}", "cancel") + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus(state=TaskState.canceled), + final=True, + ) + ) + + if context.current_task: + context.current_task.status = TaskStatus(state=TaskState.canceled) + return context.current_task + return None diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 1c7a653ec..bc964754c 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -17,7 +17,6 @@ from urllib.parse import urlparse from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator from typing_extensions import Self -from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig from crewai.agent.utils import ( ahandle_knowledge_retrieval, apply_training_data, @@ -78,6 +77,14 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.training_handler import CrewTrainingHandler +try: + from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig +except ImportError: + A2AClientConfig = Any + A2AConfig = Any + A2AServerConfig = Any + + if TYPE_CHECKING: from crewai_tools import CodeInterpreterTool diff --git a/lib/crewai/src/crewai/events/event_types.py b/lib/crewai/src/crewai/events/event_types.py index ea00aa9ae..b4479021e 100644 --- a/lib/crewai/src/crewai/events/event_types.py +++ b/lib/crewai/src/crewai/events/event_types.py @@ -1,3 +1,20 @@ +from crewai.events.types.a2a_events import ( + A2AConversationCompletedEvent, + A2AConversationStartedEvent, + A2ADelegationCompletedEvent, + A2ADelegationStartedEvent, + A2AMessageSentEvent, + A2APollingStartedEvent, + A2APollingStatusEvent, + A2APushNotificationReceivedEvent, + A2APushNotificationRegisteredEvent, + A2APushNotificationTimeoutEvent, + A2AResponseReceivedEvent, + A2AServerTaskCanceledEvent, + A2AServerTaskCompletedEvent, + A2AServerTaskFailedEvent, + A2AServerTaskStartedEvent, +) from crewai.events.types.agent_events import ( AgentExecutionCompletedEvent, AgentExecutionErrorEvent, @@ -76,7 +93,22 @@ from crewai.events.types.tool_usage_events import ( EventTypes = ( - CrewKickoffStartedEvent + A2AConversationCompletedEvent + | A2AConversationStartedEvent + | A2ADelegationCompletedEvent + | A2ADelegationStartedEvent + | A2AMessageSentEvent + | A2APollingStartedEvent + | A2APollingStatusEvent + | A2APushNotificationReceivedEvent + | A2APushNotificationRegisteredEvent + | A2APushNotificationTimeoutEvent + | A2AResponseReceivedEvent + | A2AServerTaskCanceledEvent + | A2AServerTaskCompletedEvent + | A2AServerTaskFailedEvent + | A2AServerTaskStartedEvent + | CrewKickoffStartedEvent | CrewKickoffCompletedEvent | CrewKickoffFailedEvent | CrewTestStartedEvent diff --git a/lib/crewai/src/crewai/events/types/a2a_events.py b/lib/crewai/src/crewai/events/types/a2a_events.py index 87eb6040b..9f414b333 100644 --- a/lib/crewai/src/crewai/events/types/a2a_events.py +++ b/lib/crewai/src/crewai/events/types/a2a_events.py @@ -210,3 +210,37 @@ class A2APushNotificationTimeoutEvent(A2AEventBase): type: str = "a2a_push_notification_timeout" task_id: str timeout_seconds: float + + +class A2AServerTaskStartedEvent(A2AEventBase): + """Event emitted when an A2A server task execution starts.""" + + type: str = "a2a_server_task_started" + a2a_task_id: str + a2a_context_id: str + + +class A2AServerTaskCompletedEvent(A2AEventBase): + """Event emitted when an A2A server task execution completes.""" + + type: str = "a2a_server_task_completed" + a2a_task_id: str + a2a_context_id: str + result: str + + +class A2AServerTaskCanceledEvent(A2AEventBase): + """Event emitted when an A2A server task execution is canceled.""" + + type: str = "a2a_server_task_canceled" + a2a_task_id: str + a2a_context_id: str + + +class A2AServerTaskFailedEvent(A2AEventBase): + """Event emitted when an A2A server task execution fails.""" + + type: str = "a2a_server_task_failed" + a2a_task_id: str + a2a_context_id: str + error: str diff --git a/lib/crewai/src/crewai/types/utils.py b/lib/crewai/src/crewai/types/utils.py index f46f9795c..afc9f5329 100644 --- a/lib/crewai/src/crewai/types/utils.py +++ b/lib/crewai/src/crewai/types/utils.py @@ -1,8 +1,6 @@ """Utilities for creating and manipulating types.""" -from typing import Annotated, Final, Literal - -from typing_extensions import TypeAliasType +from typing import Annotated, Final, Literal, cast _DYNAMIC_LITERAL_ALIAS: Final[Literal["DynamicLiteral"]] = "DynamicLiteral" @@ -20,6 +18,11 @@ def create_literals_from_strings( Returns: Literal type for each A2A agent ID + + Raises: + ValueError: If values is empty (Literal requires at least one value) """ unique_values: tuple[str, ...] = tuple(dict.fromkeys(values)) - return Literal.__getitem__(unique_values) + if not unique_values: + raise ValueError("Cannot create Literal type from empty values") + return cast(type, Literal.__getitem__(unique_values)) diff --git a/lib/crewai/tests/a2a/utils/test_task.py b/lib/crewai/tests/a2a/utils/test_task.py new file mode 100644 index 000000000..0c01a0afc --- /dev/null +++ b/lib/crewai/tests/a2a/utils/test_task.py @@ -0,0 +1,370 @@ +"""Tests for A2A task utilities.""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from a2a.server.agent_execution import RequestContext +from a2a.server.events import EventQueue +from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus + +from crewai.a2a.utils.task import cancel, cancellable, execute + + +@pytest.fixture +def mock_agent() -> MagicMock: + """Create a mock CrewAI agent.""" + agent = MagicMock() + agent.role = "Test Agent" + agent.tools = [] + agent.aexecute_task = AsyncMock(return_value="Task completed successfully") + return agent + + +@pytest.fixture +def mock_task() -> MagicMock: + """Create a mock Task.""" + return MagicMock() + + +@pytest.fixture +def mock_context() -> MagicMock: + """Create a mock RequestContext.""" + context = MagicMock(spec=RequestContext) + context.task_id = "test-task-123" + context.context_id = "test-context-456" + context.get_user_input.return_value = "Test user message" + context.message = MagicMock(spec=Message) + context.current_task = None + return context + + +@pytest.fixture +def mock_event_queue() -> AsyncMock: + """Create a mock EventQueue.""" + queue = AsyncMock(spec=EventQueue) + queue.enqueue_event = AsyncMock() + return queue + + +@pytest_asyncio.fixture(autouse=True) +async def clear_cache(mock_context: MagicMock) -> None: + """Clear cancel flag from cache before each test.""" + from aiocache import caches + + cache = caches.get("default") + await cache.delete(f"cancel:{mock_context.task_id}") + + +class TestCancellableDecorator: + """Tests for the cancellable decorator.""" + + @pytest.mark.asyncio + async def test_executes_function_without_context(self) -> None: + """Function executes normally when no RequestContext is provided.""" + call_count = 0 + + @cancellable + async def my_func(value: int) -> int: + nonlocal call_count + call_count += 1 + return value * 2 + + result = await my_func(5) + + assert result == 10 + assert call_count == 1 + + @pytest.mark.asyncio + async def test_executes_function_with_context(self, mock_context: MagicMock) -> None: + """Function executes normally with RequestContext when not cancelled.""" + @cancellable + async def my_func(context: RequestContext) -> str: + await asyncio.sleep(0.01) + return "completed" + + result = await my_func(mock_context) + + assert result == "completed" + + @pytest.mark.asyncio + async def test_cancellation_raises_cancelled_error( + self, mock_context: MagicMock + ) -> None: + """Function raises CancelledError when cancel flag is set.""" + from aiocache import caches + + cache = caches.get("default") + + @cancellable + async def slow_func(context: RequestContext) -> str: + await asyncio.sleep(1.0) + return "should not reach" + + await cache.set(f"cancel:{mock_context.task_id}", True) + + with pytest.raises(asyncio.CancelledError): + await slow_func(mock_context) + + @pytest.mark.asyncio + async def test_cleanup_removes_cancel_flag(self, mock_context: MagicMock) -> None: + """Cancel flag is cleaned up after execution.""" + from aiocache import caches + + cache = caches.get("default") + + @cancellable + async def quick_func(context: RequestContext) -> str: + return "done" + + await quick_func(mock_context) + + flag = await cache.get(f"cancel:{mock_context.task_id}") + assert flag is None + + @pytest.mark.asyncio + async def test_extracts_context_from_kwargs(self, mock_context: MagicMock) -> None: + """Context can be passed as keyword argument.""" + @cancellable + async def my_func(value: int, context: RequestContext | None = None) -> int: + return value + 1 + + result = await my_func(10, context=mock_context) + + assert result == 11 + + +class TestExecute: + """Tests for the execute function.""" + + @pytest.mark.asyncio + async def test_successful_execution( + self, + mock_agent: MagicMock, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + mock_task: MagicMock, + ) -> None: + """Execute completes successfully and enqueues completed task.""" + with ( + patch("crewai.a2a.utils.task.Task", return_value=mock_task), + patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + ): + await execute(mock_agent, mock_context, mock_event_queue) + + mock_agent.aexecute_task.assert_called_once() + mock_event_queue.enqueue_event.assert_called_once() + assert mock_bus.emit.call_count == 2 + + @pytest.mark.asyncio + async def test_emits_started_event( + self, + mock_agent: MagicMock, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + mock_task: MagicMock, + ) -> None: + """Execute emits A2AServerTaskStartedEvent.""" + with ( + patch("crewai.a2a.utils.task.Task", return_value=mock_task), + patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + ): + await execute(mock_agent, mock_context, mock_event_queue) + + first_call = mock_bus.emit.call_args_list[0] + event = first_call[0][1] + + assert event.type == "a2a_server_task_started" + assert event.a2a_task_id == mock_context.task_id + assert event.a2a_context_id == mock_context.context_id + + @pytest.mark.asyncio + async def test_emits_completed_event( + self, + mock_agent: MagicMock, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + mock_task: MagicMock, + ) -> None: + """Execute emits A2AServerTaskCompletedEvent on success.""" + with ( + patch("crewai.a2a.utils.task.Task", return_value=mock_task), + patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + ): + await execute(mock_agent, mock_context, mock_event_queue) + + second_call = mock_bus.emit.call_args_list[1] + event = second_call[0][1] + + assert event.type == "a2a_server_task_completed" + assert event.a2a_task_id == mock_context.task_id + assert event.result == "Task completed successfully" + + @pytest.mark.asyncio + async def test_emits_failed_event_on_exception( + self, + mock_agent: MagicMock, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + mock_task: MagicMock, + ) -> None: + """Execute emits A2AServerTaskFailedEvent on exception.""" + mock_agent.aexecute_task = AsyncMock(side_effect=ValueError("Test error")) + + with ( + patch("crewai.a2a.utils.task.Task", return_value=mock_task), + patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + ): + with pytest.raises(Exception): + await execute(mock_agent, mock_context, mock_event_queue) + + failed_call = mock_bus.emit.call_args_list[1] + event = failed_call[0][1] + + assert event.type == "a2a_server_task_failed" + assert "Test error" in event.error + + @pytest.mark.asyncio + async def test_emits_canceled_event_on_cancellation( + self, + mock_agent: MagicMock, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + mock_task: MagicMock, + ) -> None: + """Execute emits A2AServerTaskCanceledEvent on CancelledError.""" + mock_agent.aexecute_task = AsyncMock(side_effect=asyncio.CancelledError()) + + with ( + patch("crewai.a2a.utils.task.Task", return_value=mock_task), + patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + ): + with pytest.raises(asyncio.CancelledError): + await execute(mock_agent, mock_context, mock_event_queue) + + canceled_call = mock_bus.emit.call_args_list[1] + event = canceled_call[0][1] + + assert event.type == "a2a_server_task_canceled" + assert event.a2a_task_id == mock_context.task_id + + +class TestCancel: + """Tests for the cancel function.""" + + @pytest.mark.asyncio + async def test_sets_cancel_flag_in_cache( + self, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + ) -> None: + """Cancel sets the cancel flag in cache.""" + from aiocache import caches + + cache = caches.get("default") + + await cancel(mock_context, mock_event_queue) + + flag = await cache.get(f"cancel:{mock_context.task_id}") + assert flag is True + + @pytest.mark.asyncio + async def test_enqueues_task_status_update_event( + self, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + ) -> None: + """Cancel enqueues TaskStatusUpdateEvent with canceled state.""" + await cancel(mock_context, mock_event_queue) + + mock_event_queue.enqueue_event.assert_called_once() + event = mock_event_queue.enqueue_event.call_args[0][0] + + assert event.task_id == mock_context.task_id + assert event.context_id == mock_context.context_id + assert event.status.state == TaskState.canceled + assert event.final is True + + @pytest.mark.asyncio + async def test_returns_none_when_no_current_task( + self, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + ) -> None: + """Cancel returns None when context has no current_task.""" + mock_context.current_task = None + + result = await cancel(mock_context, mock_event_queue) + + assert result is None + + @pytest.mark.asyncio + async def test_returns_updated_task_when_current_task_exists( + self, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + ) -> None: + """Cancel returns updated task when context has current_task.""" + current_task = MagicMock(spec=A2ATask) + current_task.status = TaskStatus(state=TaskState.working) + mock_context.current_task = current_task + + result = await cancel(mock_context, mock_event_queue) + + assert result is current_task + assert result.status.state == TaskState.canceled + + @pytest.mark.asyncio + async def test_cleanup_after_cancel( + self, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + ) -> None: + """Cancel flag persists for cancellable decorator to detect.""" + from aiocache import caches + + cache = caches.get("default") + + await cancel(mock_context, mock_event_queue) + + flag = await cache.get(f"cancel:{mock_context.task_id}") + assert flag is True + + await cache.delete(f"cancel:{mock_context.task_id}") + + +class TestExecuteAndCancelIntegration: + """Integration tests for execute and cancel working together.""" + + @pytest.mark.asyncio + async def test_cancel_stops_running_execute( + self, + mock_agent: MagicMock, + mock_context: MagicMock, + mock_event_queue: AsyncMock, + mock_task: MagicMock, + ) -> None: + """Calling cancel stops a running execute.""" + async def slow_task(**kwargs: Any) -> str: + await asyncio.sleep(2.0) + return "should not complete" + + mock_agent.aexecute_task = slow_task + + with ( + patch("crewai.a2a.utils.task.Task", return_value=mock_task), + patch("crewai.a2a.utils.task.crewai_event_bus"), + ): + execute_task = asyncio.create_task( + execute(mock_agent, mock_context, mock_event_queue) + ) + + await asyncio.sleep(0.1) + await cancel(mock_context, mock_event_queue) + + with pytest.raises(asyncio.CancelledError): + await execute_task \ No newline at end of file