feat: a2a task execution utilities

This commit is contained in:
Greyson LaLonde
2026-01-14 22:56:17 -05:00
committed by GitHub
parent 641c336b2c
commit 6a19b0a279
8 changed files with 759 additions and 17 deletions

View File

@@ -8,14 +8,6 @@ from __future__ import annotations
from importlib.metadata import version
from typing import Any, ClassVar, Literal
from a2a.types import (
AgentCapabilities,
AgentCardSignature,
AgentInterface,
AgentProvider,
AgentSkill,
SecurityScheme,
)
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import deprecated
@@ -24,8 +16,24 @@ from crewai.a2a.types import TransportType, Url
try:
from a2a.types import (
AgentCapabilities,
AgentCardSignature,
AgentInterface,
AgentProvider,
AgentSkill,
SecurityScheme,
)
from crewai.a2a.updates import UpdateConfig
except ImportError:
UpdateConfig = Any
AgentCapabilities = Any
AgentCardSignature = Any
AgentInterface = Any
AgentProvider = Any
SecurityScheme = Any
AgentSkill = Any
UpdateConfig = Any # type: ignore[misc,assignment]

View File

@@ -14,15 +14,19 @@ A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel]:
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None:
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
Args:
agent_ids: List of available A2A agent IDs.
Returns:
Dynamically created Pydantic model with Literal-constrained a2a_ids field.
Dynamically created Pydantic model with Literal-constrained a2a_ids field,
or None if agent_ids is empty.
"""
if not agent_ids:
return None
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
return create_model(
@@ -83,7 +87,7 @@ def extract_a2a_agent_ids_from_config(
def get_a2a_agents_and_response_model(
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], type[BaseModel]]:
) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]:
"""Get A2A agent configs and response model.
Args:

View File

