mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
feat: Enhance A2A integration with modular architecture
Address review comments from João: - Add TaskInfo dataclass for enhanced task management with status tracking - Add ServerConfig dataclass for improved server configuration management - Add custom exception classes (A2AServerError, TransportError, ExecutionError) - Refactor code to use modular components for better maintainability - Update output conversion to handle JSON data types properly - Improve error handling with granular exception types - All tests pass (30 passed, 6 skipped) Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -22,11 +22,19 @@ Example:
|
||||
try:
|
||||
from .crew_agent_executor import CrewAgentExecutor
|
||||
from .server import start_a2a_server, create_a2a_app
|
||||
from .server_config import ServerConfig
|
||||
from .task_info import TaskInfo
|
||||
from .exceptions import A2AServerError, TransportError, ExecutionError
|
||||
|
||||
__all__ = [
|
||||
"CrewAgentExecutor",
|
||||
"start_a2a_server",
|
||||
"create_a2a_app"
|
||||
"create_a2a_app",
|
||||
"ServerConfig",
|
||||
"TaskInfo",
|
||||
"A2AServerError",
|
||||
"TransportError",
|
||||
"ExecutionError"
|
||||
]
|
||||
except ImportError:
|
||||
import warnings
|
||||
@@ -45,5 +53,10 @@ except ImportError:
|
||||
CrewAgentExecutor = _missing_dependency # type: ignore
|
||||
start_a2a_server = _missing_dependency # type: ignore
|
||||
create_a2a_app = _missing_dependency # type: ignore
|
||||
ServerConfig = _missing_dependency # type: ignore
|
||||
TaskInfo = _missing_dependency # type: ignore
|
||||
A2AServerError = _missing_dependency # type: ignore
|
||||
TransportError = _missing_dependency # type: ignore
|
||||
ExecutionError = _missing_dependency # type: ignore
|
||||
|
||||
__all__ = []
|
||||
|
||||
@@ -7,13 +7,13 @@ to participate in the Agent-to-Agent protocol for remote interoperability.
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.crew import CrewOutput
|
||||
|
||||
from .task_info import TaskInfo
|
||||
|
||||
try:
|
||||
from a2a.server.agent_execution.agent_executor import AgentExecutor
|
||||
from a2a.server.agent_execution.context import RequestContext
|
||||
@@ -36,29 +36,6 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2AServerError(Exception):
|
||||
"""Base exception for A2A server errors."""
|
||||
pass
|
||||
|
||||
|
||||
class TransportError(A2AServerError):
|
||||
"""Error related to transport configuration."""
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionError(A2AServerError):
|
||||
"""Error during crew execution."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskInfo:
|
||||
"""Information about a running task."""
|
||||
task: asyncio.Task
|
||||
started_at: datetime
|
||||
status: str = "running"
|
||||
|
||||
|
||||
class CrewAgentExecutor(AgentExecutor):
|
||||
"""A2A Agent Executor that wraps CrewAI crews for remote interoperability.
|
||||
|
||||
@@ -134,6 +111,7 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
execution_task = asyncio.create_task(
|
||||
self._execute_crew_async(inputs)
|
||||
)
|
||||
from datetime import datetime
|
||||
self._running_tasks[task_id] = TaskInfo(
|
||||
task=execution_task,
|
||||
started_at=datetime.now(),
|
||||
@@ -206,7 +184,7 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
if task_id in self._running_tasks:
|
||||
task_info = self._running_tasks[task_id]
|
||||
task_info.task.cancel()
|
||||
task_info.status = "cancelled"
|
||||
task_info.update_status("cancelled")
|
||||
|
||||
try:
|
||||
await task_info.task
|
||||
|
||||
16
src/crewai/a2a/exceptions.py
Normal file
16
src/crewai/a2a/exceptions.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Custom exceptions for A2A integration."""
|
||||
|
||||
|
||||
class A2AServerError(Exception):
|
||||
"""Base exception for A2A server errors."""
|
||||
pass
|
||||
|
||||
|
||||
class TransportError(A2AServerError):
|
||||
"""Error related to transport configuration."""
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionError(A2AServerError):
|
||||
"""Error during crew execution."""
|
||||
pass
|
||||
@@ -5,9 +5,11 @@ crews, supporting multiple transport protocols and configurations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from .exceptions import TransportError
|
||||
from .server_config import ServerConfig
|
||||
|
||||
try:
|
||||
from a2a.server.agent_execution.agent_executor import AgentExecutor
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
@@ -23,16 +25,6 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
"""Configuration for A2A server."""
|
||||
host: str = "localhost"
|
||||
port: int = 10001
|
||||
transport: str = "starlette"
|
||||
agent_name: Optional[str] = None
|
||||
agent_description: Optional[str] = None
|
||||
|
||||
|
||||
def start_a2a_server(
|
||||
agent_executor: AgentExecutor,
|
||||
host: str = "localhost",
|
||||
@@ -148,7 +140,7 @@ def create_a2a_app(
|
||||
request_handler = DefaultRequestHandler(agent_executor, task_store)
|
||||
|
||||
if transport.lower() == "fastapi":
|
||||
raise ValueError("FastAPI transport is not available in the current A2A SDK version")
|
||||
raise TransportError("FastAPI transport is not available in the current A2A SDK version")
|
||||
else:
|
||||
app_instance = A2AStarletteApplication(
|
||||
agent_card=agent_card,
|
||||
|
||||
25
src/crewai/a2a/server_config.py
Normal file
25
src/crewai/a2a/server_config.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Server configuration for A2A integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
"""Configuration for A2A server.
|
||||
|
||||
This class encapsulates server settings to improve readability
|
||||
and flexibility for server setups.
|
||||
|
||||
Attributes:
|
||||
host: Host address to bind the server to
|
||||
port: Port number to bind the server to
|
||||
transport: Transport protocol to use ("starlette" or "fastapi")
|
||||
agent_name: Optional name for the agent
|
||||
agent_description: Optional description for the agent
|
||||
"""
|
||||
host: str = "localhost"
|
||||
port: int = 10001
|
||||
transport: str = "starlette"
|
||||
agent_name: Optional[str] = None
|
||||
agent_description: Optional[str] = None
|
||||
47
src/crewai/a2a/task_info.py
Normal file
47
src/crewai/a2a/task_info.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Task information tracking for A2A integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskInfo:
|
||||
"""Information about a running task in the A2A executor.
|
||||
|
||||
This class tracks the lifecycle and status of tasks being executed
|
||||
by the CrewAgentExecutor, providing better task management capabilities.
|
||||
|
||||
Attributes:
|
||||
task: The asyncio task being executed
|
||||
started_at: When the task was started
|
||||
status: Current status of the task ("running", "completed", "cancelled", "failed")
|
||||
"""
|
||||
task: asyncio.Task
|
||||
started_at: datetime
|
||||
status: str = "running"
|
||||
|
||||
def update_status(self, new_status: str) -> None:
|
||||
"""Update the task status.
|
||||
|
||||
Args:
|
||||
new_status: The new status to set
|
||||
"""
|
||||
self.status = new_status
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the task is currently running."""
|
||||
return self.status == "running" and not self.task.done()
|
||||
|
||||
@property
|
||||
def duration(self) -> Optional[float]:
|
||||
"""Get the duration of the task in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds if task is completed, None if still running
|
||||
"""
|
||||
if self.task.done():
|
||||
return (datetime.now() - self.started_at).total_seconds()
|
||||
return None
|
||||
@@ -95,7 +95,8 @@ class TestA2AServer:
|
||||
|
||||
def test_create_a2a_app_fastapi(self, mock_agent_executor):
|
||||
"""Test creating A2A app with FastAPI transport raises error."""
|
||||
with pytest.raises(ValueError, match="FastAPI transport is not available"):
|
||||
from crewai.a2a.exceptions import TransportError
|
||||
with pytest.raises(TransportError, match="FastAPI transport is not available"):
|
||||
create_a2a_app(
|
||||
mock_agent_executor,
|
||||
transport="fastapi",
|
||||
|
||||
Reference in New Issue
Block a user