mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 07:42:40 +00:00
feat: Enhance A2A integration with modular architecture
Address review comments from João: - Add TaskInfo dataclass for enhanced task management with status tracking - Add ServerConfig dataclass for improved server configuration management - Add custom exception classes (A2AServerError, TransportError, ExecutionError) - Refactor code to use modular components for better maintainability - Update output conversion to handle JSON data types properly - Improve error handling with granular exception types - All tests pass (30 passed, 6 skipped) Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -22,11 +22,19 @@ Example:
|
|||||||
try:
|
try:
|
||||||
from .crew_agent_executor import CrewAgentExecutor
|
from .crew_agent_executor import CrewAgentExecutor
|
||||||
from .server import start_a2a_server, create_a2a_app
|
from .server import start_a2a_server, create_a2a_app
|
||||||
|
from .server_config import ServerConfig
|
||||||
|
from .task_info import TaskInfo
|
||||||
|
from .exceptions import A2AServerError, TransportError, ExecutionError
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CrewAgentExecutor",
|
"CrewAgentExecutor",
|
||||||
"start_a2a_server",
|
"start_a2a_server",
|
||||||
"create_a2a_app"
|
"create_a2a_app",
|
||||||
|
"ServerConfig",
|
||||||
|
"TaskInfo",
|
||||||
|
"A2AServerError",
|
||||||
|
"TransportError",
|
||||||
|
"ExecutionError"
|
||||||
]
|
]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import warnings
|
import warnings
|
||||||
@@ -45,5 +53,10 @@ except ImportError:
|
|||||||
CrewAgentExecutor = _missing_dependency # type: ignore
|
CrewAgentExecutor = _missing_dependency # type: ignore
|
||||||
start_a2a_server = _missing_dependency # type: ignore
|
start_a2a_server = _missing_dependency # type: ignore
|
||||||
create_a2a_app = _missing_dependency # type: ignore
|
create_a2a_app = _missing_dependency # type: ignore
|
||||||
|
ServerConfig = _missing_dependency # type: ignore
|
||||||
|
TaskInfo = _missing_dependency # type: ignore
|
||||||
|
A2AServerError = _missing_dependency # type: ignore
|
||||||
|
TransportError = _missing_dependency # type: ignore
|
||||||
|
ExecutionError = _missing_dependency # type: ignore
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ to participate in the Agent-to-Agent protocol for remote interoperability.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from crewai import Crew
|
from crewai import Crew
|
||||||
from crewai.crew import CrewOutput
|
from crewai.crew import CrewOutput
|
||||||
|
|
||||||
|
from .task_info import TaskInfo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from a2a.server.agent_execution.agent_executor import AgentExecutor
|
from a2a.server.agent_execution.agent_executor import AgentExecutor
|
||||||
from a2a.server.agent_execution.context import RequestContext
|
from a2a.server.agent_execution.context import RequestContext
|
||||||
@@ -36,29 +36,6 @@ except ImportError:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class A2AServerError(Exception):
|
|
||||||
"""Base exception for A2A server errors."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TransportError(A2AServerError):
|
|
||||||
"""Error related to transport configuration."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionError(A2AServerError):
|
|
||||||
"""Error during crew execution."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TaskInfo:
|
|
||||||
"""Information about a running task."""
|
|
||||||
task: asyncio.Task
|
|
||||||
started_at: datetime
|
|
||||||
status: str = "running"
|
|
||||||
|
|
||||||
|
|
||||||
class CrewAgentExecutor(AgentExecutor):
|
class CrewAgentExecutor(AgentExecutor):
|
||||||
"""A2A Agent Executor that wraps CrewAI crews for remote interoperability.
|
"""A2A Agent Executor that wraps CrewAI crews for remote interoperability.
|
||||||
|
|
||||||
@@ -134,6 +111,7 @@ class CrewAgentExecutor(AgentExecutor):
|
|||||||
execution_task = asyncio.create_task(
|
execution_task = asyncio.create_task(
|
||||||
self._execute_crew_async(inputs)
|
self._execute_crew_async(inputs)
|
||||||
)
|
)
|
||||||
|
from datetime import datetime
|
||||||
self._running_tasks[task_id] = TaskInfo(
|
self._running_tasks[task_id] = TaskInfo(
|
||||||
task=execution_task,
|
task=execution_task,
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
@@ -206,7 +184,7 @@ class CrewAgentExecutor(AgentExecutor):
|
|||||||
if task_id in self._running_tasks:
|
if task_id in self._running_tasks:
|
||||||
task_info = self._running_tasks[task_id]
|
task_info = self._running_tasks[task_id]
|
||||||
task_info.task.cancel()
|
task_info.task.cancel()
|
||||||
task_info.status = "cancelled"
|
task_info.update_status("cancelled")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await task_info.task
|
await task_info.task
|
||||||
|
|||||||
16
src/crewai/a2a/exceptions.py
Normal file
16
src/crewai/a2a/exceptions.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Custom exceptions for A2A integration."""
|
||||||
|
|
||||||
|
|
||||||
|
class A2AServerError(Exception):
|
||||||
|
"""Base exception for A2A server errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TransportError(A2AServerError):
|
||||||
|
"""Error related to transport configuration."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionError(A2AServerError):
|
||||||
|
"""Error during crew execution."""
|
||||||
|
pass
|
||||||
@@ -5,9 +5,11 @@ crews, supporting multiple transport protocols and configurations.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from .exceptions import TransportError
|
||||||
|
from .server_config import ServerConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from a2a.server.agent_execution.agent_executor import AgentExecutor
|
from a2a.server.agent_execution.agent_executor import AgentExecutor
|
||||||
from a2a.server.apps import A2AStarletteApplication
|
from a2a.server.apps import A2AStarletteApplication
|
||||||
@@ -23,16 +25,6 @@ except ImportError:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ServerConfig:
|
|
||||||
"""Configuration for A2A server."""
|
|
||||||
host: str = "localhost"
|
|
||||||
port: int = 10001
|
|
||||||
transport: str = "starlette"
|
|
||||||
agent_name: Optional[str] = None
|
|
||||||
agent_description: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def start_a2a_server(
|
def start_a2a_server(
|
||||||
agent_executor: AgentExecutor,
|
agent_executor: AgentExecutor,
|
||||||
host: str = "localhost",
|
host: str = "localhost",
|
||||||
@@ -148,7 +140,7 @@ def create_a2a_app(
|
|||||||
request_handler = DefaultRequestHandler(agent_executor, task_store)
|
request_handler = DefaultRequestHandler(agent_executor, task_store)
|
||||||
|
|
||||||
if transport.lower() == "fastapi":
|
if transport.lower() == "fastapi":
|
||||||
raise ValueError("FastAPI transport is not available in the current A2A SDK version")
|
raise TransportError("FastAPI transport is not available in the current A2A SDK version")
|
||||||
else:
|
else:
|
||||||
app_instance = A2AStarletteApplication(
|
app_instance = A2AStarletteApplication(
|
||||||
agent_card=agent_card,
|
agent_card=agent_card,
|
||||||
|
|||||||
25
src/crewai/a2a/server_config.py
Normal file
25
src/crewai/a2a/server_config.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Server configuration for A2A integration."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ServerConfig:
|
||||||
|
"""Configuration for A2A server.
|
||||||
|
|
||||||
|
This class encapsulates server settings to improve readability
|
||||||
|
and flexibility for server setups.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
host: Host address to bind the server to
|
||||||
|
port: Port number to bind the server to
|
||||||
|
transport: Transport protocol to use ("starlette" or "fastapi")
|
||||||
|
agent_name: Optional name for the agent
|
||||||
|
agent_description: Optional description for the agent
|
||||||
|
"""
|
||||||
|
host: str = "localhost"
|
||||||
|
port: int = 10001
|
||||||
|
transport: str = "starlette"
|
||||||
|
agent_name: Optional[str] = None
|
||||||
|
agent_description: Optional[str] = None
|
||||||
47
src/crewai/a2a/task_info.py
Normal file
47
src/crewai/a2a/task_info.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""Task information tracking for A2A integration."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskInfo:
|
||||||
|
"""Information about a running task in the A2A executor.
|
||||||
|
|
||||||
|
This class tracks the lifecycle and status of tasks being executed
|
||||||
|
by the CrewAgentExecutor, providing better task management capabilities.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task: The asyncio task being executed
|
||||||
|
started_at: When the task was started
|
||||||
|
status: Current status of the task ("running", "completed", "cancelled", "failed")
|
||||||
|
"""
|
||||||
|
task: asyncio.Task
|
||||||
|
started_at: datetime
|
||||||
|
status: str = "running"
|
||||||
|
|
||||||
|
def update_status(self, new_status: str) -> None:
|
||||||
|
"""Update the task status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_status: The new status to set
|
||||||
|
"""
|
||||||
|
self.status = new_status
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
"""Check if the task is currently running."""
|
||||||
|
return self.status == "running" and not self.task.done()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self) -> Optional[float]:
|
||||||
|
"""Get the duration of the task in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds if task is completed, None if still running
|
||||||
|
"""
|
||||||
|
if self.task.done():
|
||||||
|
return (datetime.now() - self.started_at).total_seconds()
|
||||||
|
return None
|
||||||
@@ -95,7 +95,8 @@ class TestA2AServer:
|
|||||||
|
|
||||||
def test_create_a2a_app_fastapi(self, mock_agent_executor):
|
def test_create_a2a_app_fastapi(self, mock_agent_executor):
|
||||||
"""Test creating A2A app with FastAPI transport raises error."""
|
"""Test creating A2A app with FastAPI transport raises error."""
|
||||||
with pytest.raises(ValueError, match="FastAPI transport is not available"):
|
from crewai.a2a.exceptions import TransportError
|
||||||
|
with pytest.raises(TransportError, match="FastAPI transport is not available"):
|
||||||
create_a2a_app(
|
create_a2a_app(
|
||||||
mock_agent_executor,
|
mock_agent_executor,
|
||||||
transport="fastapi",
|
transport="fastapi",
|
||||||
|
|||||||
Reference in New Issue
Block a user