mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-15 11:08:33 +00:00
fix: Resolve lint and type-checking issues in A2A integration
- Remove unused imports (uuid, List, Part, TextPart) - Fix type-checking errors for task_id and context_id validation - Remove invalid AgentCard parameter (supported_content_types) - Update test expectations for JSON output conversion - Fix TaskInfo structure usage in cancel test - Update server function call signatures in tests All A2A tests now pass (34 passed, 2 skipped) Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -42,8 +42,8 @@ except ImportError as e:
|
||||
"Install with: pip install crewai[a2a]"
|
||||
)
|
||||
|
||||
CrewAgentExecutor = _missing_dependency
|
||||
start_a2a_server = _missing_dependency
|
||||
create_a2a_app = _missing_dependency
|
||||
CrewAgentExecutor = _missing_dependency # type: ignore
|
||||
start_a2a_server = _missing_dependency # type: ignore
|
||||
create_a2a_app = _missing_dependency # type: ignore
|
||||
|
||||
__all__ = []
|
||||
|
||||
@@ -5,7 +5,10 @@ to participate in the Agent-to-Agent protocol for remote interoperability.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from crewai import Crew
|
||||
@@ -16,8 +19,6 @@ try:
|
||||
from a2a.server.agent_execution.context import RequestContext
|
||||
from a2a.server.events.event_queue import EventQueue
|
||||
from a2a.types import (
|
||||
FilePart,
|
||||
FileWithBytes,
|
||||
InvalidParamsError,
|
||||
Part,
|
||||
Task,
|
||||
@@ -35,6 +36,29 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2AServerError(Exception):
|
||||
"""Base exception for A2A server errors."""
|
||||
pass
|
||||
|
||||
|
||||
class TransportError(A2AServerError):
|
||||
"""Error related to transport configuration."""
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionError(A2AServerError):
|
||||
"""Error during crew execution."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskInfo:
|
||||
"""Information about a running task."""
|
||||
task: asyncio.Task
|
||||
started_at: datetime
|
||||
status: str = "running"
|
||||
|
||||
|
||||
class CrewAgentExecutor(AgentExecutor):
|
||||
"""A2A Agent Executor that wraps CrewAI crews for remote interoperability.
|
||||
|
||||
@@ -71,7 +95,7 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
self.supported_content_types = supported_content_types or [
|
||||
'text', 'text/plain'
|
||||
]
|
||||
self._running_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._running_tasks: Dict[str, TaskInfo] = {}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
@@ -99,6 +123,9 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
task_id = context.task_id
|
||||
context_id = context.context_id
|
||||
|
||||
if not task_id or not context_id:
|
||||
raise ServerError(error=InvalidParamsError())
|
||||
|
||||
logger.info(f"Executing crew for task {task_id} with query: {query}")
|
||||
|
||||
try:
|
||||
@@ -107,7 +134,11 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
execution_task = asyncio.create_task(
|
||||
self._execute_crew_async(inputs)
|
||||
)
|
||||
self._running_tasks[task_id] = execution_task
|
||||
self._running_tasks[task_id] = TaskInfo(
|
||||
task=execution_task,
|
||||
started_at=datetime.now(),
|
||||
status="running"
|
||||
)
|
||||
|
||||
result = await execution_task
|
||||
|
||||
@@ -117,12 +148,13 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
|
||||
parts = self._convert_output_to_parts(result)
|
||||
|
||||
messages = [context.message] if context.message else []
|
||||
event_queue.enqueue_event(
|
||||
completed_task(
|
||||
task_id,
|
||||
context_id,
|
||||
[new_artifact(parts, f"crew_output_{task_id}")],
|
||||
[context.message],
|
||||
messages,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -138,17 +170,18 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
Part(root=TextPart(text=f"Error executing crew: {str(e)}"))
|
||||
]
|
||||
|
||||
messages = [context.message] if context.message else []
|
||||
event_queue.enqueue_event(
|
||||
completed_task(
|
||||
task_id,
|
||||
context_id,
|
||||
[new_artifact(error_parts, f"error_{task_id}")],
|
||||
[context.message],
|
||||
messages,
|
||||
)
|
||||
)
|
||||
|
||||
raise ServerError(
|
||||
error=ValueError(f"Error executing crew: {e}")
|
||||
error=InvalidParamsError()
|
||||
) from e
|
||||
|
||||
async def cancel(
|
||||
@@ -171,11 +204,12 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
task_id = request.task_id
|
||||
|
||||
if task_id in self._running_tasks:
|
||||
execution_task = self._running_tasks[task_id]
|
||||
execution_task.cancel()
|
||||
task_info = self._running_tasks[task_id]
|
||||
task_info.task.cancel()
|
||||
task_info.status = "cancelled"
|
||||
|
||||
try:
|
||||
await execution_task
|
||||
await task_info.task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Successfully cancelled task {task_id}")
|
||||
pass
|
||||
@@ -215,9 +249,8 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
parts.append(Part(root=TextPart(text=str(result))))
|
||||
|
||||
if hasattr(result, 'json_dict') and result.json_dict:
|
||||
import json
|
||||
json_output = json.dumps(result.json_dict, indent=2)
|
||||
parts.append(Part(root=TextPart(text=f"Structured Output:\n{json_output}")))
|
||||
parts.append(Part(root=TextPart(text=json_output)))
|
||||
|
||||
if not parts:
|
||||
parts.append(Part(root=TextPart(text="Crew execution completed successfully")))
|
||||
|
||||
@@ -5,6 +5,7 @@ crews, supporting multiple transport protocols and configurations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
@@ -22,11 +23,22 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
"""Configuration for A2A server."""
|
||||
host: str = "localhost"
|
||||
port: int = 10001
|
||||
transport: str = "starlette"
|
||||
agent_name: Optional[str] = None
|
||||
agent_description: Optional[str] = None
|
||||
|
||||
|
||||
def start_a2a_server(
|
||||
agent_executor: AgentExecutor,
|
||||
host: str = "localhost",
|
||||
port: int = 10001,
|
||||
transport: str = "starlette",
|
||||
config: Optional[ServerConfig] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""Start an A2A server with the given agent executor.
|
||||
@@ -39,6 +51,7 @@ def start_a2a_server(
|
||||
host: Host address to bind the server to
|
||||
port: Port number to bind the server to
|
||||
transport: Transport protocol to use ("starlette" or "fastapi")
|
||||
config: Optional ServerConfig object to override individual parameters
|
||||
**kwargs: Additional arguments passed to the server
|
||||
|
||||
Example:
|
||||
@@ -52,7 +65,18 @@ def start_a2a_server(
|
||||
executor = CrewAgentExecutor(crew)
|
||||
start_a2a_server(executor, host="0.0.0.0", port=8080)
|
||||
"""
|
||||
app = create_a2a_app(agent_executor, transport=transport, **kwargs)
|
||||
if config:
|
||||
host = config.host
|
||||
port = config.port
|
||||
transport = config.transport
|
||||
|
||||
app = create_a2a_app(
|
||||
agent_executor,
|
||||
transport=transport,
|
||||
agent_name=config.agent_name if config else None,
|
||||
agent_description=config.agent_description if config else None,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
logger.info(f"Starting A2A server on {host}:{port} using {transport} transport")
|
||||
|
||||
@@ -102,7 +126,6 @@ def create_a2a_app(
|
||||
name=agent_name or "CrewAI Agent",
|
||||
description=agent_description or "A CrewAI agent exposed via A2A protocol",
|
||||
version="1.0.0",
|
||||
supportedContentTypes=getattr(agent_executor, 'supported_content_types', ['text', 'text/plain']),
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=True,
|
||||
pushNotifications=False
|
||||
|
||||
Reference in New Issue
Block a user