@@ -0,0 +1,284 @@
"""A2A task utilities for server-side task management."""
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from functools import wraps
import logging
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import (
InternalError,
InvalidParamsError,
Message,
Task as A2ATask,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils import new_agent_text_message, new_text_artifact
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AServerTaskCanceledEvent,
A2AServerTaskCompletedEvent,
A2AServerTaskFailedEvent,
A2AServerTaskStartedEvent,
)
from crewai.task import Task
if TYPE_CHECKING:
from crewai.agent import Agent
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
def _parse_redis_url(url: str) -> dict[str, Any]:
from urllib.parse import urlparse
parsed = urlparse(url)
config: dict[str, Any] = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
}
if parsed.path and parsed.path != "/":
try:
config["db"] = int(parsed.path.lstrip("/"))
except ValueError:
pass
if parsed.password:
config["password"] = parsed.password
return config
_redis_url = os.environ.get("REDIS_URL")
caches.set_config(
{
"default": _parse_redis_url(_redis_url)
if _redis_url
else {
"cache": "aiocache.SimpleMemoryCache",
}
}
)
def cancellable(
fn: Callable[P, Coroutine[Any, Any, T]],
) -> Callable[P, Coroutine[Any, Any, T]]:
"""Decorator that enables cancellation for A2A task execution.
Runs a cancellation watcher concurrently with the wrapped function.
When a cancel event is published, the execution is cancelled.
Args:
fn: The async function to wrap.
Returns:
Wrapped function with cancellation support.
"""
@wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrap function with cancellation monitoring."""
context: RequestContext | None = None
for arg in args:
if isinstance(arg, RequestContext):
context = arg
break
if context is None:
context = cast(RequestContext | None, kwargs.get("context"))
if context is None:
return await fn(*args, **kwargs)
task_id = context.task_id
cache = caches.get("default")
async def poll_for_cancel() -> bool:
"""Poll cache for cancellation flag."""
while True:
if await cache.get(f"cancel:{task_id}"):
return True
await asyncio.sleep(0.1)
async def watch_for_cancel() -> bool:
"""Watch for cancellation events via pub/sub or polling."""
if isinstance(cache, SimpleMemoryCache):
return await poll_for_cancel()
try:
client = cache.client
pubsub = client.pubsub()
await pubsub.subscribe(f"cancel:{task_id}")
async for message in pubsub.listen():
if message["type"] == "message":
return True
except Exception as e:
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
return await poll_for_cancel()
return False
execute_task = asyncio.create_task(fn(*args, **kwargs))
cancel_watch = asyncio.create_task(watch_for_cancel())
try:
done, _ = await asyncio.wait(
[execute_task, cancel_watch],
return_when=asyncio.FIRST_COMPLETED,
)
if cancel_watch in done:
execute_task.cancel()
try:
await execute_task
except asyncio.CancelledError:
pass
raise asyncio.CancelledError(f"Task {task_id} was cancelled")
cancel_watch.cancel()
return execute_task.result()
finally:
await cache.delete(f"cancel:{task_id}")
return wrapper
@cancellable
async def execute(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
) -> None:
"""Execute an A2A task using a CrewAI agent.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
TODOs:
* need to impl both of structured output and file inputs, depends on `file_inputs` for
`crewai.task.Task`, pass the below two to Task. both utils in `a2a.utils.parts`
* structured outputs ingestion, `structured_inputs = get_data_parts(parts=context.message.parts)`
* file inputs ingestion, `file_inputs = get_file_parts(parts=context.message.parts)`
"""
user_message = context.get_user_input()
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
msg = "task_id and context_id are required"
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(a2a_task_id="", a2a_context_id="", error=msg),
)
raise ServerError(InvalidParamsError(message=msg)) from None
task = Task(
description=user_message,
expected_output="Response to the user's request",
agent=agent,
)
crewai_event_bus.emit(
agent,
A2AServerTaskStartedEvent(a2a_task_id=task_id, a2a_context_id=context_id),
)
try:
result = await agent.aexecute_task(task=task, tools=agent.tools)
result_str = str(result)
history: list[Message] = [context.message] if context.message else []
history.append(new_agent_text_message(result_str, context_id, task_id))
await event_queue.enqueue_event(
A2ATask(
id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.input_required),
artifacts=[new_text_artifact(result_str, f"result_{task_id}")],
history=history,
)
)
crewai_event_bus.emit(
agent,
A2AServerTaskCompletedEvent(
a2a_task_id=task_id, a2a_context_id=context_id, result=str(result)
),
)
except asyncio.CancelledError:
crewai_event_bus.emit(
agent,
A2AServerTaskCanceledEvent(a2a_task_id=task_id, a2a_context_id=context_id),
)
raise
except Exception as e:
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(
a2a_task_id=task_id, a2a_context_id=context_id, error=str(e)
),
)
raise ServerError(
error=InternalError(message=f"Task execution failed: {e}")
) from e
async def cancel(
context: RequestContext,
event_queue: EventQueue,
) -> A2ATask | None:
"""Cancel an A2A task.
Publishes a cancel event that the cancellable decorator listens for.
Args:
context: The A2A request context containing task information.
event_queue: The event queue for sending the cancellation status.
Returns:
The canceled task with updated status.
"""
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
if context.current_task and context.current_task.status.state in (
TaskState.completed,
TaskState.failed,
TaskState.canceled,
):
return context.current_task
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.canceled),
final=True,
)
)
if context.current_task:
context.current_task.status = TaskStatus(state=TaskState.canceled)
return context.current_task
return None

View File

@@ -17,7 +17,6 @@ from urllib.parse import urlparse
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.agent.utils import (
ahandle_knowledge_retrieval,
apply_training_data,
@@ -78,6 +77,14 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
try:
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
except ImportError:
A2AClientConfig = Any
A2AConfig = Any
A2AServerConfig = Any
if TYPE_CHECKING:
from crewai_tools import CodeInterpreterTool

View File

@@ -1,3 +1,20 @@
from crewai.events.types.a2a_events import (
A2AConversationCompletedEvent,
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
A2APollingStartedEvent,
A2APollingStatusEvent,
A2APushNotificationReceivedEvent,
A2APushNotificationRegisteredEvent,
A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
A2AServerTaskCanceledEvent,
A2AServerTaskCompletedEvent,
A2AServerTaskFailedEvent,
A2AServerTaskStartedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
@@ -76,7 +93,22 @@ from crewai.events.types.tool_usage_events import (
EventTypes = (
CrewKickoffStartedEvent
A2AConversationCompletedEvent
| A2AConversationStartedEvent
| A2ADelegationCompletedEvent
| A2ADelegationStartedEvent
| A2AMessageSentEvent
| A2APollingStartedEvent
| A2APollingStatusEvent
| A2APushNotificationReceivedEvent
| A2APushNotificationRegisteredEvent
| A2APushNotificationTimeoutEvent
| A2AResponseReceivedEvent
| A2AServerTaskCanceledEvent
| A2AServerTaskCompletedEvent
| A2AServerTaskFailedEvent
| A2AServerTaskStartedEvent
| CrewKickoffStartedEvent
| CrewKickoffCompletedEvent
| CrewKickoffFailedEvent
| CrewTestStartedEvent

View File

@@ -210,3 +210,37 @@ class A2APushNotificationTimeoutEvent(A2AEventBase):
type: str = "a2a_push_notification_timeout"
task_id: str
timeout_seconds: float
class A2AServerTaskStartedEvent(A2AEventBase):
"""Event emitted when an A2A server task execution starts."""
type: str = "a2a_server_task_started"
a2a_task_id: str
a2a_context_id: str
class A2AServerTaskCompletedEvent(A2AEventBase):
"""Event emitted when an A2A server task execution completes."""
type: str = "a2a_server_task_completed"
a2a_task_id: str
a2a_context_id: str
result: str
class A2AServerTaskCanceledEvent(A2AEventBase):
"""Event emitted when an A2A server task execution is canceled."""
type: str = "a2a_server_task_canceled"
a2a_task_id: str
a2a_context_id: str
class A2AServerTaskFailedEvent(A2AEventBase):
"""Event emitted when an A2A server task execution fails."""
type: str = "a2a_server_task_failed"
a2a_task_id: str
a2a_context_id: str
error: str

View File

@@ -1,8 +1,6 @@
"""Utilities for creating and manipulating types."""
from typing import Annotated, Final, Literal
from typing_extensions import TypeAliasType
from typing import Annotated, Final, Literal, cast
_DYNAMIC_LITERAL_ALIAS: Final[Literal["DynamicLiteral"]] = "DynamicLiteral"
@@ -20,6 +18,11 @@ def create_literals_from_strings(
Returns:
Literal type for each A2A agent ID
Raises:
ValueError: If values is empty (Literal requires at least one value)
"""
unique_values: tuple[str, ...] = tuple(dict.fromkeys(values))
return Literal.__getitem__(unique_values)
if not unique_values:
raise ValueError("Cannot create Literal type from empty values")
return cast(type, Literal.__getitem__(unique_values))

