Merge branch 'main' of github.com:crewAIInc/crewAI into lorenze/enh-decouple-executor-from-crew

This commit is contained in:
lorenzejay
2026-01-15 09:12:38 -08:00
42 changed files with 3981 additions and 1080 deletions

View File

@@ -0,0 +1,325 @@
"""Tests for A2A agent card utilities."""
from __future__ import annotations
from a2a.types import AgentCard, AgentSkill
from crewai import Agent
from crewai.a2a.config import A2AClientConfig, A2AServerConfig
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
class TestInjectA2AServerMethods:
"""Tests for inject_a2a_server_methods function."""
def test_agent_with_server_config_gets_to_agent_card_method(self) -> None:
"""Agent with A2AServerConfig should have to_agent_card method injected."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
assert hasattr(agent, "to_agent_card")
assert callable(agent.to_agent_card)
def test_agent_without_server_config_no_injection(self) -> None:
"""Agent without A2AServerConfig should not get to_agent_card method."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AClientConfig(endpoint="http://example.com"),
)
assert not hasattr(agent, "to_agent_card")
def test_agent_without_a2a_no_injection(self) -> None:
"""Agent without any a2a config should not get to_agent_card method."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
assert not hasattr(agent, "to_agent_card")
def test_agent_with_mixed_configs_gets_injection(self) -> None:
"""Agent with list containing A2AServerConfig should get to_agent_card."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=[
A2AClientConfig(endpoint="http://example.com"),
A2AServerConfig(name="My Agent"),
],
)
assert hasattr(agent, "to_agent_card")
assert callable(agent.to_agent_card)
def test_manual_injection_on_plain_agent(self) -> None:
"""inject_a2a_server_methods should work when called manually."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
# Manually set server config and inject
object.__setattr__(agent, "a2a", A2AServerConfig())
inject_a2a_server_methods(agent)
assert hasattr(agent, "to_agent_card")
assert callable(agent.to_agent_card)
class TestToAgentCard:
"""Tests for the injected to_agent_card method."""
def test_returns_agent_card(self) -> None:
"""to_agent_card should return an AgentCard instance."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert isinstance(card, AgentCard)
def test_uses_agent_role_as_name(self) -> None:
"""AgentCard name should default to agent role."""
agent = Agent(
role="Data Analyst",
goal="Analyze data",
backstory="Expert analyst",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.name == "Data Analyst"
def test_uses_server_config_name(self) -> None:
"""AgentCard name should prefer A2AServerConfig.name over role."""
agent = Agent(
role="Data Analyst",
goal="Analyze data",
backstory="Expert analyst",
a2a=A2AServerConfig(name="Custom Agent Name"),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.name == "Custom Agent Name"
def test_uses_goal_as_description(self) -> None:
"""AgentCard description should include agent goal."""
agent = Agent(
role="Test Agent",
goal="Accomplish important tasks",
backstory="Has extensive experience",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert "Accomplish important tasks" in card.description
def test_uses_server_config_description(self) -> None:
"""AgentCard description should prefer A2AServerConfig.description."""
agent = Agent(
role="Test Agent",
goal="Accomplish important tasks",
backstory="Has extensive experience",
a2a=A2AServerConfig(description="Custom description"),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.description == "Custom description"
def test_uses_provided_url(self) -> None:
"""AgentCard url should use the provided URL."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://my-server.com:9000")
assert card.url == "http://my-server.com:9000"
def test_uses_server_config_url(self) -> None:
"""AgentCard url should prefer A2AServerConfig.url over provided URL."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(url="http://configured-url.com"),
)
card = agent.to_agent_card("http://fallback-url.com")
assert card.url == "http://configured-url.com/"
def test_generates_default_skill(self) -> None:
"""AgentCard should have at least one skill based on agent role."""
agent = Agent(
role="Research Assistant",
goal="Help with research",
backstory="Skilled researcher",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert len(card.skills) >= 1
skill = card.skills[0]
assert skill.name == "Research Assistant"
assert skill.description == "Help with research"
def test_uses_server_config_skills(self) -> None:
"""AgentCard skills should prefer A2AServerConfig.skills."""
custom_skill = AgentSkill(
id="custom-skill",
name="Custom Skill",
description="A custom skill",
tags=["custom"],
)
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(skills=[custom_skill]),
)
card = agent.to_agent_card("http://localhost:8000")
assert len(card.skills) == 1
assert card.skills[0].id == "custom-skill"
assert card.skills[0].name == "Custom Skill"
def test_includes_custom_version(self) -> None:
"""AgentCard should include version from A2AServerConfig."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(version="2.0.0"),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.version == "2.0.0"
def test_default_version(self) -> None:
"""AgentCard should have default version 1.0.0."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.version == "1.0.0"
class TestAgentCardJsonStructure:
"""Tests for the JSON structure of AgentCard."""
def test_json_has_required_fields(self) -> None:
"""AgentCard JSON should contain all required A2A protocol fields."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump()
assert "name" in json_data
assert "description" in json_data
assert "url" in json_data
assert "version" in json_data
assert "skills" in json_data
assert "capabilities" in json_data
assert "defaultInputModes" in json_data
assert "defaultOutputModes" in json_data
def test_json_skills_structure(self) -> None:
"""Each skill in JSON should have required fields."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump()
assert len(json_data["skills"]) >= 1
skill = json_data["skills"][0]
assert "id" in skill
assert "name" in skill
assert "description" in skill
assert "tags" in skill
def test_json_capabilities_structure(self) -> None:
"""Capabilities in JSON should have expected fields."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump()
capabilities = json_data["capabilities"]
assert "streaming" in capabilities
assert "pushNotifications" in capabilities
def test_json_serializable(self) -> None:
"""AgentCard should be JSON serializable."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_str = card.model_dump_json()
assert isinstance(json_str, str)
assert "Test Agent" in json_str
assert "http://localhost:8000" in json_str
def test_json_excludes_none_values(self) -> None:
"""AgentCard JSON with exclude_none should omit None fields."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump(exclude_none=True)
assert "provider" not in json_data
assert "documentationUrl" not in json_data
assert "iconUrl" not in json_data

View File

@@ -0,0 +1,370 @@
"""Tests for A2A task utilities."""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus
from crewai.a2a.utils.task import cancel, cancellable, execute
@pytest.fixture
def mock_agent() -> MagicMock:
"""Create a mock CrewAI agent."""
agent = MagicMock()
agent.role = "Test Agent"
agent.tools = []
agent.aexecute_task = AsyncMock(return_value="Task completed successfully")
return agent
@pytest.fixture
def mock_task() -> MagicMock:
"""Create a mock Task."""
return MagicMock()
@pytest.fixture
def mock_context() -> MagicMock:
"""Create a mock RequestContext."""
context = MagicMock(spec=RequestContext)
context.task_id = "test-task-123"
context.context_id = "test-context-456"
context.get_user_input.return_value = "Test user message"
context.message = MagicMock(spec=Message)
context.current_task = None
return context
@pytest.fixture
def mock_event_queue() -> AsyncMock:
"""Create a mock EventQueue."""
queue = AsyncMock(spec=EventQueue)
queue.enqueue_event = AsyncMock()
return queue
@pytest_asyncio.fixture(autouse=True)
async def clear_cache(mock_context: MagicMock) -> None:
"""Clear cancel flag from cache before each test."""
from aiocache import caches
cache = caches.get("default")
await cache.delete(f"cancel:{mock_context.task_id}")
class TestCancellableDecorator:
"""Tests for the cancellable decorator."""
@pytest.mark.asyncio
async def test_executes_function_without_context(self) -> None:
"""Function executes normally when no RequestContext is provided."""
call_count = 0
@cancellable
async def my_func(value: int) -> int:
nonlocal call_count
call_count += 1
return value * 2
result = await my_func(5)
assert result == 10
assert call_count == 1
@pytest.mark.asyncio
async def test_executes_function_with_context(self, mock_context: MagicMock) -> None:
"""Function executes normally with RequestContext when not cancelled."""
@cancellable
async def my_func(context: RequestContext) -> str:
await asyncio.sleep(0.01)
return "completed"
result = await my_func(mock_context)
assert result == "completed"
@pytest.mark.asyncio
async def test_cancellation_raises_cancelled_error(
self, mock_context: MagicMock
) -> None:
"""Function raises CancelledError when cancel flag is set."""
from aiocache import caches
cache = caches.get("default")
@cancellable
async def slow_func(context: RequestContext) -> str:
await asyncio.sleep(1.0)
return "should not reach"
await cache.set(f"cancel:{mock_context.task_id}", True)
with pytest.raises(asyncio.CancelledError):
await slow_func(mock_context)
@pytest.mark.asyncio
async def test_cleanup_removes_cancel_flag(self, mock_context: MagicMock) -> None:
"""Cancel flag is cleaned up after execution."""
from aiocache import caches
cache = caches.get("default")
@cancellable
async def quick_func(context: RequestContext) -> str:
return "done"
await quick_func(mock_context)
flag = await cache.get(f"cancel:{mock_context.task_id}")
assert flag is None
@pytest.mark.asyncio
async def test_extracts_context_from_kwargs(self, mock_context: MagicMock) -> None:
"""Context can be passed as keyword argument."""
@cancellable
async def my_func(value: int, context: RequestContext | None = None) -> int:
return value + 1
result = await my_func(10, context=mock_context)
assert result == 11
class TestExecute:
"""Tests for the execute function."""
@pytest.mark.asyncio
async def test_successful_execution(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute completes successfully and enqueues completed task."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
mock_agent.aexecute_task.assert_called_once()
mock_event_queue.enqueue_event.assert_called_once()
assert mock_bus.emit.call_count == 2
@pytest.mark.asyncio
async def test_emits_started_event(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskStartedEvent."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
first_call = mock_bus.emit.call_args_list[0]
event = first_call[0][1]
assert event.type == "a2a_server_task_started"
assert event.a2a_task_id == mock_context.task_id
assert event.a2a_context_id == mock_context.context_id
@pytest.mark.asyncio
async def test_emits_completed_event(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskCompletedEvent on success."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
second_call = mock_bus.emit.call_args_list[1]
event = second_call[0][1]
assert event.type == "a2a_server_task_completed"
assert event.a2a_task_id == mock_context.task_id
assert event.result == "Task completed successfully"
@pytest.mark.asyncio
async def test_emits_failed_event_on_exception(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskFailedEvent on exception."""
mock_agent.aexecute_task = AsyncMock(side_effect=ValueError("Test error"))
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(Exception):
await execute(mock_agent, mock_context, mock_event_queue)
failed_call = mock_bus.emit.call_args_list[1]
event = failed_call[0][1]
assert event.type == "a2a_server_task_failed"
assert "Test error" in event.error
@pytest.mark.asyncio
async def test_emits_canceled_event_on_cancellation(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskCanceledEvent on CancelledError."""
mock_agent.aexecute_task = AsyncMock(side_effect=asyncio.CancelledError())
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(asyncio.CancelledError):
await execute(mock_agent, mock_context, mock_event_queue)
canceled_call = mock_bus.emit.call_args_list[1]
event = canceled_call[0][1]
assert event.type == "a2a_server_task_canceled"
assert event.a2a_task_id == mock_context.task_id
class TestCancel:
"""Tests for the cancel function."""
@pytest.mark.asyncio
async def test_sets_cancel_flag_in_cache(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel sets the cancel flag in cache."""
from aiocache import caches
cache = caches.get("default")
await cancel(mock_context, mock_event_queue)
flag = await cache.get(f"cancel:{mock_context.task_id}")
assert flag is True
@pytest.mark.asyncio
async def test_enqueues_task_status_update_event(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel enqueues TaskStatusUpdateEvent with canceled state."""
await cancel(mock_context, mock_event_queue)
mock_event_queue.enqueue_event.assert_called_once()
event = mock_event_queue.enqueue_event.call_args[0][0]
assert event.task_id == mock_context.task_id
assert event.context_id == mock_context.context_id
assert event.status.state == TaskState.canceled
assert event.final is True
@pytest.mark.asyncio
async def test_returns_none_when_no_current_task(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel returns None when context has no current_task."""
mock_context.current_task = None
result = await cancel(mock_context, mock_event_queue)
assert result is None
@pytest.mark.asyncio
async def test_returns_updated_task_when_current_task_exists(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel returns updated task when context has current_task."""
current_task = MagicMock(spec=A2ATask)
current_task.status = TaskStatus(state=TaskState.working)
mock_context.current_task = current_task
result = await cancel(mock_context, mock_event_queue)
assert result is current_task
assert result.status.state == TaskState.canceled
@pytest.mark.asyncio
async def test_cleanup_after_cancel(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel flag persists for cancellable decorator to detect."""
from aiocache import caches
cache = caches.get("default")
await cancel(mock_context, mock_event_queue)
flag = await cache.get(f"cancel:{mock_context.task_id}")
assert flag is True
await cache.delete(f"cancel:{mock_context.task_id}")
class TestExecuteAndCancelIntegration:
"""Integration tests for execute and cancel working together."""
@pytest.mark.asyncio
async def test_cancel_stops_running_execute(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Calling cancel stops a running execute."""
async def slow_task(**kwargs: Any) -> str:
await asyncio.sleep(2.0)
return "should not complete"
mock_agent.aexecute_task = slow_task
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus"),
):
execute_task = asyncio.create_task(
execute(mock_agent, mock_context, mock_event_queue)
)
await asyncio.sleep(0.1)
await cancel(mock_context, mock_event_queue)
with pytest.raises(asyncio.CancelledError):
await execute_task