View File

@@ -0,0 +1,370 @@
"""Tests for A2A task utilities."""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus
from crewai.a2a.utils.task import cancel, cancellable, execute
@pytest.fixture
def mock_agent() -> MagicMock:
"""Create a mock CrewAI agent."""
agent = MagicMock()
agent.role = "Test Agent"
agent.tools = []
agent.aexecute_task = AsyncMock(return_value="Task completed successfully")
return agent
@pytest.fixture
def mock_task() -> MagicMock:
"""Create a mock Task."""
return MagicMock()
@pytest.fixture
def mock_context() -> MagicMock:
"""Create a mock RequestContext."""
context = MagicMock(spec=RequestContext)
context.task_id = "test-task-123"
context.context_id = "test-context-456"
context.get_user_input.return_value = "Test user message"
context.message = MagicMock(spec=Message)
context.current_task = None
return context
@pytest.fixture
def mock_event_queue() -> AsyncMock:
"""Create a mock EventQueue."""
queue = AsyncMock(spec=EventQueue)
queue.enqueue_event = AsyncMock()
return queue
@pytest_asyncio.fixture(autouse=True)
async def clear_cache(mock_context: MagicMock) -> None:
"""Clear cancel flag from cache before each test."""
from aiocache import caches
cache = caches.get("default")
await cache.delete(f"cancel:{mock_context.task_id}")
class TestCancellableDecorator:
"""Tests for the cancellable decorator."""
@pytest.mark.asyncio
async def test_executes_function_without_context(self) -> None:
"""Function executes normally when no RequestContext is provided."""
call_count = 0
@cancellable
async def my_func(value: int) -> int:
nonlocal call_count
call_count += 1
return value * 2
result = await my_func(5)
assert result == 10
assert call_count == 1
@pytest.mark.asyncio
async def test_executes_function_with_context(self, mock_context: MagicMock) -> None:
"""Function executes normally with RequestContext when not cancelled."""
@cancellable
async def my_func(context: RequestContext) -> str:
await asyncio.sleep(0.01)
return "completed"
result = await my_func(mock_context)
assert result == "completed"
@pytest.mark.asyncio
async def test_cancellation_raises_cancelled_error(
self, mock_context: MagicMock
) -> None:
"""Function raises CancelledError when cancel flag is set."""
from aiocache import caches
cache = caches.get("default")
@cancellable
async def slow_func(context: RequestContext) -> str:
await asyncio.sleep(1.0)
return "should not reach"
await cache.set(f"cancel:{mock_context.task_id}", True)
with pytest.raises(asyncio.CancelledError):
await slow_func(mock_context)
@pytest.mark.asyncio
async def test_cleanup_removes_cancel_flag(self, mock_context: MagicMock) -> None:
"""Cancel flag is cleaned up after execution."""
from aiocache import caches
cache = caches.get("default")
@cancellable
async def quick_func(context: RequestContext) -> str:
return "done"
await quick_func(mock_context)
flag = await cache.get(f"cancel:{mock_context.task_id}")
assert flag is None
@pytest.mark.asyncio
async def test_extracts_context_from_kwargs(self, mock_context: MagicMock) -> None:
"""Context can be passed as keyword argument."""
@cancellable
async def my_func(value: int, context: RequestContext | None = None) -> int:
return value + 1
result = await my_func(10, context=mock_context)
assert result == 11
class TestExecute:
"""Tests for the execute function."""
@pytest.mark.asyncio
async def test_successful_execution(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute completes successfully and enqueues completed task."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
mock_agent.aexecute_task.assert_called_once()
mock_event_queue.enqueue_event.assert_called_once()
assert mock_bus.emit.call_count == 2
@pytest.mark.asyncio
async def test_emits_started_event(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskStartedEvent."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
first_call = mock_bus.emit.call_args_list[0]
event = first_call[0][1]
assert event.type == "a2a_server_task_started"
assert event.a2a_task_id == mock_context.task_id
assert event.a2a_context_id == mock_context.context_id
@pytest.mark.asyncio
async def test_emits_completed_event(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskCompletedEvent on success."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
second_call = mock_bus.emit.call_args_list[1]
event = second_call[0][1]
assert event.type == "a2a_server_task_completed"
assert event.a2a_task_id == mock_context.task_id
assert event.result == "Task completed successfully"
@pytest.mark.asyncio
async def test_emits_failed_event_on_exception(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskFailedEvent on exception."""
mock_agent.aexecute_task = AsyncMock(side_effect=ValueError("Test error"))
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(Exception):
await execute(mock_agent, mock_context, mock_event_queue)
failed_call = mock_bus.emit.call_args_list[1]
event = failed_call[0][1]
assert event.type == "a2a_server_task_failed"
assert "Test error" in event.error
@pytest.mark.asyncio
async def test_emits_canceled_event_on_cancellation(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskCanceledEvent on CancelledError."""
mock_agent.aexecute_task = AsyncMock(side_effect=asyncio.CancelledError())
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(asyncio.CancelledError):
await execute(mock_agent, mock_context, mock_event_queue)
canceled_call = mock_bus.emit.call_args_list[1]
event = canceled_call[0][1]
assert event.type == "a2a_server_task_canceled"
assert event.a2a_task_id == mock_context.task_id
class TestCancel:
"""Tests for the cancel function."""
@pytest.mark.asyncio
async def test_sets_cancel_flag_in_cache(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel sets the cancel flag in cache."""
from aiocache import caches
cache = caches.get("default")
await cancel(mock_context, mock_event_queue)
flag = await cache.get(f"cancel:{mock_context.task_id}")
assert flag is True
@pytest.mark.asyncio
async def test_enqueues_task_status_update_event(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel enqueues TaskStatusUpdateEvent with canceled state."""
await cancel(mock_context, mock_event_queue)
mock_event_queue.enqueue_event.assert_called_once()
event = mock_event_queue.enqueue_event.call_args[0][0]
assert event.task_id == mock_context.task_id
assert event.context_id == mock_context.context_id
assert event.status.state == TaskState.canceled
assert event.final is True
@pytest.mark.asyncio
async def test_returns_none_when_no_current_task(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel returns None when context has no current_task."""
mock_context.current_task = None
result = await cancel(mock_context, mock_event_queue)
assert result is None
@pytest.mark.asyncio
async def test_returns_updated_task_when_current_task_exists(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel returns updated task when context has current_task."""
current_task = MagicMock(spec=A2ATask)
current_task.status = TaskStatus(state=TaskState.working)
mock_context.current_task = current_task
result = await cancel(mock_context, mock_event_queue)
assert result is current_task
assert result.status.state == TaskState.canceled
@pytest.mark.asyncio
async def test_cleanup_after_cancel(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel flag persists for cancellable decorator to detect."""
from aiocache import caches
cache = caches.get("default")
await cancel(mock_context, mock_event_queue)
flag = await cache.get(f"cancel:{mock_context.task_id}")
assert flag is True
await cache.delete(f"cancel:{mock_context.task_id}")
class TestExecuteAndCancelIntegration:
"""Integration tests for execute and cancel working together."""
@pytest.mark.asyncio
async def test_cancel_stops_running_execute(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Calling cancel stops a running execute."""
async def slow_task(**kwargs: Any) -> str:
await asyncio.sleep(2.0)
return "should not complete"
mock_agent.aexecute_task = slow_task
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus"),
):
execute_task = asyncio.create_task(
execute(mock_agent, mock_context, mock_event_queue)
)
await asyncio.sleep(0.1)
await cancel(mock_context, mock_event_queue)
with pytest.raises(asyncio.CancelledError):
await execute_task