mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 01:28:14 +00:00
Compare commits
2 Commits
devin/1768
...
devin/1768
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca1f8fd7e0 | ||
|
|
ceef062426 |
@@ -89,6 +89,9 @@ from crewai_tools.tools.jina_scrape_website_tool.jina_scrape_website_tool import
|
|||||||
from crewai_tools.tools.json_search_tool.json_search_tool import JSONSearchTool
|
from crewai_tools.tools.json_search_tool.json_search_tool import JSONSearchTool
|
||||||
from crewai_tools.tools.linkup.linkup_search_tool import LinkupSearchTool
|
from crewai_tools.tools.linkup.linkup_search_tool import LinkupSearchTool
|
||||||
from crewai_tools.tools.llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
from crewai_tools.tools.llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||||
|
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
|
||||||
|
MCPDiscoveryTool,
|
||||||
|
)
|
||||||
from crewai_tools.tools.mdx_search_tool.mdx_search_tool import MDXSearchTool
|
from crewai_tools.tools.mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||||
from crewai_tools.tools.merge_agent_handler_tool.merge_agent_handler_tool import (
|
from crewai_tools.tools.merge_agent_handler_tool.merge_agent_handler_tool import (
|
||||||
MergeAgentHandlerTool,
|
MergeAgentHandlerTool,
|
||||||
@@ -236,6 +239,7 @@ __all__ = [
|
|||||||
"JinaScrapeWebsiteTool",
|
"JinaScrapeWebsiteTool",
|
||||||
"LinkupSearchTool",
|
"LinkupSearchTool",
|
||||||
"LlamaIndexTool",
|
"LlamaIndexTool",
|
||||||
|
"MCPDiscoveryTool",
|
||||||
"MCPServerAdapter",
|
"MCPServerAdapter",
|
||||||
"MDXSearchTool",
|
"MDXSearchTool",
|
||||||
"MergeAgentHandlerTool",
|
"MergeAgentHandlerTool",
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
|
||||||
|
MCPDiscoveryResult,
|
||||||
|
MCPDiscoveryTool,
|
||||||
|
MCPDiscoveryToolSchema,
|
||||||
|
MCPServerMetrics,
|
||||||
|
MCPServerRecommendation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MCPDiscoveryResult",
|
||||||
|
"MCPDiscoveryTool",
|
||||||
|
"MCPDiscoveryToolSchema",
|
||||||
|
"MCPServerMetrics",
|
||||||
|
"MCPServerRecommendation",
|
||||||
|
]
|
||||||
@@ -0,0 +1,414 @@
|
|||||||
|
"""MCP Discovery Tool for CrewAI agents.
|
||||||
|
|
||||||
|
This tool enables agents to dynamically discover MCP servers based on
|
||||||
|
natural language queries using the MCP Discovery API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerMetrics(TypedDict, total=False):
|
||||||
|
"""Performance metrics for an MCP server."""
|
||||||
|
|
||||||
|
avg_latency_ms: float | None
|
||||||
|
uptime_pct: float | None
|
||||||
|
last_checked: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerRecommendation(TypedDict, total=False):
|
||||||
|
"""A recommended MCP server from the discovery API."""
|
||||||
|
|
||||||
|
server: str
|
||||||
|
npm_package: str
|
||||||
|
install_command: str
|
||||||
|
confidence: float
|
||||||
|
description: str
|
||||||
|
capabilities: list[str]
|
||||||
|
metrics: MCPServerMetrics
|
||||||
|
docs_url: str
|
||||||
|
github_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class MCPDiscoveryResult(TypedDict):
|
||||||
|
"""Result from the MCP Discovery API."""
|
||||||
|
|
||||||
|
recommendations: list[MCPServerRecommendation]
|
||||||
|
total_found: int
|
||||||
|
query_time_ms: int
|
||||||
|
|
||||||
|
|
||||||
|
class MCPDiscoveryConstraints(BaseModel):
|
||||||
|
"""Constraints for MCP server discovery."""
|
||||||
|
|
||||||
|
max_latency_ms: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Maximum acceptable latency in milliseconds",
|
||||||
|
)
|
||||||
|
required_features: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="List of required features/capabilities",
|
||||||
|
)
|
||||||
|
exclude_servers: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="List of server names to exclude from results",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPDiscoveryToolSchema(BaseModel):
|
||||||
|
"""Input schema for MCPDiscoveryTool."""
|
||||||
|
|
||||||
|
need: str = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Natural language description of what you need. "
|
||||||
|
"For example: 'database with authentication', 'email automation', "
|
||||||
|
"'file storage', 'web scraping'"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
constraints: MCPDiscoveryConstraints | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional constraints to filter results",
|
||||||
|
)
|
||||||
|
limit: int = Field(
|
||||||
|
default=5,
|
||||||
|
description="Maximum number of recommendations to return (1-10)",
|
||||||
|
ge=1,
|
||||||
|
le=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPDiscoveryTool(BaseTool):
|
||||||
|
"""Tool for discovering MCP servers dynamically.
|
||||||
|
|
||||||
|
This tool uses the MCP Discovery API to find MCP servers that match
|
||||||
|
a natural language description of what the agent needs. It enables
|
||||||
|
agents to dynamically discover and select the best MCP servers for
|
||||||
|
their tasks without requiring pre-configuration.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from crewai import Agent
|
||||||
|
from crewai_tools import MCPDiscoveryTool
|
||||||
|
|
||||||
|
discovery_tool = MCPDiscoveryTool()
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role='Researcher',
|
||||||
|
tools=[discovery_tool],
|
||||||
|
goal='Research and analyze data'
|
||||||
|
)
|
||||||
|
|
||||||
|
# The agent can now discover MCP servers dynamically:
|
||||||
|
# discover_mcp_server(need="database with authentication")
|
||||||
|
# Returns: Supabase MCP server with installation instructions
|
||||||
|
```
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: The name of the tool.
|
||||||
|
description: A description of what the tool does.
|
||||||
|
args_schema: The Pydantic model for input validation.
|
||||||
|
base_url: The base URL for the MCP Discovery API.
|
||||||
|
timeout: Request timeout in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "Discover MCP Server"
|
||||||
|
description: str = (
|
||||||
|
"Discover MCP (Model Context Protocol) servers that match your needs. "
|
||||||
|
"Use this tool to find the best MCP server for any task using natural "
|
||||||
|
"language. Returns server recommendations with installation instructions, "
|
||||||
|
"capabilities, and performance metrics."
|
||||||
|
)
|
||||||
|
args_schema: type[BaseModel] = MCPDiscoveryToolSchema
|
||||||
|
base_url: str = "https://mcp-discovery-production.up.railway.app"
|
||||||
|
timeout: int = 30
|
||||||
|
env_vars: list[EnvVar] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
EnvVar(
|
||||||
|
name="MCP_DISCOVERY_API_KEY",
|
||||||
|
description="API key for MCP Discovery (optional for free tier)",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_request_payload(
|
||||||
|
self,
|
||||||
|
need: str,
|
||||||
|
constraints: MCPDiscoveryConstraints | None,
|
||||||
|
limit: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build the request payload for the discovery API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
need: Natural language description of what is needed.
|
||||||
|
constraints: Optional constraints to filter results.
|
||||||
|
limit: Maximum number of recommendations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the request payload.
|
||||||
|
"""
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"need": need,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
if constraints:
|
||||||
|
constraints_dict: dict[str, Any] = {}
|
||||||
|
if constraints.max_latency_ms is not None:
|
||||||
|
constraints_dict["max_latency_ms"] = constraints.max_latency_ms
|
||||||
|
if constraints.required_features:
|
||||||
|
constraints_dict["required_features"] = constraints.required_features
|
||||||
|
if constraints.exclude_servers:
|
||||||
|
constraints_dict["exclude_servers"] = constraints.exclude_servers
|
||||||
|
if constraints_dict:
|
||||||
|
payload["constraints"] = constraints_dict
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _make_api_request(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Make a request to the MCP Discovery API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: The request payload.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The API response as a dictionary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the API returns an empty response.
|
||||||
|
requests.exceptions.RequestException: If the request fails.
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/api/v1/discover"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
api_key = os.environ.get("MCP_DISCOVERY_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
response = None
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
results = response.json()
|
||||||
|
if not results:
|
||||||
|
logger.error("Empty response from MCP Discovery API")
|
||||||
|
raise ValueError("Empty response from MCP Discovery API")
|
||||||
|
return results
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
error_msg = f"Error making request to MCP Discovery API: {e}"
|
||||||
|
if response is not None and hasattr(response, "content"):
|
||||||
|
error_msg += (
|
||||||
|
f"\nResponse content: "
|
||||||
|
f"{response.content.decode('utf-8', errors='replace')}"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
if response is not None and hasattr(response, "content"):
|
||||||
|
logger.error(f"Error decoding JSON response: {e}")
|
||||||
|
logger.error(
|
||||||
|
f"Response content: "
|
||||||
|
f"{response.content.decode('utf-8', errors='replace')}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Error decoding JSON response: {e} (No response content available)"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _process_single_recommendation(
|
||||||
|
self, rec: dict[str, Any]
|
||||||
|
) -> MCPServerRecommendation | None:
|
||||||
|
"""Process a single recommendation from the API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rec: Raw recommendation dictionary from the API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed MCPServerRecommendation or None if malformed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
metrics_data = rec.get("metrics", {}) if isinstance(rec, dict) else {}
|
||||||
|
metrics: MCPServerMetrics = {
|
||||||
|
"avg_latency_ms": metrics_data.get("avg_latency_ms"),
|
||||||
|
"uptime_pct": metrics_data.get("uptime_pct"),
|
||||||
|
"last_checked": metrics_data.get("last_checked"),
|
||||||
|
}
|
||||||
|
|
||||||
|
recommendation: MCPServerRecommendation = {
|
||||||
|
"server": rec.get("server", ""),
|
||||||
|
"npm_package": rec.get("npm_package", ""),
|
||||||
|
"install_command": rec.get("install_command", ""),
|
||||||
|
"confidence": rec.get("confidence", 0.0),
|
||||||
|
"description": rec.get("description", ""),
|
||||||
|
"capabilities": rec.get("capabilities", []),
|
||||||
|
"metrics": metrics,
|
||||||
|
"docs_url": rec.get("docs_url", ""),
|
||||||
|
"github_url": rec.get("github_url", ""),
|
||||||
|
}
|
||||||
|
return recommendation
|
||||||
|
except (KeyError, TypeError, AttributeError) as e:
|
||||||
|
logger.warning(f"Skipping malformed recommendation: {rec}, error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _process_recommendations(
|
||||||
|
self, recommendations: list[dict[str, Any]]
|
||||||
|
) -> list[MCPServerRecommendation]:
|
||||||
|
"""Process and validate server recommendations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recommendations: Raw recommendations from the API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of processed MCPServerRecommendation objects.
|
||||||
|
"""
|
||||||
|
processed: list[MCPServerRecommendation] = []
|
||||||
|
for rec in recommendations:
|
||||||
|
result = self._process_single_recommendation(rec)
|
||||||
|
if result is not None:
|
||||||
|
processed.append(result)
|
||||||
|
return processed
|
||||||
|
|
||||||
|
def _format_result(self, result: MCPDiscoveryResult) -> str:
|
||||||
|
"""Format the discovery result as a human-readable string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: The discovery result to format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string representation of the result.
|
||||||
|
"""
|
||||||
|
if not result["recommendations"]:
|
||||||
|
return "No MCP servers found matching your requirements."
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"Found {result['total_found']} MCP server(s) "
|
||||||
|
f"(query took {result['query_time_ms']}ms):\n"
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, rec in enumerate(result["recommendations"], 1):
|
||||||
|
confidence_pct = rec.get("confidence", 0) * 100
|
||||||
|
lines.append(f"{i}. {rec.get('server', 'Unknown')} ({confidence_pct:.0f}% confidence)")
|
||||||
|
lines.append(f" Description: {rec.get('description', 'N/A')}")
|
||||||
|
lines.append(f" Capabilities: {', '.join(rec.get('capabilities', []))}")
|
||||||
|
lines.append(f" Install: {rec.get('install_command', 'N/A')}")
|
||||||
|
lines.append(f" NPM Package: {rec.get('npm_package', 'N/A')}")
|
||||||
|
|
||||||
|
metrics = rec.get("metrics", {})
|
||||||
|
if metrics.get("avg_latency_ms") is not None:
|
||||||
|
lines.append(f" Avg Latency: {metrics['avg_latency_ms']}ms")
|
||||||
|
if metrics.get("uptime_pct") is not None:
|
||||||
|
lines.append(f" Uptime: {metrics['uptime_pct']}%")
|
||||||
|
|
||||||
|
if rec.get("docs_url"):
|
||||||
|
lines.append(f" Docs: {rec['docs_url']}")
|
||||||
|
if rec.get("github_url"):
|
||||||
|
lines.append(f" GitHub: {rec['github_url']}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _run(self, **kwargs: Any) -> str:
|
||||||
|
"""Execute the MCP discovery operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments matching MCPDiscoveryToolSchema.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string with discovered MCP servers.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required parameters are missing.
|
||||||
|
"""
|
||||||
|
need: str | None = kwargs.get("need")
|
||||||
|
if not need:
|
||||||
|
raise ValueError("'need' parameter is required")
|
||||||
|
|
||||||
|
constraints_data = kwargs.get("constraints")
|
||||||
|
constraints: MCPDiscoveryConstraints | None = None
|
||||||
|
if constraints_data:
|
||||||
|
if isinstance(constraints_data, dict):
|
||||||
|
constraints = MCPDiscoveryConstraints(**constraints_data)
|
||||||
|
elif isinstance(constraints_data, MCPDiscoveryConstraints):
|
||||||
|
constraints = constraints_data
|
||||||
|
|
||||||
|
limit: int = kwargs.get("limit", 5)
|
||||||
|
|
||||||
|
payload = self._build_request_payload(need, constraints, limit)
|
||||||
|
response = self._make_api_request(payload)
|
||||||
|
|
||||||
|
recommendations = self._process_recommendations(
|
||||||
|
response.get("recommendations", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
result: MCPDiscoveryResult = {
|
||||||
|
"recommendations": recommendations,
|
||||||
|
"total_found": response.get("total_found", len(recommendations)),
|
||||||
|
"query_time_ms": response.get("query_time_ms", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
return self._format_result(result)
|
||||||
|
|
||||||
|
def discover(
|
||||||
|
self,
|
||||||
|
need: str,
|
||||||
|
constraints: MCPDiscoveryConstraints | None = None,
|
||||||
|
limit: int = 5,
|
||||||
|
) -> MCPDiscoveryResult:
|
||||||
|
"""Discover MCP servers matching the given requirements.
|
||||||
|
|
||||||
|
This is a convenience method that returns structured data instead
|
||||||
|
of a formatted string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
need: Natural language description of what is needed.
|
||||||
|
constraints: Optional constraints to filter results.
|
||||||
|
limit: Maximum number of recommendations (1-10).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCPDiscoveryResult containing server recommendations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
tool = MCPDiscoveryTool()
|
||||||
|
result = tool.discover(
|
||||||
|
need="database with authentication",
|
||||||
|
constraints=MCPDiscoveryConstraints(
|
||||||
|
max_latency_ms=200,
|
||||||
|
required_features=["auth", "realtime"]
|
||||||
|
),
|
||||||
|
limit=3
|
||||||
|
)
|
||||||
|
for rec in result["recommendations"]:
|
||||||
|
print(f"{rec['server']}: {rec['description']}")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
payload = self._build_request_payload(need, constraints, limit)
|
||||||
|
response = self._make_api_request(payload)
|
||||||
|
|
||||||
|
recommendations = self._process_recommendations(
|
||||||
|
response.get("recommendations", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"recommendations": recommendations,
|
||||||
|
"total_found": response.get("total_found", len(recommendations)),
|
||||||
|
"query_time_ms": response.get("query_time_ms", 0),
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
@@ -0,0 +1,452 @@
|
|||||||
|
"""Tests for the MCP Discovery Tool."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
|
||||||
|
MCPDiscoveryConstraints,
|
||||||
|
MCPDiscoveryResult,
|
||||||
|
MCPDiscoveryTool,
|
||||||
|
MCPDiscoveryToolSchema,
|
||||||
|
MCPServerMetrics,
|
||||||
|
MCPServerRecommendation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_api_response() -> dict:
|
||||||
|
"""Create a mock API response for testing."""
|
||||||
|
return {
|
||||||
|
"recommendations": [
|
||||||
|
{
|
||||||
|
"server": "sqlite-server",
|
||||||
|
"npm_package": "@modelcontextprotocol/server-sqlite",
|
||||||
|
"install_command": "npx -y @modelcontextprotocol/server-sqlite",
|
||||||
|
"confidence": 0.38,
|
||||||
|
"description": "SQLite database server for MCP.",
|
||||||
|
"capabilities": ["sqlite", "sql", "database", "embedded"],
|
||||||
|
"metrics": {
|
||||||
|
"avg_latency_ms": 50.0,
|
||||||
|
"uptime_pct": 99.9,
|
||||||
|
"last_checked": "2026-01-17T10:30:00Z",
|
||||||
|
},
|
||||||
|
"docs_url": "https://modelcontextprotocol.io/docs/servers/sqlite",
|
||||||
|
"github_url": "https://github.com/modelcontextprotocol/servers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"server": "postgres-server",
|
||||||
|
"npm_package": "@modelcontextprotocol/server-postgres",
|
||||||
|
"install_command": "npx -y @modelcontextprotocol/server-postgres",
|
||||||
|
"confidence": 0.33,
|
||||||
|
"description": "PostgreSQL database server for MCP.",
|
||||||
|
"capabilities": ["postgres", "sql", "database", "queries"],
|
||||||
|
"metrics": {
|
||||||
|
"avg_latency_ms": None,
|
||||||
|
"uptime_pct": None,
|
||||||
|
"last_checked": None,
|
||||||
|
},
|
||||||
|
"docs_url": "https://modelcontextprotocol.io/docs/servers/postgres",
|
||||||
|
"github_url": "https://github.com/modelcontextprotocol/servers",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"total_found": 2,
|
||||||
|
"query_time_ms": 245,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def discovery_tool() -> MCPDiscoveryTool:
|
||||||
|
"""Create an MCPDiscoveryTool instance for testing."""
|
||||||
|
return MCPDiscoveryTool()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPDiscoveryToolSchema:
|
||||||
|
"""Tests for MCPDiscoveryToolSchema."""
|
||||||
|
|
||||||
|
def test_schema_with_required_fields(self) -> None:
|
||||||
|
"""Test schema with only required fields."""
|
||||||
|
schema = MCPDiscoveryToolSchema(need="database server")
|
||||||
|
assert schema.need == "database server"
|
||||||
|
assert schema.constraints is None
|
||||||
|
assert schema.limit == 5
|
||||||
|
|
||||||
|
def test_schema_with_all_fields(self) -> None:
|
||||||
|
"""Test schema with all fields."""
|
||||||
|
constraints = MCPDiscoveryConstraints(
|
||||||
|
max_latency_ms=200,
|
||||||
|
required_features=["auth", "realtime"],
|
||||||
|
exclude_servers=["deprecated-server"],
|
||||||
|
)
|
||||||
|
schema = MCPDiscoveryToolSchema(
|
||||||
|
need="database with authentication",
|
||||||
|
constraints=constraints,
|
||||||
|
limit=3,
|
||||||
|
)
|
||||||
|
assert schema.need == "database with authentication"
|
||||||
|
assert schema.constraints is not None
|
||||||
|
assert schema.constraints.max_latency_ms == 200
|
||||||
|
assert schema.constraints.required_features == ["auth", "realtime"]
|
||||||
|
assert schema.constraints.exclude_servers == ["deprecated-server"]
|
||||||
|
assert schema.limit == 3
|
||||||
|
|
||||||
|
def test_schema_limit_validation(self) -> None:
|
||||||
|
"""Test that limit is validated to be between 1 and 10."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MCPDiscoveryToolSchema(need="test", limit=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MCPDiscoveryToolSchema(need="test", limit=11)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPDiscoveryConstraints:
|
||||||
|
"""Tests for MCPDiscoveryConstraints."""
|
||||||
|
|
||||||
|
def test_empty_constraints(self) -> None:
|
||||||
|
"""Test creating empty constraints."""
|
||||||
|
constraints = MCPDiscoveryConstraints()
|
||||||
|
assert constraints.max_latency_ms is None
|
||||||
|
assert constraints.required_features is None
|
||||||
|
assert constraints.exclude_servers is None
|
||||||
|
|
||||||
|
def test_full_constraints(self) -> None:
|
||||||
|
"""Test creating constraints with all fields."""
|
||||||
|
constraints = MCPDiscoveryConstraints(
|
||||||
|
max_latency_ms=100,
|
||||||
|
required_features=["feature1", "feature2"],
|
||||||
|
exclude_servers=["server1", "server2"],
|
||||||
|
)
|
||||||
|
assert constraints.max_latency_ms == 100
|
||||||
|
assert constraints.required_features == ["feature1", "feature2"]
|
||||||
|
assert constraints.exclude_servers == ["server1", "server2"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPDiscoveryTool:
|
||||||
|
"""Tests for MCPDiscoveryTool."""
|
||||||
|
|
||||||
|
def test_tool_initialization(self, discovery_tool: MCPDiscoveryTool) -> None:
|
||||||
|
"""Test tool initialization with default values."""
|
||||||
|
assert discovery_tool.name == "Discover MCP Server"
|
||||||
|
assert "MCP" in discovery_tool.description
|
||||||
|
assert discovery_tool.base_url == "https://mcp-discovery-production.up.railway.app"
|
||||||
|
assert discovery_tool.timeout == 30
|
||||||
|
|
||||||
|
def test_build_request_payload_basic(
|
||||||
|
self, discovery_tool: MCPDiscoveryTool
|
||||||
|
) -> None:
|
||||||
|
"""Test building request payload with basic parameters."""
|
||||||
|
payload = discovery_tool._build_request_payload(
|
||||||
|
need="database server",
|
||||||
|
constraints=None,
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
assert payload == {"need": "database server", "limit": 5}
|
||||||
|
|
||||||
|
def test_build_request_payload_with_constraints(
|
||||||
|
self, discovery_tool: MCPDiscoveryTool
|
||||||
|
) -> None:
|
||||||
|
"""Test building request payload with constraints."""
|
||||||
|
constraints = MCPDiscoveryConstraints(
|
||||||
|
max_latency_ms=200,
|
||||||
|
required_features=["auth"],
|
||||||
|
exclude_servers=["old-server"],
|
||||||
|
)
|
||||||
|
payload = discovery_tool._build_request_payload(
|
||||||
|
need="database",
|
||||||
|
constraints=constraints,
|
||||||
|
limit=3,
|
||||||
|
)
|
||||||
|
assert payload["need"] == "database"
|
||||||
|
assert payload["limit"] == 3
|
||||||
|
assert "constraints" in payload
|
||||||
|
assert payload["constraints"]["max_latency_ms"] == 200
|
||||||
|
assert payload["constraints"]["required_features"] == ["auth"]
|
||||||
|
assert payload["constraints"]["exclude_servers"] == ["old-server"]
|
||||||
|
|
||||||
|
def test_build_request_payload_partial_constraints(
|
||||||
|
self, discovery_tool: MCPDiscoveryTool
|
||||||
|
) -> None:
|
||||||
|
"""Test building request payload with partial constraints."""
|
||||||
|
constraints = MCPDiscoveryConstraints(max_latency_ms=100)
|
||||||
|
payload = discovery_tool._build_request_payload(
|
||||||
|
need="test",
|
||||||
|
constraints=constraints,
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
assert payload["constraints"] == {"max_latency_ms": 100}
|
||||||
|
|
||||||
|
def test_process_recommendations(
|
||||||
|
self, discovery_tool: MCPDiscoveryTool, mock_api_response: dict
|
||||||
|
) -> None:
|
||||||
|
"""Test processing recommendations from API response."""
|
||||||
|
recommendations = discovery_tool._process_recommendations(
|
||||||
|
mock_api_response["recommendations"]
|
||||||
|
)
|
||||||
|
assert len(recommendations) == 2
|
||||||
|
|
||||||
|
first_rec = recommendations[0]
|
||||||
|
assert first_rec["server"] == "sqlite-server"
|
||||||
|
assert first_rec["npm_package"] == "@modelcontextprotocol/server-sqlite"
|
||||||
|
assert first_rec["confidence"] == 0.38
|
||||||
|
assert first_rec["capabilities"] == ["sqlite", "sql", "database", "embedded"]
|
||||||
|
assert first_rec["metrics"]["avg_latency_ms"] == 50.0
|
||||||
|
assert first_rec["metrics"]["uptime_pct"] == 99.9
|
||||||
|
|
||||||
|
second_rec = recommendations[1]
|
||||||
|
assert second_rec["server"] == "postgres-server"
|
||||||
|
assert second_rec["metrics"]["avg_latency_ms"] is None
|
||||||
|
|
||||||
|
def test_process_recommendations_with_malformed_data(
|
||||||
|
self, discovery_tool: MCPDiscoveryTool
|
||||||
|
) -> None:
|
||||||
|
"""Test processing recommendations with malformed data."""
|
||||||
|
malformed_recommendations = [
|
||||||
|
{"server": "valid-server", "confidence": 0.5},
|
||||||
|
None,
|
||||||
|
{"invalid": "data"},
|
||||||
|
]
|
||||||
|
recommendations = discovery_tool._process_recommendations(
|
||||||
|
malformed_recommendations
|
||||||
|
)
|
||||||
|
assert len(recommendations) >= 1
|
||||||
|
assert recommendations[0]["server"] == "valid-server"
|
||||||
|
|
||||||
|
def test_format_result_with_recommendations(
|
||||||
|
self, discovery_tool: MCPDiscoveryTool
|
||||||
|
) -> None:
|
||||||
|
"""Test formatting results with recommendations."""
|
||||||
|
result: MCPDiscoveryResult = {
|
||||||
|
"recommendations": [
|
||||||
|
{
|
||||||
|
"server": "test-server",
|
||||||
|
"npm_package": "@test/server",
|
||||||
|
"install_command": "npx -y @test/server",
|
||||||
|
"confidence": 0.85,
|
||||||
|
"description": "A test server",
|
||||||
|
"capabilities": ["test", "demo"],
|
||||||
|
"metrics": {
|
||||||
|
"avg_latency_ms": 100.0,
|
||||||
|
"uptime_pct": 99.5,
|
||||||
|
"last_checked": "2026-01-17T10:00:00Z",
|
||||||
|
},
|
||||||
|
"docs_url": "https://example.com/docs",
|
||||||
|
"github_url": "https://github.com/test/server",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total_found": 1,
|
||||||
|
"query_time_ms": 150,
|
||||||
|
}
|
||||||
|
formatted = discovery_tool._format_result(result)
|
||||||
|
assert "Found 1 MCP server(s)" in formatted
|
||||||
|
assert "test-server" in formatted
|
||||||
|
assert "85% confidence" in formatted
|
||||||
|
assert "A test server" in formatted
|
||||||
|
assert "test, demo" in formatted
|
||||||
|
assert "npx -y @test/server" in formatted
|
||||||
|
assert "100.0ms" in formatted
|
||||||
|
assert "99.5%" in formatted
|
||||||
|
|
||||||
|
def test_format_result_empty(self, discovery_tool: MCPDiscoveryTool) -> None:
|
||||||
|
"""Test formatting results with no recommendations."""
|
||||||
|
result: MCPDiscoveryResult = {
|
||||||
|
"recommendations": [],
|
||||||
|
"total_found": 0,
|
||||||
|
"query_time_ms": 50,
|
||||||
|
}
|
||||||
|
formatted = discovery_tool._format_result(result)
|
||||||
|
assert "No MCP servers found" in formatted
|
||||||
|
|
||||||
|
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||||
|
def test_make_api_request_success(
|
||||||
|
self,
|
||||||
|
mock_post: MagicMock,
|
||||||
|
discovery_tool: MCPDiscoveryTool,
|
||||||
|
mock_api_response: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful API request."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_api_response
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discovery_tool._make_api_request({"need": "database", "limit": 5})
|
||||||
|
|
||||||
|
assert result == mock_api_response
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
assert call_args[1]["json"] == {"need": "database", "limit": 5}
|
||||||
|
assert call_args[1]["timeout"] == 30
|
||||||
|
|
||||||
|
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||||
|
def test_make_api_request_with_api_key(
|
||||||
|
self,
|
||||||
|
mock_post: MagicMock,
|
||||||
|
discovery_tool: MCPDiscoveryTool,
|
||||||
|
mock_api_response: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Test API request with API key."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_api_response
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.dict("os.environ", {"MCP_DISCOVERY_API_KEY": "test-key"}):
|
||||||
|
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||||
|
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
assert "Authorization" in call_args[1]["headers"]
|
||||||
|
assert call_args[1]["headers"]["Authorization"] == "Bearer test-key"
|
||||||
|
|
||||||
|
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||||
|
def test_make_api_request_empty_response(
|
||||||
|
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
|
||||||
|
) -> None:
|
||||||
|
"""Test API request with empty response."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Empty response"):
|
||||||
|
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||||
|
|
||||||
|
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||||
|
def test_make_api_request_network_error(
|
||||||
|
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
|
||||||
|
) -> None:
|
||||||
|
"""Test API request with network error."""
|
||||||
|
mock_post.side_effect = requests.exceptions.ConnectionError("Network error")
|
||||||
|
|
||||||
|
with pytest.raises(requests.exceptions.ConnectionError):
|
||||||
|
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||||
|
|
||||||
|
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||||
|
def test_make_api_request_json_decode_error(
|
||||||
|
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
|
||||||
|
) -> None:
|
||||||
|
"""Test API request with JSON decode error."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.side_effect = json.JSONDecodeError("Error", "", 0)
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_response.content = b"invalid json"
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
with pytest.raises(json.JSONDecodeError):
|
||||||
|
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||||
|
|
||||||
|
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||||
|
def test_run_success(
|
||||||
|
self,
|
||||||
|
mock_post: MagicMock,
|
||||||
|
discovery_tool: MCPDiscoveryTool,
|
||||||
|
mock_api_response: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful _run execution."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_api_response
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discovery_tool._run(need="database server")
|
||||||
|
|
||||||
|
assert "sqlite-server" in result
|
||||||
|
assert "postgres-server" in result
|
||||||
|
assert "Found 2 MCP server(s)" in result
|
||||||
|
|
||||||
|
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||||
|
def test_run_with_constraints(
|
||||||
|
self,
|
||||||
|
mock_post: MagicMock,
|
||||||
|
discovery_tool: MCPDiscoveryTool,
|
||||||
|
mock_api_response: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Test _run with constraints."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_api_response
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discovery_tool._run(
|
||||||
|
need="database",
|
||||||
|
constraints={"max_latency_ms": 100, "required_features": ["sql"]},
|
||||||
|
limit=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "sqlite-server" in result
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
payload = call_args[1]["json"]
|
||||||
|
assert payload["constraints"]["max_latency_ms"] == 100
|
||||||
|
assert payload["constraints"]["required_features"] == ["sql"]
|
||||||
|
assert payload["limit"] == 3
|
||||||
|
|
||||||
|
def test_run_missing_need_parameter(
|
||||||
|
self, discovery_tool: MCPDiscoveryTool
|
||||||
|
) -> None:
|
||||||
|
"""Test _run with missing need parameter."""
|
||||||
|
with pytest.raises(ValueError, match="'need' parameter is required"):
|
||||||
|
discovery_tool._run()
|
||||||
|
|
||||||
|
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||||
|
def test_discover_method(
|
||||||
|
self,
|
||||||
|
mock_post: MagicMock,
|
||||||
|
discovery_tool: MCPDiscoveryTool,
|
||||||
|
mock_api_response: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Test the discover convenience method."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_api_response
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discovery_tool.discover(
|
||||||
|
need="database",
|
||||||
|
constraints=MCPDiscoveryConstraints(max_latency_ms=200),
|
||||||
|
limit=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "recommendations" in result
|
||||||
|
assert "total_found" in result
|
||||||
|
assert "query_time_ms" in result
|
||||||
|
assert len(result["recommendations"]) == 2
|
||||||
|
assert result["total_found"] == 2
|
||||||
|
|
||||||
|
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||||
|
def test_discover_returns_structured_data(
|
||||||
|
self,
|
||||||
|
mock_post: MagicMock,
|
||||||
|
discovery_tool: MCPDiscoveryTool,
|
||||||
|
mock_api_response: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Test that discover returns properly structured data."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_api_response
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
result = discovery_tool.discover(need="database")
|
||||||
|
|
||||||
|
first_rec = result["recommendations"][0]
|
||||||
|
assert "server" in first_rec
|
||||||
|
assert "npm_package" in first_rec
|
||||||
|
assert "install_command" in first_rec
|
||||||
|
assert "confidence" in first_rec
|
||||||
|
assert "description" in first_rec
|
||||||
|
assert "capabilities" in first_rec
|
||||||
|
assert "metrics" in first_rec
|
||||||
|
assert "docs_url" in first_rec
|
||||||
|
assert "github_url" in first_rec
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPDiscoveryToolIntegration:
|
||||||
|
"""Integration tests for MCPDiscoveryTool (requires network)."""
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Integration test - requires network access")
|
||||||
|
def test_real_api_call(self) -> None:
|
||||||
|
"""Test actual API call to MCP Discovery service."""
|
||||||
|
tool = MCPDiscoveryTool()
|
||||||
|
result = tool._run(need="database", limit=3)
|
||||||
|
assert "MCP server" in result or "No MCP servers found" in result
|
||||||
@@ -3,9 +3,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import TYPE_CHECKING, TypedDict
|
from typing import TYPE_CHECKING, Any, TypedDict
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from a2a.client.errors import A2AClientHTTPError
|
||||||
from a2a.types import (
|
from a2a.types import (
|
||||||
AgentCard,
|
AgentCard,
|
||||||
Message,
|
Message,
|
||||||
@@ -20,7 +21,10 @@ from a2a.types import (
|
|||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
from crewai.events.types.a2a_events import (
|
||||||
|
A2AConnectionErrorEvent,
|
||||||
|
A2AResponseReceivedEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -55,7 +59,8 @@ class TaskStateResult(TypedDict):
|
|||||||
history: list[Message]
|
history: list[Message]
|
||||||
result: NotRequired[str]
|
result: NotRequired[str]
|
||||||
error: NotRequired[str]
|
error: NotRequired[str]
|
||||||
agent_card: NotRequired[AgentCard]
|
agent_card: NotRequired[dict[str, Any]]
|
||||||
|
a2a_agent_name: NotRequired[str | None]
|
||||||
|
|
||||||
|
|
||||||
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
|
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
|
||||||
@@ -131,50 +136,69 @@ def process_task_state(
|
|||||||
is_multiturn: bool,
|
is_multiturn: bool,
|
||||||
agent_role: str | None,
|
agent_role: str | None,
|
||||||
result_parts: list[str] | None = None,
|
result_parts: list[str] | None = None,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
a2a_agent_name: str | None = None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
is_final: bool = True,
|
||||||
) -> TaskStateResult | None:
|
) -> TaskStateResult | None:
|
||||||
"""Process A2A task state and return result dictionary.
|
"""Process A2A task state and return result dictionary.
|
||||||
|
|
||||||
Shared logic for both polling and streaming handlers.
|
Shared logic for both polling and streaming handlers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a2a_task: The A2A task to process
|
a2a_task: The A2A task to process.
|
||||||
new_messages: List to collect messages (modified in place)
|
new_messages: List to collect messages (modified in place).
|
||||||
agent_card: The agent card
|
agent_card: The agent card.
|
||||||
turn_number: Current turn number
|
turn_number: Current turn number.
|
||||||
is_multiturn: Whether multi-turn conversation
|
is_multiturn: Whether multi-turn conversation.
|
||||||
agent_role: Agent role for logging
|
agent_role: Agent role for logging.
|
||||||
result_parts: Accumulated result parts (streaming passes accumulated,
|
result_parts: Accumulated result parts (streaming passes accumulated,
|
||||||
polling passes None to extract from task)
|
polling passes None to extract from task).
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
from_task: Optional CrewAI Task for event metadata.
|
||||||
|
from_agent: Optional CrewAI Agent for event metadata.
|
||||||
|
is_final: Whether this is the final response in the stream.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Result dictionary if terminal/actionable state, None otherwise
|
Result dictionary if terminal/actionable state, None otherwise.
|
||||||
"""
|
"""
|
||||||
should_extract = result_parts is None
|
|
||||||
if result_parts is None:
|
if result_parts is None:
|
||||||
result_parts = []
|
result_parts = []
|
||||||
|
|
||||||
if a2a_task.status.state == TaskState.completed:
|
if a2a_task.status.state == TaskState.completed:
|
||||||
if should_extract:
|
if not result_parts:
|
||||||
extracted_parts = extract_task_result_parts(a2a_task)
|
extracted_parts = extract_task_result_parts(a2a_task)
|
||||||
result_parts.extend(extracted_parts)
|
result_parts.extend(extracted_parts)
|
||||||
if a2a_task.history:
|
if a2a_task.history:
|
||||||
new_messages.extend(a2a_task.history)
|
new_messages.extend(a2a_task.history)
|
||||||
|
|
||||||
response_text = " ".join(result_parts) if result_parts else ""
|
response_text = " ".join(result_parts) if result_parts else ""
|
||||||
|
message_id = None
|
||||||
|
if a2a_task.status and a2a_task.status.message:
|
||||||
|
message_id = a2a_task.status.message.message_id
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
None,
|
None,
|
||||||
A2AResponseReceivedEvent(
|
A2AResponseReceivedEvent(
|
||||||
response=response_text,
|
response=response_text,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
|
context_id=a2a_task.context_id,
|
||||||
|
message_id=message_id,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
status="completed",
|
status="completed",
|
||||||
|
final=is_final,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
status=TaskState.completed,
|
status=TaskState.completed,
|
||||||
agent_card=agent_card,
|
agent_card=agent_card.model_dump(exclude_none=True),
|
||||||
result=response_text,
|
result=response_text,
|
||||||
history=new_messages,
|
history=new_messages,
|
||||||
)
|
)
|
||||||
@@ -194,14 +218,24 @@ def process_task_state(
|
|||||||
)
|
)
|
||||||
new_messages.append(agent_message)
|
new_messages.append(agent_message)
|
||||||
|
|
||||||
|
input_message_id = None
|
||||||
|
if a2a_task.status and a2a_task.status.message:
|
||||||
|
input_message_id = a2a_task.status.message.message_id
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
None,
|
None,
|
||||||
A2AResponseReceivedEvent(
|
A2AResponseReceivedEvent(
|
||||||
response=response_text,
|
response=response_text,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
|
context_id=a2a_task.context_id,
|
||||||
|
message_id=input_message_id,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
status="input_required",
|
status="input_required",
|
||||||
|
final=is_final,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -209,7 +243,7 @@ def process_task_state(
|
|||||||
status=TaskState.input_required,
|
status=TaskState.input_required,
|
||||||
error=response_text,
|
error=response_text,
|
||||||
history=new_messages,
|
history=new_messages,
|
||||||
agent_card=agent_card,
|
agent_card=agent_card.model_dump(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
|
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
|
||||||
@@ -248,6 +282,11 @@ async def send_message_and_get_task_id(
|
|||||||
turn_number: int,
|
turn_number: int,
|
||||||
is_multiturn: bool,
|
is_multiturn: bool,
|
||||||
agent_role: str | None,
|
agent_role: str | None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
a2a_agent_name: str | None = None,
|
||||||
|
context_id: str | None = None,
|
||||||
) -> str | TaskStateResult:
|
) -> str | TaskStateResult:
|
||||||
"""Send message and process initial response.
|
"""Send message and process initial response.
|
||||||
|
|
||||||
@@ -262,6 +301,11 @@ async def send_message_and_get_task_id(
|
|||||||
turn_number: Current turn number
|
turn_number: Current turn number
|
||||||
is_multiturn: Whether multi-turn conversation
|
is_multiturn: Whether multi-turn conversation
|
||||||
agent_role: Agent role for logging
|
agent_role: Agent role for logging
|
||||||
|
from_task: Optional CrewAI Task object for event metadata.
|
||||||
|
from_agent: Optional CrewAI Agent object for event metadata.
|
||||||
|
endpoint: Optional A2A endpoint URL.
|
||||||
|
a2a_agent_name: Optional A2A agent name.
|
||||||
|
context_id: Optional A2A context ID for correlation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
|
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
|
||||||
@@ -280,9 +324,16 @@ async def send_message_and_get_task_id(
|
|||||||
A2AResponseReceivedEvent(
|
A2AResponseReceivedEvent(
|
||||||
response=response_text,
|
response=response_text,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
|
context_id=event.context_id,
|
||||||
|
message_id=event.message_id,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
status="completed",
|
status="completed",
|
||||||
|
final=True,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -290,7 +341,7 @@ async def send_message_and_get_task_id(
|
|||||||
status=TaskState.completed,
|
status=TaskState.completed,
|
||||||
result=response_text,
|
result=response_text,
|
||||||
history=new_messages,
|
history=new_messages,
|
||||||
agent_card=agent_card,
|
agent_card=agent_card.model_dump(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(event, tuple):
|
if isinstance(event, tuple):
|
||||||
@@ -304,6 +355,10 @@ async def send_message_and_get_task_id(
|
|||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
return result
|
return result
|
||||||
@@ -316,6 +371,99 @@ async def send_message_and_get_task_id(
|
|||||||
history=new_messages,
|
history=new_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except A2AClientHTTPError as e:
|
||||||
|
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||||
|
|
||||||
|
error_message = Message(
|
||||||
|
role=Role.agent,
|
||||||
|
message_id=str(uuid.uuid4()),
|
||||||
|
parts=[Part(root=TextPart(text=error_msg))],
|
||||||
|
context_id=context_id,
|
||||||
|
)
|
||||||
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=str(e),
|
||||||
|
error_type="http_error",
|
||||||
|
status_code=e.status_code,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="send_message",
|
||||||
|
context_id=context_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=error_msg,
|
||||||
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="failed",
|
||||||
|
final=True,
|
||||||
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error=error_msg,
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error during send_message: {e!s}"
|
||||||
|
|
||||||
|
error_message = Message(
|
||||||
|
role=Role.agent,
|
||||||
|
message_id=str(uuid.uuid4()),
|
||||||
|
parts=[Part(root=TextPart(text=error_msg))],
|
||||||
|
context_id=context_id,
|
||||||
|
)
|
||||||
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=str(e),
|
||||||
|
error_type="unexpected_error",
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="send_message",
|
||||||
|
context_id=context_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=error_msg,
|
||||||
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="failed",
|
||||||
|
final=True,
|
||||||
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error=error_msg,
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
aclose = getattr(event_stream, "aclose", None)
|
aclose = getattr(event_stream, "aclose", None)
|
||||||
if aclose:
|
if aclose:
|
||||||
|
|||||||
@@ -22,6 +22,13 @@ class BaseHandlerKwargs(TypedDict, total=False):
|
|||||||
turn_number: int
|
turn_number: int
|
||||||
is_multiturn: bool
|
is_multiturn: bool
|
||||||
agent_role: str | None
|
agent_role: str | None
|
||||||
|
context_id: str | None
|
||||||
|
task_id: str | None
|
||||||
|
endpoint: str | None
|
||||||
|
agent_branch: Any
|
||||||
|
a2a_agent_name: str | None
|
||||||
|
from_task: Any
|
||||||
|
from_agent: Any
|
||||||
|
|
||||||
|
|
||||||
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||||
@@ -29,8 +36,6 @@ class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
|||||||
|
|
||||||
polling_interval: float
|
polling_interval: float
|
||||||
polling_timeout: float
|
polling_timeout: float
|
||||||
endpoint: str
|
|
||||||
agent_branch: Any
|
|
||||||
history_length: int
|
history_length: int
|
||||||
max_polls: int | None
|
max_polls: int | None
|
||||||
|
|
||||||
@@ -38,9 +43,6 @@ class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
|||||||
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
|
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||||
"""Kwargs for streaming handler."""
|
"""Kwargs for streaming handler."""
|
||||||
|
|
||||||
context_id: str | None
|
|
||||||
task_id: str | None
|
|
||||||
|
|
||||||
|
|
||||||
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||||
"""Kwargs for push notification handler."""
|
"""Kwargs for push notification handler."""
|
||||||
@@ -49,7 +51,6 @@ class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
|||||||
result_store: PushNotificationResultStore
|
result_store: PushNotificationResultStore
|
||||||
polling_timeout: float
|
polling_timeout: float
|
||||||
polling_interval: float
|
polling_interval: float
|
||||||
agent_branch: Any
|
|
||||||
|
|
||||||
|
|
||||||
class PushNotificationResultStore(Protocol):
|
class PushNotificationResultStore(Protocol):
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from crewai.a2a.task_helpers import (
|
|||||||
from crewai.a2a.updates.base import PollingHandlerKwargs
|
from crewai.a2a.updates.base import PollingHandlerKwargs
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.a2a_events import (
|
from crewai.events.types.a2a_events import (
|
||||||
|
A2AConnectionErrorEvent,
|
||||||
A2APollingStartedEvent,
|
A2APollingStartedEvent,
|
||||||
A2APollingStatusEvent,
|
A2APollingStatusEvent,
|
||||||
A2AResponseReceivedEvent,
|
A2AResponseReceivedEvent,
|
||||||
@@ -49,23 +50,33 @@ async def _poll_task_until_complete(
|
|||||||
agent_branch: Any | None = None,
|
agent_branch: Any | None = None,
|
||||||
history_length: int = 100,
|
history_length: int = 100,
|
||||||
max_polls: int | None = None,
|
max_polls: int | None = None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
context_id: str | None = None,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
a2a_agent_name: str | None = None,
|
||||||
) -> A2ATask:
|
) -> A2ATask:
|
||||||
"""Poll task status until terminal state reached.
|
"""Poll task status until terminal state reached.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: A2A client instance
|
client: A2A client instance.
|
||||||
task_id: Task ID to poll
|
task_id: Task ID to poll.
|
||||||
polling_interval: Seconds between poll attempts
|
polling_interval: Seconds between poll attempts.
|
||||||
polling_timeout: Max seconds before timeout
|
polling_timeout: Max seconds before timeout.
|
||||||
agent_branch: Agent tree branch for logging
|
agent_branch: Agent tree branch for logging.
|
||||||
history_length: Number of messages to retrieve per poll
|
history_length: Number of messages to retrieve per poll.
|
||||||
max_polls: Max number of poll attempts (None = unlimited)
|
max_polls: Max number of poll attempts (None = unlimited).
|
||||||
|
from_task: Optional CrewAI Task object for event metadata.
|
||||||
|
from_agent: Optional CrewAI Agent object for event metadata.
|
||||||
|
context_id: A2A context ID for correlation.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Final task object in terminal state
|
Final task object in terminal state.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
A2APollingTimeoutError: If polling exceeds timeout or max_polls
|
A2APollingTimeoutError: If polling exceeds timeout or max_polls.
|
||||||
"""
|
"""
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
poll_count = 0
|
poll_count = 0
|
||||||
@@ -77,13 +88,19 @@ async def _poll_task_until_complete(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elapsed = time.monotonic() - start_time
|
elapsed = time.monotonic() - start_time
|
||||||
|
effective_context_id = task.context_id or context_id
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent_branch,
|
agent_branch,
|
||||||
A2APollingStatusEvent(
|
A2APollingStatusEvent(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
context_id=effective_context_id,
|
||||||
state=str(task.status.state.value) if task.status.state else "unknown",
|
state=str(task.status.state.value) if task.status.state else "unknown",
|
||||||
elapsed_seconds=elapsed,
|
elapsed_seconds=elapsed,
|
||||||
poll_count=poll_count,
|
poll_count=poll_count,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -137,6 +154,9 @@ class PollingHandler:
|
|||||||
max_polls = kwargs.get("max_polls")
|
max_polls = kwargs.get("max_polls")
|
||||||
context_id = kwargs.get("context_id")
|
context_id = kwargs.get("context_id")
|
||||||
task_id = kwargs.get("task_id")
|
task_id = kwargs.get("task_id")
|
||||||
|
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||||
|
from_task = kwargs.get("from_task")
|
||||||
|
from_agent = kwargs.get("from_agent")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result_or_task_id = await send_message_and_get_task_id(
|
result_or_task_id = await send_message_and_get_task_id(
|
||||||
@@ -146,6 +166,11 @@ class PollingHandler:
|
|||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
context_id=context_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(result_or_task_id, str):
|
if not isinstance(result_or_task_id, str):
|
||||||
@@ -157,8 +182,12 @@ class PollingHandler:
|
|||||||
agent_branch,
|
agent_branch,
|
||||||
A2APollingStartedEvent(
|
A2APollingStartedEvent(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
context_id=context_id,
|
||||||
polling_interval=polling_interval,
|
polling_interval=polling_interval,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -170,6 +199,11 @@ class PollingHandler:
|
|||||||
agent_branch=agent_branch,
|
agent_branch=agent_branch,
|
||||||
history_length=history_length,
|
history_length=history_length,
|
||||||
max_polls=max_polls,
|
max_polls=max_polls,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
context_id=context_id,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = process_task_state(
|
result = process_task_state(
|
||||||
@@ -179,6 +213,10 @@ class PollingHandler:
|
|||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
return result
|
return result
|
||||||
@@ -206,9 +244,15 @@ class PollingHandler:
|
|||||||
A2AResponseReceivedEvent(
|
A2AResponseReceivedEvent(
|
||||||
response=error_msg,
|
response=error_msg,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
status="failed",
|
status="failed",
|
||||||
|
final=True,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
@@ -229,14 +273,83 @@ class PollingHandler:
|
|||||||
)
|
)
|
||||||
new_messages.append(error_message)
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint,
|
||||||
|
error=str(e),
|
||||||
|
error_type="http_error",
|
||||||
|
status_code=e.status_code,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="polling",
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent_branch,
|
agent_branch,
|
||||||
A2AResponseReceivedEvent(
|
A2AResponseReceivedEvent(
|
||||||
response=error_msg,
|
response=error_msg,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
status="failed",
|
status="failed",
|
||||||
|
final=True,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error=error_msg,
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error during polling: {e!s}"
|
||||||
|
|
||||||
|
error_message = Message(
|
||||||
|
role=Role.agent,
|
||||||
|
message_id=str(uuid.uuid4()),
|
||||||
|
parts=[Part(root=TextPart(text=error_msg))],
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=str(e),
|
||||||
|
error_type="unexpected_error",
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="polling",
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=error_msg,
|
||||||
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="failed",
|
||||||
|
final=True,
|
||||||
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from crewai.a2a.updates.base import (
|
|||||||
)
|
)
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.a2a_events import (
|
from crewai.events.types.a2a_events import (
|
||||||
|
A2AConnectionErrorEvent,
|
||||||
A2APushNotificationRegisteredEvent,
|
A2APushNotificationRegisteredEvent,
|
||||||
A2APushNotificationTimeoutEvent,
|
A2APushNotificationTimeoutEvent,
|
||||||
A2AResponseReceivedEvent,
|
A2AResponseReceivedEvent,
|
||||||
@@ -48,6 +49,11 @@ async def _wait_for_push_result(
|
|||||||
timeout: float,
|
timeout: float,
|
||||||
poll_interval: float,
|
poll_interval: float,
|
||||||
agent_branch: Any | None = None,
|
agent_branch: Any | None = None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
context_id: str | None = None,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
a2a_agent_name: str | None = None,
|
||||||
) -> A2ATask | None:
|
) -> A2ATask | None:
|
||||||
"""Wait for push notification result.
|
"""Wait for push notification result.
|
||||||
|
|
||||||
@@ -57,6 +63,11 @@ async def _wait_for_push_result(
|
|||||||
timeout: Max seconds to wait.
|
timeout: Max seconds to wait.
|
||||||
poll_interval: Seconds between polling attempts.
|
poll_interval: Seconds between polling attempts.
|
||||||
agent_branch: Agent tree branch for logging.
|
agent_branch: Agent tree branch for logging.
|
||||||
|
from_task: Optional CrewAI Task object for event metadata.
|
||||||
|
from_agent: Optional CrewAI Agent object for event metadata.
|
||||||
|
context_id: A2A context ID for correlation.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Final task object, or None if timeout.
|
Final task object, or None if timeout.
|
||||||
@@ -72,7 +83,12 @@ async def _wait_for_push_result(
|
|||||||
agent_branch,
|
agent_branch,
|
||||||
A2APushNotificationTimeoutEvent(
|
A2APushNotificationTimeoutEvent(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
context_id=context_id,
|
||||||
timeout_seconds=timeout,
|
timeout_seconds=timeout,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,18 +131,56 @@ class PushNotificationHandler:
|
|||||||
agent_role = kwargs.get("agent_role")
|
agent_role = kwargs.get("agent_role")
|
||||||
context_id = kwargs.get("context_id")
|
context_id = kwargs.get("context_id")
|
||||||
task_id = kwargs.get("task_id")
|
task_id = kwargs.get("task_id")
|
||||||
|
endpoint = kwargs.get("endpoint")
|
||||||
|
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||||
|
from_task = kwargs.get("from_task")
|
||||||
|
from_agent = kwargs.get("from_agent")
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
|
error_msg = (
|
||||||
|
"PushNotificationConfig is required for push notification handler"
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=error_msg,
|
||||||
|
error_type="configuration_error",
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="push_notification",
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
status=TaskState.failed,
|
status=TaskState.failed,
|
||||||
error="PushNotificationConfig is required for push notification handler",
|
error=error_msg,
|
||||||
history=new_messages,
|
history=new_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result_store is None:
|
if result_store is None:
|
||||||
|
error_msg = (
|
||||||
|
"PushNotificationResultStore is required for push notification handler"
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=error_msg,
|
||||||
|
error_type="configuration_error",
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="push_notification",
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
status=TaskState.failed,
|
status=TaskState.failed,
|
||||||
error="PushNotificationResultStore is required for push notification handler",
|
error=error_msg,
|
||||||
history=new_messages,
|
history=new_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -138,6 +192,11 @@ class PushNotificationHandler:
|
|||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
context_id=context_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(result_or_task_id, str):
|
if not isinstance(result_or_task_id, str):
|
||||||
@@ -149,7 +208,12 @@ class PushNotificationHandler:
|
|||||||
agent_branch,
|
agent_branch,
|
||||||
A2APushNotificationRegisteredEvent(
|
A2APushNotificationRegisteredEvent(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
context_id=context_id,
|
||||||
callback_url=str(config.url),
|
callback_url=str(config.url),
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -165,6 +229,11 @@ class PushNotificationHandler:
|
|||||||
timeout=polling_timeout,
|
timeout=polling_timeout,
|
||||||
poll_interval=polling_interval,
|
poll_interval=polling_interval,
|
||||||
agent_branch=agent_branch,
|
agent_branch=agent_branch,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
context_id=context_id,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_task is None:
|
if final_task is None:
|
||||||
@@ -181,6 +250,10 @@ class PushNotificationHandler:
|
|||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
return result
|
return result
|
||||||
@@ -203,14 +276,83 @@ class PushNotificationHandler:
|
|||||||
)
|
)
|
||||||
new_messages.append(error_message)
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=str(e),
|
||||||
|
error_type="http_error",
|
||||||
|
status_code=e.status_code,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="push_notification",
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent_branch,
|
agent_branch,
|
||||||
A2AResponseReceivedEvent(
|
A2AResponseReceivedEvent(
|
||||||
response=error_msg,
|
response=error_msg,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
status="failed",
|
status="failed",
|
||||||
|
final=True,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error=error_msg,
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error during push notification: {e!s}"
|
||||||
|
|
||||||
|
error_message = Message(
|
||||||
|
role=Role.agent,
|
||||||
|
message_id=str(uuid.uuid4()),
|
||||||
|
parts=[Part(root=TextPart(text=error_msg))],
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=str(e),
|
||||||
|
error_type="unexpected_error",
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="push_notification",
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=error_msg,
|
||||||
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="failed",
|
||||||
|
final=True,
|
||||||
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
|
|||||||
@@ -26,7 +26,13 @@ from crewai.a2a.task_helpers import (
|
|||||||
)
|
)
|
||||||
from crewai.a2a.updates.base import StreamingHandlerKwargs
|
from crewai.a2a.updates.base import StreamingHandlerKwargs
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
from crewai.events.types.a2a_events import (
|
||||||
|
A2AArtifactReceivedEvent,
|
||||||
|
A2AConnectionErrorEvent,
|
||||||
|
A2AResponseReceivedEvent,
|
||||||
|
A2AStreamingChunkEvent,
|
||||||
|
A2AStreamingStartedEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StreamingHandler:
|
class StreamingHandler:
|
||||||
@@ -57,19 +63,57 @@ class StreamingHandler:
|
|||||||
turn_number = kwargs.get("turn_number", 0)
|
turn_number = kwargs.get("turn_number", 0)
|
||||||
is_multiturn = kwargs.get("is_multiturn", False)
|
is_multiturn = kwargs.get("is_multiturn", False)
|
||||||
agent_role = kwargs.get("agent_role")
|
agent_role = kwargs.get("agent_role")
|
||||||
|
endpoint = kwargs.get("endpoint")
|
||||||
|
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||||
|
from_task = kwargs.get("from_task")
|
||||||
|
from_agent = kwargs.get("from_agent")
|
||||||
|
agent_branch = kwargs.get("agent_branch")
|
||||||
|
|
||||||
result_parts: list[str] = []
|
result_parts: list[str] = []
|
||||||
final_result: TaskStateResult | None = None
|
final_result: TaskStateResult | None = None
|
||||||
event_stream = client.send_message(message)
|
event_stream = client.send_message(message)
|
||||||
|
chunk_index = 0
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AStreamingStartedEvent(
|
||||||
|
task_id=task_id,
|
||||||
|
context_id=context_id,
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
agent_role=agent_role,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for event in event_stream:
|
async for event in event_stream:
|
||||||
if isinstance(event, Message):
|
if isinstance(event, Message):
|
||||||
new_messages.append(event)
|
new_messages.append(event)
|
||||||
|
message_context_id = event.context_id or context_id
|
||||||
for part in event.parts:
|
for part in event.parts:
|
||||||
if part.root.kind == "text":
|
if part.root.kind == "text":
|
||||||
text = part.root.text
|
text = part.root.text
|
||||||
result_parts.append(text)
|
result_parts.append(text)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AStreamingChunkEvent(
|
||||||
|
task_id=event.task_id or task_id,
|
||||||
|
context_id=message_context_id,
|
||||||
|
chunk=text,
|
||||||
|
chunk_index=chunk_index,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chunk_index += 1
|
||||||
|
|
||||||
elif isinstance(event, tuple):
|
elif isinstance(event, tuple):
|
||||||
a2a_task, update = event
|
a2a_task, update = event
|
||||||
@@ -81,10 +125,51 @@ class StreamingHandler:
|
|||||||
for part in artifact.parts
|
for part in artifact.parts
|
||||||
if part.root.kind == "text"
|
if part.root.kind == "text"
|
||||||
)
|
)
|
||||||
|
artifact_size = None
|
||||||
|
if artifact.parts:
|
||||||
|
artifact_size = sum(
|
||||||
|
len(p.root.text.encode("utf-8"))
|
||||||
|
if p.root.kind == "text"
|
||||||
|
else len(getattr(p.root, "data", b""))
|
||||||
|
for p in artifact.parts
|
||||||
|
)
|
||||||
|
effective_context_id = a2a_task.context_id or context_id
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AArtifactReceivedEvent(
|
||||||
|
task_id=a2a_task.id,
|
||||||
|
artifact_id=artifact.artifact_id,
|
||||||
|
artifact_name=artifact.name,
|
||||||
|
artifact_description=artifact.description,
|
||||||
|
mime_type=artifact.parts[0].root.kind
|
||||||
|
if artifact.parts
|
||||||
|
else None,
|
||||||
|
size_bytes=artifact_size,
|
||||||
|
append=update.append or False,
|
||||||
|
last_chunk=update.last_chunk or False,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
context_id=effective_context_id,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
is_final_update = False
|
is_final_update = False
|
||||||
if isinstance(update, TaskStatusUpdateEvent):
|
if isinstance(update, TaskStatusUpdateEvent):
|
||||||
is_final_update = update.final
|
is_final_update = update.final
|
||||||
|
if (
|
||||||
|
update.status
|
||||||
|
and update.status.message
|
||||||
|
and update.status.message.parts
|
||||||
|
):
|
||||||
|
result_parts.extend(
|
||||||
|
part.root.text
|
||||||
|
for part in update.status.message.parts
|
||||||
|
if part.root.kind == "text" and part.root.text
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not is_final_update
|
not is_final_update
|
||||||
@@ -101,6 +186,11 @@ class StreamingHandler:
|
|||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
result_parts=result_parts,
|
result_parts=result_parts,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
is_final=is_final_update,
|
||||||
)
|
)
|
||||||
if final_result:
|
if final_result:
|
||||||
break
|
break
|
||||||
@@ -118,13 +208,82 @@ class StreamingHandler:
|
|||||||
new_messages.append(error_message)
|
new_messages.append(error_message)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
None,
|
agent_branch,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=str(e),
|
||||||
|
error_type="http_error",
|
||||||
|
status_code=e.status_code,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="streaming",
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
A2AResponseReceivedEvent(
|
A2AResponseReceivedEvent(
|
||||||
response=error_msg,
|
response=error_msg,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
status="failed",
|
status="failed",
|
||||||
|
final=True,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error=error_msg,
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error during streaming: {e!s}"
|
||||||
|
|
||||||
|
error_message = Message(
|
||||||
|
role=Role.agent,
|
||||||
|
message_id=str(uuid.uuid4()),
|
||||||
|
parts=[Part(root=TextPart(text=error_msg))],
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=str(e),
|
||||||
|
error_type="unexpected_error",
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="streaming",
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=error_msg,
|
||||||
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="failed",
|
||||||
|
final=True,
|
||||||
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
@@ -136,7 +295,23 @@ class StreamingHandler:
|
|||||||
finally:
|
finally:
|
||||||
aclose = getattr(event_stream, "aclose", None)
|
aclose = getattr(event_stream, "aclose", None)
|
||||||
if aclose:
|
if aclose:
|
||||||
await aclose()
|
try:
|
||||||
|
await aclose()
|
||||||
|
except Exception as close_error:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint or "",
|
||||||
|
error=str(close_error),
|
||||||
|
error_type="stream_close_error",
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
operation="stream_close",
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if final_result:
|
if final_result:
|
||||||
return final_result
|
return final_result
|
||||||
@@ -145,5 +320,5 @@ class StreamingHandler:
|
|||||||
status=TaskState.completed,
|
status=TaskState.completed,
|
||||||
result=" ".join(result_parts) if result_parts else "",
|
result=" ".join(result_parts) if result_parts else "",
|
||||||
history=new_messages,
|
history=new_messages,
|
||||||
agent_card=agent_card,
|
agent_card=agent_card.model_dump(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,6 +23,12 @@ from crewai.a2a.auth.utils import (
|
|||||||
)
|
)
|
||||||
from crewai.a2a.config import A2AServerConfig
|
from crewai.a2a.config import A2AServerConfig
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.a2a_events import (
|
||||||
|
A2AAgentCardFetchedEvent,
|
||||||
|
A2AAuthenticationFailedEvent,
|
||||||
|
A2AConnectionErrorEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -183,6 +189,8 @@ async def _afetch_agent_card_impl(
|
|||||||
timeout: int,
|
timeout: int,
|
||||||
) -> AgentCard:
|
) -> AgentCard:
|
||||||
"""Internal async implementation of AgentCard fetching."""
|
"""Internal async implementation of AgentCard fetching."""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if "/.well-known/agent-card.json" in endpoint:
|
if "/.well-known/agent-card.json" in endpoint:
|
||||||
base_url = endpoint.replace("/.well-known/agent-card.json", "")
|
base_url = endpoint.replace("/.well-known/agent-card.json", "")
|
||||||
agent_card_path = "/.well-known/agent-card.json"
|
agent_card_path = "/.well-known/agent-card.json"
|
||||||
@@ -217,9 +225,29 @@ async def _afetch_agent_card_impl(
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
return AgentCard.model_validate(response.json())
|
agent_card = AgentCard.model_validate(response.json())
|
||||||
|
fetch_time_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AAgentCardFetchedEvent(
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=agent_card.name,
|
||||||
|
agent_card=agent_card_dict,
|
||||||
|
protocol_version=agent_card.protocol_version,
|
||||||
|
provider=agent_card_dict.get("provider"),
|
||||||
|
cached=False,
|
||||||
|
fetch_time_ms=fetch_time_ms,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return agent_card
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
response_body = e.response.text[:1000] if e.response.text else None
|
||||||
|
|
||||||
if e.response.status_code == 401:
|
if e.response.status_code == 401:
|
||||||
error_details = ["Authentication failed"]
|
error_details = ["Authentication failed"]
|
||||||
www_auth = e.response.headers.get("WWW-Authenticate")
|
www_auth = e.response.headers.get("WWW-Authenticate")
|
||||||
@@ -228,7 +256,93 @@ async def _afetch_agent_card_impl(
|
|||||||
if not auth:
|
if not auth:
|
||||||
error_details.append("No auth scheme provided")
|
error_details.append("No auth scheme provided")
|
||||||
msg = " | ".join(error_details)
|
msg = " | ".join(error_details)
|
||||||
|
|
||||||
|
auth_type = type(auth).__name__ if auth else None
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AAuthenticationFailedEvent(
|
||||||
|
endpoint=endpoint,
|
||||||
|
auth_type=auth_type,
|
||||||
|
error=msg,
|
||||||
|
status_code=401,
|
||||||
|
metadata={
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"response_body": response_body,
|
||||||
|
"www_authenticate": www_auth,
|
||||||
|
"request_url": str(e.request.url),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
raise A2AClientHTTPError(401, msg) from e
|
raise A2AClientHTTPError(401, msg) from e
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint,
|
||||||
|
error=str(e),
|
||||||
|
error_type="http_error",
|
||||||
|
status_code=e.response.status_code,
|
||||||
|
operation="fetch_agent_card",
|
||||||
|
metadata={
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"response_body": response_body,
|
||||||
|
"request_url": str(e.request.url),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except httpx.TimeoutException as e:
|
||||||
|
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint,
|
||||||
|
error=str(e),
|
||||||
|
error_type="timeout",
|
||||||
|
operation="fetch_agent_card",
|
||||||
|
metadata={
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"timeout_config": timeout,
|
||||||
|
"request_url": str(e.request.url) if e.request else None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint,
|
||||||
|
error=str(e),
|
||||||
|
error_type="connection_error",
|
||||||
|
operation="fetch_agent_card",
|
||||||
|
metadata={
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"request_url": str(e.request.url) if e.request else None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AConnectionErrorEvent(
|
||||||
|
endpoint=endpoint,
|
||||||
|
error=str(e),
|
||||||
|
error_type="request_error",
|
||||||
|
operation="fetch_agent_card",
|
||||||
|
metadata={
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"request_url": str(e.request.url) if e.request else None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -88,6 +88,9 @@ def execute_a2a_delegation(
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
turn_number: int | None = None,
|
turn_number: int | None = None,
|
||||||
updates: UpdateConfig | None = None,
|
updates: UpdateConfig | None = None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
skill_id: str | None = None,
|
||||||
) -> TaskStateResult:
|
) -> TaskStateResult:
|
||||||
"""Execute a task delegation to a remote A2A agent synchronously.
|
"""Execute a task delegation to a remote A2A agent synchronously.
|
||||||
|
|
||||||
@@ -129,6 +132,9 @@ def execute_a2a_delegation(
|
|||||||
response_model: Optional Pydantic model for structured outputs.
|
response_model: Optional Pydantic model for structured outputs.
|
||||||
turn_number: Optional turn number for multi-turn conversations.
|
turn_number: Optional turn number for multi-turn conversations.
|
||||||
updates: Update mechanism config from A2AConfig.updates.
|
updates: Update mechanism config from A2AConfig.updates.
|
||||||
|
from_task: Optional CrewAI Task object for event metadata.
|
||||||
|
from_agent: Optional CrewAI Agent object for event metadata.
|
||||||
|
skill_id: Optional skill ID to target a specific agent capability.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TaskStateResult with status, result/error, history, and agent_card.
|
TaskStateResult with status, result/error, history, and agent_card.
|
||||||
@@ -156,10 +162,16 @@ def execute_a2a_delegation(
|
|||||||
transport_protocol=transport_protocol,
|
transport_protocol=transport_protocol,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
updates=updates,
|
updates=updates,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
skill_id=skill_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
loop.close()
|
try:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
async def aexecute_a2a_delegation(
|
async def aexecute_a2a_delegation(
|
||||||
@@ -181,6 +193,9 @@ async def aexecute_a2a_delegation(
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
turn_number: int | None = None,
|
turn_number: int | None = None,
|
||||||
updates: UpdateConfig | None = None,
|
updates: UpdateConfig | None = None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
skill_id: str | None = None,
|
||||||
) -> TaskStateResult:
|
) -> TaskStateResult:
|
||||||
"""Execute a task delegation to a remote A2A agent asynchronously.
|
"""Execute a task delegation to a remote A2A agent asynchronously.
|
||||||
|
|
||||||
@@ -222,6 +237,9 @@ async def aexecute_a2a_delegation(
|
|||||||
response_model: Optional Pydantic model for structured outputs.
|
response_model: Optional Pydantic model for structured outputs.
|
||||||
turn_number: Optional turn number for multi-turn conversations.
|
turn_number: Optional turn number for multi-turn conversations.
|
||||||
updates: Update mechanism config from A2AConfig.updates.
|
updates: Update mechanism config from A2AConfig.updates.
|
||||||
|
from_task: Optional CrewAI Task object for event metadata.
|
||||||
|
from_agent: Optional CrewAI Agent object for event metadata.
|
||||||
|
skill_id: Optional skill ID to target a specific agent capability.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TaskStateResult with status, result/error, history, and agent_card.
|
TaskStateResult with status, result/error, history, and agent_card.
|
||||||
@@ -233,17 +251,6 @@ async def aexecute_a2a_delegation(
|
|||||||
if turn_number is None:
|
if turn_number is None:
|
||||||
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
|
||||||
agent_branch,
|
|
||||||
A2ADelegationStartedEvent(
|
|
||||||
endpoint=endpoint,
|
|
||||||
task_description=task_description,
|
|
||||||
agent_id=agent_id,
|
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
turn_number=turn_number,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await _aexecute_a2a_delegation_impl(
|
result = await _aexecute_a2a_delegation_impl(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
auth=auth,
|
auth=auth,
|
||||||
@@ -264,15 +271,28 @@ async def aexecute_a2a_delegation(
|
|||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
updates=updates,
|
updates=updates,
|
||||||
transport_protocol=transport_protocol,
|
transport_protocol=transport_protocol,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
skill_id=skill_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent_branch,
|
agent_branch,
|
||||||
A2ADelegationCompletedEvent(
|
A2ADelegationCompletedEvent(
|
||||||
status=result["status"],
|
status=result["status"],
|
||||||
result=result.get("result"),
|
result=result.get("result"),
|
||||||
error=result.get("error"),
|
error=result.get("error"),
|
||||||
|
context_id=context_id,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=result.get("a2a_agent_name"),
|
||||||
|
agent_card=agent_card_data,
|
||||||
|
provider=agent_card_data.get("provider"),
|
||||||
|
metadata=metadata,
|
||||||
|
extensions=list(extensions.keys()) if extensions else None,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -299,6 +319,9 @@ async def _aexecute_a2a_delegation_impl(
|
|||||||
agent_role: str | None,
|
agent_role: str | None,
|
||||||
response_model: type[BaseModel] | None,
|
response_model: type[BaseModel] | None,
|
||||||
updates: UpdateConfig | None,
|
updates: UpdateConfig | None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
skill_id: str | None = None,
|
||||||
) -> TaskStateResult:
|
) -> TaskStateResult:
|
||||||
"""Internal async implementation of A2A delegation."""
|
"""Internal async implementation of A2A delegation."""
|
||||||
if auth:
|
if auth:
|
||||||
@@ -331,6 +354,28 @@ async def _aexecute_a2a_delegation_impl(
|
|||||||
if agent_card.name:
|
if agent_card.name:
|
||||||
a2a_agent_name = agent_card.name
|
a2a_agent_name = agent_card.name
|
||||||
|
|
||||||
|
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2ADelegationStartedEvent(
|
||||||
|
endpoint=endpoint,
|
||||||
|
task_description=task_description,
|
||||||
|
agent_id=agent_id or endpoint,
|
||||||
|
context_id=context_id,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
turn_number=turn_number,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
agent_card=agent_card_dict,
|
||||||
|
protocol_version=agent_card.protocol_version,
|
||||||
|
provider=agent_card_dict.get("provider"),
|
||||||
|
skill_id=skill_id,
|
||||||
|
metadata=metadata,
|
||||||
|
extensions=list(extensions.keys()) if extensions else None,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if turn_number == 1:
|
if turn_number == 1:
|
||||||
agent_id_for_event = agent_id or endpoint
|
agent_id_for_event = agent_id or endpoint
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -338,7 +383,17 @@ async def _aexecute_a2a_delegation_impl(
|
|||||||
A2AConversationStartedEvent(
|
A2AConversationStartedEvent(
|
||||||
agent_id=agent_id_for_event,
|
agent_id=agent_id_for_event,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
|
context_id=context_id,
|
||||||
a2a_agent_name=a2a_agent_name,
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
agent_card=agent_card_dict,
|
||||||
|
protocol_version=agent_card.protocol_version,
|
||||||
|
provider=agent_card_dict.get("provider"),
|
||||||
|
skill_id=skill_id,
|
||||||
|
reference_task_ids=reference_task_ids,
|
||||||
|
metadata=metadata,
|
||||||
|
extensions=list(extensions.keys()) if extensions else None,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -364,6 +419,10 @@ async def _aexecute_a2a_delegation_impl(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
message_metadata = metadata.copy() if metadata else {}
|
||||||
|
if skill_id:
|
||||||
|
message_metadata["skill_id"] = skill_id
|
||||||
|
|
||||||
message = Message(
|
message = Message(
|
||||||
role=Role.user,
|
role=Role.user,
|
||||||
message_id=str(uuid.uuid4()),
|
message_id=str(uuid.uuid4()),
|
||||||
@@ -371,7 +430,7 @@ async def _aexecute_a2a_delegation_impl(
|
|||||||
context_id=context_id,
|
context_id=context_id,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
reference_task_ids=reference_task_ids,
|
reference_task_ids=reference_task_ids,
|
||||||
metadata=metadata,
|
metadata=message_metadata if message_metadata else None,
|
||||||
extensions=extensions,
|
extensions=extensions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -381,8 +440,17 @@ async def _aexecute_a2a_delegation_impl(
|
|||||||
A2AMessageSentEvent(
|
A2AMessageSentEvent(
|
||||||
message=message_text,
|
message=message_text,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
|
context_id=context_id,
|
||||||
|
message_id=message.message_id,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
skill_id=skill_id,
|
||||||
|
metadata=message_metadata if message_metadata else None,
|
||||||
|
extensions=list(extensions.keys()) if extensions else None,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -397,6 +465,9 @@ async def _aexecute_a2a_delegation_impl(
|
|||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
"endpoint": endpoint,
|
"endpoint": endpoint,
|
||||||
"agent_branch": agent_branch,
|
"agent_branch": agent_branch,
|
||||||
|
"a2a_agent_name": a2a_agent_name,
|
||||||
|
"from_task": from_task,
|
||||||
|
"from_agent": from_agent,
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(updates, PollingConfig):
|
if isinstance(updates, PollingConfig):
|
||||||
@@ -434,13 +505,16 @@ async def _aexecute_a2a_delegation_impl(
|
|||||||
use_polling=use_polling,
|
use_polling=use_polling,
|
||||||
push_notification_config=push_config_for_client,
|
push_notification_config=push_config_for_client,
|
||||||
) as client:
|
) as client:
|
||||||
return await handler.execute(
|
result = await handler.execute(
|
||||||
client=client,
|
client=client,
|
||||||
message=message,
|
message=message,
|
||||||
new_messages=new_messages,
|
new_messages=new_messages,
|
||||||
agent_card=agent_card,
|
agent_card=agent_card,
|
||||||
**handler_kwargs,
|
**handler_kwargs,
|
||||||
)
|
)
|
||||||
|
result["a2a_agent_name"] = a2a_agent_name
|
||||||
|
result["agent_card"] = agent_card.model_dump(exclude_none=True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|||||||
@@ -3,11 +3,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine
|
||||||
|
from datetime import datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from a2a.server.agent_execution import RequestContext
|
from a2a.server.agent_execution import RequestContext
|
||||||
from a2a.server.events import EventQueue
|
from a2a.server.events import EventQueue
|
||||||
@@ -45,7 +48,14 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
def _parse_redis_url(url: str) -> dict[str, Any]:
|
def _parse_redis_url(url: str) -> dict[str, Any]:
|
||||||
from urllib.parse import urlparse
|
"""Parse a Redis URL into aiocache configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Redis connection URL (e.g., redis://localhost:6379/0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configuration dict for aiocache.RedisCache.
|
||||||
|
"""
|
||||||
|
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
config: dict[str, Any] = {
|
config: dict[str, Any] = {
|
||||||
@@ -127,7 +137,7 @@ def cancellable(
|
|||||||
async for message in pubsub.listen():
|
async for message in pubsub.listen():
|
||||||
if message["type"] == "message":
|
if message["type"] == "message":
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except (OSError, ConnectionError) as e:
|
||||||
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
|
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
|
||||||
return await poll_for_cancel()
|
return await poll_for_cancel()
|
||||||
return False
|
return False
|
||||||
@@ -183,7 +193,12 @@ async def execute(
|
|||||||
msg = "task_id and context_id are required"
|
msg = "task_id and context_id are required"
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent,
|
agent,
|
||||||
A2AServerTaskFailedEvent(a2a_task_id="", a2a_context_id="", error=msg),
|
A2AServerTaskFailedEvent(
|
||||||
|
task_id="",
|
||||||
|
context_id="",
|
||||||
|
error=msg,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
raise ServerError(InvalidParamsError(message=msg)) from None
|
raise ServerError(InvalidParamsError(message=msg)) from None
|
||||||
|
|
||||||
@@ -195,7 +210,12 @@ async def execute(
|
|||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent,
|
agent,
|
||||||
A2AServerTaskStartedEvent(a2a_task_id=task_id, a2a_context_id=context_id),
|
A2AServerTaskStartedEvent(
|
||||||
|
task_id=task_id,
|
||||||
|
context_id=context_id,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -215,20 +235,33 @@ async def execute(
|
|||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent,
|
agent,
|
||||||
A2AServerTaskCompletedEvent(
|
A2AServerTaskCompletedEvent(
|
||||||
a2a_task_id=task_id, a2a_context_id=context_id, result=str(result)
|
task_id=task_id,
|
||||||
|
context_id=context_id,
|
||||||
|
result=str(result),
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent,
|
agent,
|
||||||
A2AServerTaskCanceledEvent(a2a_task_id=task_id, a2a_context_id=context_id),
|
A2AServerTaskCanceledEvent(
|
||||||
|
task_id=task_id,
|
||||||
|
context_id=context_id,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent,
|
agent,
|
||||||
A2AServerTaskFailedEvent(
|
A2AServerTaskFailedEvent(
|
||||||
a2a_task_id=task_id, a2a_context_id=context_id, error=str(e)
|
task_id=task_id,
|
||||||
|
context_id=context_id,
|
||||||
|
error=str(e),
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
raise ServerError(
|
raise ServerError(
|
||||||
@@ -282,3 +315,85 @@ async def cancel(
|
|||||||
context.current_task.status = TaskStatus(state=TaskState.canceled)
|
context.current_task.status = TaskStatus(state=TaskState.canceled)
|
||||||
return context.current_task
|
return context.current_task
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def list_tasks(
|
||||||
|
tasks: list[A2ATask],
|
||||||
|
context_id: str | None = None,
|
||||||
|
status: TaskState | None = None,
|
||||||
|
status_timestamp_after: datetime | None = None,
|
||||||
|
page_size: int = 50,
|
||||||
|
page_token: str | None = None,
|
||||||
|
history_length: int | None = None,
|
||||||
|
include_artifacts: bool = False,
|
||||||
|
) -> tuple[list[A2ATask], str | None, int]:
|
||||||
|
"""Filter and paginate A2A tasks.
|
||||||
|
|
||||||
|
Provides filtering by context, status, and timestamp, along with
|
||||||
|
cursor-based pagination. This is a pure utility function that operates
|
||||||
|
on an in-memory list of tasks - storage retrieval is handled separately.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks: All tasks to filter.
|
||||||
|
context_id: Filter by context ID to get tasks in a conversation.
|
||||||
|
status: Filter by task state (e.g., completed, working).
|
||||||
|
status_timestamp_after: Filter to tasks updated after this time.
|
||||||
|
page_size: Maximum tasks per page (default 50).
|
||||||
|
page_token: Base64-encoded cursor from previous response.
|
||||||
|
history_length: Limit history messages per task (None = full history).
|
||||||
|
include_artifacts: Whether to include task artifacts (default False).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filtered_tasks, next_page_token, total_count).
|
||||||
|
- filtered_tasks: Tasks matching filters, paginated and trimmed.
|
||||||
|
- next_page_token: Token for next page, or None if no more pages.
|
||||||
|
- total_count: Total number of tasks matching filters (before pagination).
|
||||||
|
"""
|
||||||
|
filtered: list[A2ATask] = []
|
||||||
|
for task in tasks:
|
||||||
|
if context_id and task.context_id != context_id:
|
||||||
|
continue
|
||||||
|
if status and task.status.state != status:
|
||||||
|
continue
|
||||||
|
if status_timestamp_after and task.status.timestamp:
|
||||||
|
ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00"))
|
||||||
|
if ts <= status_timestamp_after:
|
||||||
|
continue
|
||||||
|
filtered.append(task)
|
||||||
|
|
||||||
|
def get_timestamp(t: A2ATask) -> datetime:
|
||||||
|
"""Extract timestamp from task status for sorting."""
|
||||||
|
if t.status.timestamp is None:
|
||||||
|
return datetime.min
|
||||||
|
return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
filtered.sort(key=get_timestamp, reverse=True)
|
||||||
|
total = len(filtered)
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
if page_token:
|
||||||
|
try:
|
||||||
|
cursor_id = base64.b64decode(page_token).decode()
|
||||||
|
for idx, task in enumerate(filtered):
|
||||||
|
if task.id == cursor_id:
|
||||||
|
start = idx + 1
|
||||||
|
break
|
||||||
|
except (ValueError, UnicodeDecodeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
page = filtered[start : start + page_size]
|
||||||
|
|
||||||
|
result: list[A2ATask] = []
|
||||||
|
for task in page:
|
||||||
|
task = task.model_copy(deep=True)
|
||||||
|
if history_length is not None and task.history:
|
||||||
|
task.history = task.history[-history_length:]
|
||||||
|
if not include_artifacts:
|
||||||
|
task.artifacts = None
|
||||||
|
result.append(task)
|
||||||
|
|
||||||
|
next_token: str | None = None
|
||||||
|
if result and len(result) == page_size:
|
||||||
|
next_token = base64.b64encode(result[-1].id.encode()).decode()
|
||||||
|
|
||||||
|
return result, next_token, total
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ Wraps agent classes with A2A delegation capabilities.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine, Mapping
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
import json
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
@@ -189,7 +190,7 @@ def _execute_task_with_a2a(
|
|||||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||||
original_fn: Callable[..., str],
|
original_fn: Callable[..., str],
|
||||||
task: Task,
|
task: Task,
|
||||||
agent_response_model: type[BaseModel],
|
agent_response_model: type[BaseModel] | None,
|
||||||
context: str | None,
|
context: str | None,
|
||||||
tools: list[BaseTool] | None,
|
tools: list[BaseTool] | None,
|
||||||
extension_registry: ExtensionRegistry,
|
extension_registry: ExtensionRegistry,
|
||||||
@@ -277,7 +278,7 @@ def _execute_task_with_a2a(
|
|||||||
def _augment_prompt_with_a2a(
|
def _augment_prompt_with_a2a(
|
||||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||||
task_description: str,
|
task_description: str,
|
||||||
agent_cards: dict[str, AgentCard],
|
agent_cards: Mapping[str, AgentCard | dict[str, Any]],
|
||||||
conversation_history: list[Message] | None = None,
|
conversation_history: list[Message] | None = None,
|
||||||
turn_num: int = 0,
|
turn_num: int = 0,
|
||||||
max_turns: int | None = None,
|
max_turns: int | None = None,
|
||||||
@@ -309,7 +310,15 @@ def _augment_prompt_with_a2a(
|
|||||||
for config in a2a_agents:
|
for config in a2a_agents:
|
||||||
if config.endpoint in agent_cards:
|
if config.endpoint in agent_cards:
|
||||||
card = agent_cards[config.endpoint]
|
card = agent_cards[config.endpoint]
|
||||||
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
|
if isinstance(card, dict):
|
||||||
|
filtered = {
|
||||||
|
k: v
|
||||||
|
for k, v in card.items()
|
||||||
|
if k in {"description", "url", "skills"} and v is not None
|
||||||
|
}
|
||||||
|
agents_text += f"\n{json.dumps(filtered, indent=2)}\n"
|
||||||
|
else:
|
||||||
|
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
|
||||||
|
|
||||||
failed_agents = failed_agents or {}
|
failed_agents = failed_agents or {}
|
||||||
if failed_agents:
|
if failed_agents:
|
||||||
@@ -377,7 +386,7 @@ IMPORTANT: You have the ability to delegate this task to remote A2A agents.
|
|||||||
|
|
||||||
|
|
||||||
def _parse_agent_response(
|
def _parse_agent_response(
|
||||||
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
|
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel] | None
|
||||||
) -> BaseModel | str | dict[str, Any]:
|
) -> BaseModel | str | dict[str, Any]:
|
||||||
"""Parse LLM output as AgentResponse or return raw agent response."""
|
"""Parse LLM output as AgentResponse or return raw agent response."""
|
||||||
if agent_response_model:
|
if agent_response_model:
|
||||||
@@ -394,6 +403,11 @@ def _parse_agent_response(
|
|||||||
def _handle_max_turns_exceeded(
|
def _handle_max_turns_exceeded(
|
||||||
conversation_history: list[Message],
|
conversation_history: list[Message],
|
||||||
max_turns: int,
|
max_turns: int,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
a2a_agent_name: str | None = None,
|
||||||
|
agent_card: dict[str, Any] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle the case when max turns is exceeded.
|
"""Handle the case when max turns is exceeded.
|
||||||
|
|
||||||
@@ -421,6 +435,11 @@ def _handle_max_turns_exceeded(
|
|||||||
final_result=final_message,
|
final_result=final_message,
|
||||||
error=None,
|
error=None,
|
||||||
total_turns=max_turns,
|
total_turns=max_turns,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
agent_card=agent_card,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return final_message
|
return final_message
|
||||||
@@ -432,6 +451,11 @@ def _handle_max_turns_exceeded(
|
|||||||
final_result=None,
|
final_result=None,
|
||||||
error=f"Conversation exceeded maximum turns ({max_turns})",
|
error=f"Conversation exceeded maximum turns ({max_turns})",
|
||||||
total_turns=max_turns,
|
total_turns=max_turns,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
agent_card=agent_card,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})")
|
raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})")
|
||||||
@@ -442,7 +466,12 @@ def _process_response_result(
|
|||||||
disable_structured_output: bool,
|
disable_structured_output: bool,
|
||||||
turn_num: int,
|
turn_num: int,
|
||||||
agent_role: str,
|
agent_role: str,
|
||||||
agent_response_model: type[BaseModel],
|
agent_response_model: type[BaseModel] | None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
a2a_agent_name: str | None = None,
|
||||||
|
agent_card: dict[str, Any] | None = None,
|
||||||
) -> tuple[str | None, str | None]:
|
) -> tuple[str | None, str | None]:
|
||||||
"""Process LLM response and determine next action.
|
"""Process LLM response and determine next action.
|
||||||
|
|
||||||
@@ -461,6 +490,10 @@ def _process_response_result(
|
|||||||
turn_number=final_turn_number,
|
turn_number=final_turn_number,
|
||||||
is_multiturn=True,
|
is_multiturn=True,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -470,6 +503,11 @@ def _process_response_result(
|
|||||||
final_result=result_text,
|
final_result=result_text,
|
||||||
error=None,
|
error=None,
|
||||||
total_turns=final_turn_number,
|
total_turns=final_turn_number,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
agent_card=agent_card,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return result_text, None
|
return result_text, None
|
||||||
@@ -490,6 +528,10 @@ def _process_response_result(
|
|||||||
turn_number=final_turn_number,
|
turn_number=final_turn_number,
|
||||||
is_multiturn=True,
|
is_multiturn=True,
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -499,6 +541,11 @@ def _process_response_result(
|
|||||||
final_result=str(llm_response.message),
|
final_result=str(llm_response.message),
|
||||||
error=None,
|
error=None,
|
||||||
total_turns=final_turn_number,
|
total_turns=final_turn_number,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
agent_card=agent_card,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return str(llm_response.message), None
|
return str(llm_response.message), None
|
||||||
@@ -510,13 +557,15 @@ def _process_response_result(
|
|||||||
def _prepare_agent_cards_dict(
|
def _prepare_agent_cards_dict(
|
||||||
a2a_result: TaskStateResult,
|
a2a_result: TaskStateResult,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
agent_cards: dict[str, AgentCard] | None,
|
agent_cards: Mapping[str, AgentCard | dict[str, Any]] | None,
|
||||||
) -> dict[str, AgentCard]:
|
) -> dict[str, AgentCard | dict[str, Any]]:
|
||||||
"""Prepare agent cards dictionary from result and existing cards.
|
"""Prepare agent cards dictionary from result and existing cards.
|
||||||
|
|
||||||
Shared logic for both sync and async response handlers.
|
Shared logic for both sync and async response handlers.
|
||||||
"""
|
"""
|
||||||
agent_cards_dict = agent_cards or {}
|
agent_cards_dict: dict[str, AgentCard | dict[str, Any]] = (
|
||||||
|
dict(agent_cards) if agent_cards else {}
|
||||||
|
)
|
||||||
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
|
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
|
||||||
agent_cards_dict[agent_id] = a2a_result["agent_card"]
|
agent_cards_dict[agent_id] = a2a_result["agent_card"]
|
||||||
return agent_cards_dict
|
return agent_cards_dict
|
||||||
@@ -529,7 +578,7 @@ def _prepare_delegation_context(
|
|||||||
original_task_description: str | None,
|
original_task_description: str | None,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
list[A2AConfig | A2AClientConfig],
|
list[A2AConfig | A2AClientConfig],
|
||||||
type[BaseModel],
|
type[BaseModel] | None,
|
||||||
str,
|
str,
|
||||||
str,
|
str,
|
||||||
A2AConfig | A2AClientConfig,
|
A2AConfig | A2AClientConfig,
|
||||||
@@ -598,6 +647,11 @@ def _handle_task_completion(
|
|||||||
reference_task_ids: list[str],
|
reference_task_ids: list[str],
|
||||||
agent_config: A2AConfig | A2AClientConfig,
|
agent_config: A2AConfig | A2AClientConfig,
|
||||||
turn_num: int,
|
turn_num: int,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
a2a_agent_name: str | None = None,
|
||||||
|
agent_card: dict[str, Any] | None = None,
|
||||||
) -> tuple[str | None, str | None, list[str]]:
|
) -> tuple[str | None, str | None, list[str]]:
|
||||||
"""Handle task completion state including reference task updates.
|
"""Handle task completion state including reference task updates.
|
||||||
|
|
||||||
@@ -624,6 +678,11 @@ def _handle_task_completion(
|
|||||||
final_result=result_text,
|
final_result=result_text,
|
||||||
error=None,
|
error=None,
|
||||||
total_turns=final_turn_number,
|
total_turns=final_turn_number,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
agent_card=agent_card,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return str(result_text), task_id_config, reference_task_ids
|
return str(result_text), task_id_config, reference_task_ids
|
||||||
@@ -645,8 +704,11 @@ def _handle_agent_response_and_continue(
|
|||||||
original_fn: Callable[..., str],
|
original_fn: Callable[..., str],
|
||||||
context: str | None,
|
context: str | None,
|
||||||
tools: list[BaseTool] | None,
|
tools: list[BaseTool] | None,
|
||||||
agent_response_model: type[BaseModel],
|
agent_response_model: type[BaseModel] | None,
|
||||||
remote_task_completed: bool = False,
|
remote_task_completed: bool = False,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
a2a_agent_name: str | None = None,
|
||||||
|
agent_card: dict[str, Any] | None = None,
|
||||||
) -> tuple[str | None, str | None]:
|
) -> tuple[str | None, str | None]:
|
||||||
"""Handle A2A result and get CrewAI agent's response.
|
"""Handle A2A result and get CrewAI agent's response.
|
||||||
|
|
||||||
@@ -698,6 +760,11 @@ def _handle_agent_response_and_continue(
|
|||||||
turn_num=turn_num,
|
turn_num=turn_num,
|
||||||
agent_role=self.role,
|
agent_role=self.role,
|
||||||
agent_response_model=agent_response_model,
|
agent_response_model=agent_response_model,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
agent_card=agent_card,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -750,6 +817,12 @@ def _delegate_to_a2a(
|
|||||||
|
|
||||||
conversation_history: list[Message] = []
|
conversation_history: list[Message] = []
|
||||||
|
|
||||||
|
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
|
||||||
|
current_agent_card_dict = (
|
||||||
|
current_agent_card.model_dump() if current_agent_card else None
|
||||||
|
)
|
||||||
|
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for turn_num in range(max_turns):
|
for turn_num in range(max_turns):
|
||||||
console_formatter = getattr(crewai_event_bus, "_console", None)
|
console_formatter = getattr(crewai_event_bus, "_console", None)
|
||||||
@@ -777,6 +850,8 @@ def _delegate_to_a2a(
|
|||||||
turn_number=turn_num + 1,
|
turn_number=turn_num + 1,
|
||||||
updates=agent_config.updates,
|
updates=agent_config.updates,
|
||||||
transport_protocol=agent_config.transport_protocol,
|
transport_protocol=agent_config.transport_protocol,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_history = a2a_result.get("history", [])
|
conversation_history = a2a_result.get("history", [])
|
||||||
@@ -797,6 +872,11 @@ def _delegate_to_a2a(
|
|||||||
reference_task_ids,
|
reference_task_ids,
|
||||||
agent_config,
|
agent_config,
|
||||||
turn_num,
|
turn_num,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if trusted_result is not None:
|
if trusted_result is not None:
|
||||||
@@ -818,6 +898,9 @@ def _delegate_to_a2a(
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
agent_response_model=agent_response_model,
|
agent_response_model=agent_response_model,
|
||||||
remote_task_completed=(a2a_result["status"] == TaskState.completed),
|
remote_task_completed=(a2a_result["status"] == TaskState.completed),
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_result is not None:
|
if final_result is not None:
|
||||||
@@ -846,6 +929,9 @@ def _delegate_to_a2a(
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
agent_response_model=agent_response_model,
|
agent_response_model=agent_response_model,
|
||||||
remote_task_completed=False,
|
remote_task_completed=False,
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_result is not None:
|
if final_result is not None:
|
||||||
@@ -862,11 +948,24 @@ def _delegate_to_a2a(
|
|||||||
final_result=None,
|
final_result=None,
|
||||||
error=error_msg,
|
error=error_msg,
|
||||||
total_turns=turn_num + 1,
|
total_turns=turn_num + 1,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return f"A2A delegation failed: {error_msg}"
|
return f"A2A delegation failed: {error_msg}"
|
||||||
|
|
||||||
return _handle_max_turns_exceeded(conversation_history, max_turns)
|
return _handle_max_turns_exceeded(
|
||||||
|
conversation_history,
|
||||||
|
max_turns,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
task.description = original_task_description
|
task.description = original_task_description
|
||||||
@@ -916,7 +1015,7 @@ async def _aexecute_task_with_a2a(
|
|||||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||||
original_fn: Callable[..., Coroutine[Any, Any, str]],
|
original_fn: Callable[..., Coroutine[Any, Any, str]],
|
||||||
task: Task,
|
task: Task,
|
||||||
agent_response_model: type[BaseModel],
|
agent_response_model: type[BaseModel] | None,
|
||||||
context: str | None,
|
context: str | None,
|
||||||
tools: list[BaseTool] | None,
|
tools: list[BaseTool] | None,
|
||||||
extension_registry: ExtensionRegistry,
|
extension_registry: ExtensionRegistry,
|
||||||
@@ -1001,8 +1100,11 @@ async def _ahandle_agent_response_and_continue(
|
|||||||
original_fn: Callable[..., Coroutine[Any, Any, str]],
|
original_fn: Callable[..., Coroutine[Any, Any, str]],
|
||||||
context: str | None,
|
context: str | None,
|
||||||
tools: list[BaseTool] | None,
|
tools: list[BaseTool] | None,
|
||||||
agent_response_model: type[BaseModel],
|
agent_response_model: type[BaseModel] | None,
|
||||||
remote_task_completed: bool = False,
|
remote_task_completed: bool = False,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
a2a_agent_name: str | None = None,
|
||||||
|
agent_card: dict[str, Any] | None = None,
|
||||||
) -> tuple[str | None, str | None]:
|
) -> tuple[str | None, str | None]:
|
||||||
"""Async version of _handle_agent_response_and_continue."""
|
"""Async version of _handle_agent_response_and_continue."""
|
||||||
agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards)
|
agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards)
|
||||||
@@ -1032,6 +1134,11 @@ async def _ahandle_agent_response_and_continue(
|
|||||||
turn_num=turn_num,
|
turn_num=turn_num,
|
||||||
agent_role=self.role,
|
agent_role=self.role,
|
||||||
agent_response_model=agent_response_model,
|
agent_response_model=agent_response_model,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
|
endpoint=endpoint,
|
||||||
|
a2a_agent_name=a2a_agent_name,
|
||||||
|
agent_card=agent_card,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1066,6 +1173,12 @@ async def _adelegate_to_a2a(
|
|||||||
|
|
||||||
conversation_history: list[Message] = []
|
conversation_history: list[Message] = []
|
||||||
|
|
||||||
|
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
|
||||||
|
current_agent_card_dict = (
|
||||||
|
current_agent_card.model_dump() if current_agent_card else None
|
||||||
|
)
|
||||||
|
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for turn_num in range(max_turns):
|
for turn_num in range(max_turns):
|
||||||
console_formatter = getattr(crewai_event_bus, "_console", None)
|
console_formatter = getattr(crewai_event_bus, "_console", None)
|
||||||
@@ -1093,6 +1206,8 @@ async def _adelegate_to_a2a(
|
|||||||
turn_number=turn_num + 1,
|
turn_number=turn_num + 1,
|
||||||
transport_protocol=agent_config.transport_protocol,
|
transport_protocol=agent_config.transport_protocol,
|
||||||
updates=agent_config.updates,
|
updates=agent_config.updates,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_history = a2a_result.get("history", [])
|
conversation_history = a2a_result.get("history", [])
|
||||||
@@ -1113,6 +1228,11 @@ async def _adelegate_to_a2a(
|
|||||||
reference_task_ids,
|
reference_task_ids,
|
||||||
agent_config,
|
agent_config,
|
||||||
turn_num,
|
turn_num,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if trusted_result is not None:
|
if trusted_result is not None:
|
||||||
@@ -1134,6 +1254,9 @@ async def _adelegate_to_a2a(
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
agent_response_model=agent_response_model,
|
agent_response_model=agent_response_model,
|
||||||
remote_task_completed=(a2a_result["status"] == TaskState.completed),
|
remote_task_completed=(a2a_result["status"] == TaskState.completed),
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_result is not None:
|
if final_result is not None:
|
||||||
@@ -1161,6 +1284,9 @@ async def _adelegate_to_a2a(
|
|||||||
context=context,
|
context=context,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
agent_response_model=agent_response_model,
|
agent_response_model=agent_response_model,
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_result is not None:
|
if final_result is not None:
|
||||||
@@ -1177,11 +1303,24 @@ async def _adelegate_to_a2a(
|
|||||||
final_result=None,
|
final_result=None,
|
||||||
error=error_msg,
|
error=error_msg,
|
||||||
total_turns=turn_num + 1,
|
total_turns=turn_num + 1,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return f"A2A delegation failed: {error_msg}"
|
return f"A2A delegation failed: {error_msg}"
|
||||||
|
|
||||||
return _handle_max_turns_exceeded(conversation_history, max_turns)
|
return _handle_max_turns_exceeded(
|
||||||
|
conversation_history,
|
||||||
|
max_turns,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=self,
|
||||||
|
endpoint=agent_config.endpoint,
|
||||||
|
a2a_agent_name=current_a2a_agent_name,
|
||||||
|
agent_card=current_agent_card_dict,
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
task.description = original_task_description
|
task.description = original_task_description
|
||||||
|
|||||||
@@ -219,7 +219,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Final answer from the agent.
|
Final answer from the agent.
|
||||||
"""
|
"""
|
||||||
formatted_answer = None
|
formatted_answer = None
|
||||||
last_raw_output: str | None = None
|
|
||||||
while not isinstance(formatted_answer, AgentFinish):
|
while not isinstance(formatted_answer, AgentFinish):
|
||||||
try:
|
try:
|
||||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||||
@@ -245,7 +244,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
response_model=self.response_model,
|
response_model=self.response_model,
|
||||||
executor_context=self,
|
executor_context=self,
|
||||||
)
|
)
|
||||||
last_raw_output = answer
|
|
||||||
if self.response_model is not None:
|
if self.response_model is not None:
|
||||||
try:
|
try:
|
||||||
self.response_model.model_validate_json(answer)
|
self.response_model.model_validate_json(answer)
|
||||||
@@ -302,8 +300,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
iterations=self.iterations,
|
iterations=self.iterations,
|
||||||
log_error_after=self.log_error_after,
|
log_error_after=self.log_error_after,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
raw_output=last_raw_output,
|
|
||||||
agent_role=self.agent.role if self.agent else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -390,7 +386,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Final answer from the agent.
|
Final answer from the agent.
|
||||||
"""
|
"""
|
||||||
formatted_answer = None
|
formatted_answer = None
|
||||||
last_raw_output: str | None = None
|
|
||||||
while not isinstance(formatted_answer, AgentFinish):
|
while not isinstance(formatted_answer, AgentFinish):
|
||||||
try:
|
try:
|
||||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||||
@@ -416,7 +411,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
response_model=self.response_model,
|
response_model=self.response_model,
|
||||||
executor_context=self,
|
executor_context=self,
|
||||||
)
|
)
|
||||||
last_raw_output = answer
|
|
||||||
|
|
||||||
if self.response_model is not None:
|
if self.response_model is not None:
|
||||||
try:
|
try:
|
||||||
@@ -473,8 +467,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
iterations=self.iterations,
|
iterations=self.iterations,
|
||||||
log_error_after=self.log_error_after,
|
log_error_after=self.log_error_after,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
raw_output=last_raw_output,
|
|
||||||
agent_role=self.agent.role if self.agent else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,19 +1,28 @@
|
|||||||
from crewai.events.types.a2a_events import (
|
from crewai.events.types.a2a_events import (
|
||||||
|
A2AAgentCardFetchedEvent,
|
||||||
|
A2AArtifactReceivedEvent,
|
||||||
|
A2AAuthenticationFailedEvent,
|
||||||
|
A2AConnectionErrorEvent,
|
||||||
A2AConversationCompletedEvent,
|
A2AConversationCompletedEvent,
|
||||||
A2AConversationStartedEvent,
|
A2AConversationStartedEvent,
|
||||||
A2ADelegationCompletedEvent,
|
A2ADelegationCompletedEvent,
|
||||||
A2ADelegationStartedEvent,
|
A2ADelegationStartedEvent,
|
||||||
A2AMessageSentEvent,
|
A2AMessageSentEvent,
|
||||||
|
A2AParallelDelegationCompletedEvent,
|
||||||
|
A2AParallelDelegationStartedEvent,
|
||||||
A2APollingStartedEvent,
|
A2APollingStartedEvent,
|
||||||
A2APollingStatusEvent,
|
A2APollingStatusEvent,
|
||||||
A2APushNotificationReceivedEvent,
|
A2APushNotificationReceivedEvent,
|
||||||
A2APushNotificationRegisteredEvent,
|
A2APushNotificationRegisteredEvent,
|
||||||
|
A2APushNotificationSentEvent,
|
||||||
A2APushNotificationTimeoutEvent,
|
A2APushNotificationTimeoutEvent,
|
||||||
A2AResponseReceivedEvent,
|
A2AResponseReceivedEvent,
|
||||||
A2AServerTaskCanceledEvent,
|
A2AServerTaskCanceledEvent,
|
||||||
A2AServerTaskCompletedEvent,
|
A2AServerTaskCompletedEvent,
|
||||||
A2AServerTaskFailedEvent,
|
A2AServerTaskFailedEvent,
|
||||||
A2AServerTaskStartedEvent,
|
A2AServerTaskStartedEvent,
|
||||||
|
A2AStreamingChunkEvent,
|
||||||
|
A2AStreamingStartedEvent,
|
||||||
)
|
)
|
||||||
from crewai.events.types.agent_events import (
|
from crewai.events.types.agent_events import (
|
||||||
AgentExecutionCompletedEvent,
|
AgentExecutionCompletedEvent,
|
||||||
@@ -93,7 +102,11 @@ from crewai.events.types.tool_usage_events import (
|
|||||||
|
|
||||||
|
|
||||||
EventTypes = (
|
EventTypes = (
|
||||||
A2AConversationCompletedEvent
|
A2AAgentCardFetchedEvent
|
||||||
|
| A2AArtifactReceivedEvent
|
||||||
|
| A2AAuthenticationFailedEvent
|
||||||
|
| A2AConnectionErrorEvent
|
||||||
|
| A2AConversationCompletedEvent
|
||||||
| A2AConversationStartedEvent
|
| A2AConversationStartedEvent
|
||||||
| A2ADelegationCompletedEvent
|
| A2ADelegationCompletedEvent
|
||||||
| A2ADelegationStartedEvent
|
| A2ADelegationStartedEvent
|
||||||
@@ -102,12 +115,17 @@ EventTypes = (
|
|||||||
| A2APollingStatusEvent
|
| A2APollingStatusEvent
|
||||||
| A2APushNotificationReceivedEvent
|
| A2APushNotificationReceivedEvent
|
||||||
| A2APushNotificationRegisteredEvent
|
| A2APushNotificationRegisteredEvent
|
||||||
|
| A2APushNotificationSentEvent
|
||||||
| A2APushNotificationTimeoutEvent
|
| A2APushNotificationTimeoutEvent
|
||||||
| A2AResponseReceivedEvent
|
| A2AResponseReceivedEvent
|
||||||
| A2AServerTaskCanceledEvent
|
| A2AServerTaskCanceledEvent
|
||||||
| A2AServerTaskCompletedEvent
|
| A2AServerTaskCompletedEvent
|
||||||
| A2AServerTaskFailedEvent
|
| A2AServerTaskFailedEvent
|
||||||
| A2AServerTaskStartedEvent
|
| A2AServerTaskStartedEvent
|
||||||
|
| A2AStreamingChunkEvent
|
||||||
|
| A2AStreamingStartedEvent
|
||||||
|
| A2AParallelDelegationStartedEvent
|
||||||
|
| A2AParallelDelegationCompletedEvent
|
||||||
| CrewKickoffStartedEvent
|
| CrewKickoffStartedEvent
|
||||||
| CrewKickoffCompletedEvent
|
| CrewKickoffCompletedEvent
|
||||||
| CrewKickoffFailedEvent
|
| CrewKickoffFailedEvent
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Trace collection listener for orchestrating trace collection."""
|
"""Trace collection listener for orchestrating trace collection."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, ClassVar, cast
|
from typing import Any, ClassVar
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@@ -18,6 +18,32 @@ from crewai.events.listeners.tracing.types import TraceEvent
|
|||||||
from crewai.events.listeners.tracing.utils import (
|
from crewai.events.listeners.tracing.utils import (
|
||||||
safe_serialize_to_dict,
|
safe_serialize_to_dict,
|
||||||
)
|
)
|
||||||
|
from crewai.events.types.a2a_events import (
|
||||||
|
A2AAgentCardFetchedEvent,
|
||||||
|
A2AArtifactReceivedEvent,
|
||||||
|
A2AAuthenticationFailedEvent,
|
||||||
|
A2AConnectionErrorEvent,
|
||||||
|
A2AConversationCompletedEvent,
|
||||||
|
A2AConversationStartedEvent,
|
||||||
|
A2ADelegationCompletedEvent,
|
||||||
|
A2ADelegationStartedEvent,
|
||||||
|
A2AMessageSentEvent,
|
||||||
|
A2AParallelDelegationCompletedEvent,
|
||||||
|
A2AParallelDelegationStartedEvent,
|
||||||
|
A2APollingStartedEvent,
|
||||||
|
A2APollingStatusEvent,
|
||||||
|
A2APushNotificationReceivedEvent,
|
||||||
|
A2APushNotificationRegisteredEvent,
|
||||||
|
A2APushNotificationSentEvent,
|
||||||
|
A2APushNotificationTimeoutEvent,
|
||||||
|
A2AResponseReceivedEvent,
|
||||||
|
A2AServerTaskCanceledEvent,
|
||||||
|
A2AServerTaskCompletedEvent,
|
||||||
|
A2AServerTaskFailedEvent,
|
||||||
|
A2AServerTaskStartedEvent,
|
||||||
|
A2AStreamingChunkEvent,
|
||||||
|
A2AStreamingStartedEvent,
|
||||||
|
)
|
||||||
from crewai.events.types.agent_events import (
|
from crewai.events.types.agent_events import (
|
||||||
AgentExecutionCompletedEvent,
|
AgentExecutionCompletedEvent,
|
||||||
AgentExecutionErrorEvent,
|
AgentExecutionErrorEvent,
|
||||||
@@ -105,7 +131,7 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
"""Create or return singleton instance."""
|
"""Create or return singleton instance."""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cast(Self, cls._instance)
|
return cls._instance
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -160,6 +186,7 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
self._register_flow_event_handlers(crewai_event_bus)
|
self._register_flow_event_handlers(crewai_event_bus)
|
||||||
self._register_context_event_handlers(crewai_event_bus)
|
self._register_context_event_handlers(crewai_event_bus)
|
||||||
self._register_action_event_handlers(crewai_event_bus)
|
self._register_action_event_handlers(crewai_event_bus)
|
||||||
|
self._register_a2a_event_handlers(crewai_event_bus)
|
||||||
self._register_system_event_handlers(crewai_event_bus)
|
self._register_system_event_handlers(crewai_event_bus)
|
||||||
|
|
||||||
self._listeners_setup = True
|
self._listeners_setup = True
|
||||||
@@ -439,6 +466,147 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._handle_action_event("knowledge_query_failed", source, event)
|
self._handle_action_event("knowledge_query_failed", source, event)
|
||||||
|
|
||||||
|
def _register_a2a_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||||
|
"""Register handlers for A2A (Agent-to-Agent) events."""
|
||||||
|
|
||||||
|
@event_bus.on(A2ADelegationStartedEvent)
|
||||||
|
def on_a2a_delegation_started(
|
||||||
|
source: Any, event: A2ADelegationStartedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_delegation_started", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2ADelegationCompletedEvent)
|
||||||
|
def on_a2a_delegation_completed(
|
||||||
|
source: Any, event: A2ADelegationCompletedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_delegation_completed", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AConversationStartedEvent)
|
||||||
|
def on_a2a_conversation_started(
|
||||||
|
source: Any, event: A2AConversationStartedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_conversation_started", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AMessageSentEvent)
|
||||||
|
def on_a2a_message_sent(source: Any, event: A2AMessageSentEvent) -> None:
|
||||||
|
self._handle_action_event("a2a_message_sent", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AResponseReceivedEvent)
|
||||||
|
def on_a2a_response_received(
|
||||||
|
source: Any, event: A2AResponseReceivedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_response_received", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AConversationCompletedEvent)
|
||||||
|
def on_a2a_conversation_completed(
|
||||||
|
source: Any, event: A2AConversationCompletedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_conversation_completed", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2APollingStartedEvent)
|
||||||
|
def on_a2a_polling_started(source: Any, event: A2APollingStartedEvent) -> None:
|
||||||
|
self._handle_action_event("a2a_polling_started", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2APollingStatusEvent)
|
||||||
|
def on_a2a_polling_status(source: Any, event: A2APollingStatusEvent) -> None:
|
||||||
|
self._handle_action_event("a2a_polling_status", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2APushNotificationRegisteredEvent)
|
||||||
|
def on_a2a_push_notification_registered(
|
||||||
|
source: Any, event: A2APushNotificationRegisteredEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_push_notification_registered", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2APushNotificationReceivedEvent)
|
||||||
|
def on_a2a_push_notification_received(
|
||||||
|
source: Any, event: A2APushNotificationReceivedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_push_notification_received", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2APushNotificationSentEvent)
|
||||||
|
def on_a2a_push_notification_sent(
|
||||||
|
source: Any, event: A2APushNotificationSentEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_push_notification_sent", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2APushNotificationTimeoutEvent)
|
||||||
|
def on_a2a_push_notification_timeout(
|
||||||
|
source: Any, event: A2APushNotificationTimeoutEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_push_notification_timeout", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AStreamingStartedEvent)
|
||||||
|
def on_a2a_streaming_started(
|
||||||
|
source: Any, event: A2AStreamingStartedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_streaming_started", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AStreamingChunkEvent)
|
||||||
|
def on_a2a_streaming_chunk(source: Any, event: A2AStreamingChunkEvent) -> None:
|
||||||
|
self._handle_action_event("a2a_streaming_chunk", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AAgentCardFetchedEvent)
|
||||||
|
def on_a2a_agent_card_fetched(
|
||||||
|
source: Any, event: A2AAgentCardFetchedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_agent_card_fetched", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AAuthenticationFailedEvent)
|
||||||
|
def on_a2a_authentication_failed(
|
||||||
|
source: Any, event: A2AAuthenticationFailedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_authentication_failed", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AArtifactReceivedEvent)
|
||||||
|
def on_a2a_artifact_received(
|
||||||
|
source: Any, event: A2AArtifactReceivedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_artifact_received", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AConnectionErrorEvent)
|
||||||
|
def on_a2a_connection_error(
|
||||||
|
source: Any, event: A2AConnectionErrorEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_connection_error", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AServerTaskStartedEvent)
|
||||||
|
def on_a2a_server_task_started(
|
||||||
|
source: Any, event: A2AServerTaskStartedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_server_task_started", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AServerTaskCompletedEvent)
|
||||||
|
def on_a2a_server_task_completed(
|
||||||
|
source: Any, event: A2AServerTaskCompletedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_server_task_completed", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AServerTaskCanceledEvent)
|
||||||
|
def on_a2a_server_task_canceled(
|
||||||
|
source: Any, event: A2AServerTaskCanceledEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_server_task_canceled", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AServerTaskFailedEvent)
|
||||||
|
def on_a2a_server_task_failed(
|
||||||
|
source: Any, event: A2AServerTaskFailedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_server_task_failed", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AParallelDelegationStartedEvent)
|
||||||
|
def on_a2a_parallel_delegation_started(
|
||||||
|
source: Any, event: A2AParallelDelegationStartedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event("a2a_parallel_delegation_started", source, event)
|
||||||
|
|
||||||
|
@event_bus.on(A2AParallelDelegationCompletedEvent)
|
||||||
|
def on_a2a_parallel_delegation_completed(
|
||||||
|
source: Any, event: A2AParallelDelegationCompletedEvent
|
||||||
|
) -> None:
|
||||||
|
self._handle_action_event(
|
||||||
|
"a2a_parallel_delegation_completed", source, event
|
||||||
|
)
|
||||||
|
|
||||||
def _register_system_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
def _register_system_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||||
"""Register handlers for system signal events (SIGTERM, SIGINT, etc.)."""
|
"""Register handlers for system signal events (SIGTERM, SIGINT, etc.)."""
|
||||||
|
|
||||||
@@ -570,10 +738,15 @@ class TraceCollectionListener(BaseEventListener):
|
|||||||
if event_type not in self.complex_events:
|
if event_type not in self.complex_events:
|
||||||
return safe_serialize_to_dict(event)
|
return safe_serialize_to_dict(event)
|
||||||
if event_type == "task_started":
|
if event_type == "task_started":
|
||||||
|
task_name = event.task.name or event.task.description
|
||||||
|
task_display_name = (
|
||||||
|
task_name[:80] + "..." if len(task_name) > 80 else task_name
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"task_description": event.task.description,
|
"task_description": event.task.description,
|
||||||
"expected_output": event.task.expected_output,
|
"expected_output": event.task.expected_output,
|
||||||
"task_name": event.task.name or event.task.description,
|
"task_name": task_name,
|
||||||
|
"task_display_name": task_display_name,
|
||||||
"context": event.context,
|
"context": event.context,
|
||||||
"agent_role": source.agent.role,
|
"agent_role": source.agent.role,
|
||||||
"task_id": str(event.task.id),
|
"task_id": str(event.task.id),
|
||||||
|
|||||||
@@ -4,68 +4,120 @@ This module defines events emitted during A2A protocol delegation,
|
|||||||
including both single-turn and multiturn conversation flows.
|
including both single-turn and multiturn conversation flows.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import model_validator
|
||||||
|
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
|
|
||||||
class A2AEventBase(BaseEvent):
|
class A2AEventBase(BaseEvent):
|
||||||
"""Base class for A2A events with task/agent context."""
|
"""Base class for A2A events with task/agent context."""
|
||||||
|
|
||||||
from_task: Any | None = None
|
from_task: Any = None
|
||||||
from_agent: Any | None = None
|
from_agent: Any = None
|
||||||
|
|
||||||
def __init__(self, **data: Any) -> None:
|
@model_validator(mode="before")
|
||||||
"""Initialize A2A event, extracting task and agent metadata."""
|
@classmethod
|
||||||
if data.get("from_task"):
|
def extract_task_and_agent_metadata(cls, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
task = data["from_task"]
|
"""Extract task and agent metadata before validation."""
|
||||||
|
if task := data.get("from_task"):
|
||||||
data["task_id"] = str(task.id)
|
data["task_id"] = str(task.id)
|
||||||
data["task_name"] = task.name or task.description
|
data["task_name"] = task.name or task.description
|
||||||
|
data.setdefault("source_fingerprint", str(task.id))
|
||||||
|
data.setdefault("source_type", "task")
|
||||||
|
data.setdefault(
|
||||||
|
"fingerprint_metadata",
|
||||||
|
{
|
||||||
|
"task_id": str(task.id),
|
||||||
|
"task_name": task.name or task.description,
|
||||||
|
},
|
||||||
|
)
|
||||||
data["from_task"] = None
|
data["from_task"] = None
|
||||||
|
|
||||||
if data.get("from_agent"):
|
if agent := data.get("from_agent"):
|
||||||
agent = data["from_agent"]
|
|
||||||
data["agent_id"] = str(agent.id)
|
data["agent_id"] = str(agent.id)
|
||||||
data["agent_role"] = agent.role
|
data["agent_role"] = agent.role
|
||||||
|
data.setdefault("source_fingerprint", str(agent.id))
|
||||||
|
data.setdefault("source_type", "agent")
|
||||||
|
data.setdefault(
|
||||||
|
"fingerprint_metadata",
|
||||||
|
{
|
||||||
|
"agent_id": str(agent.id),
|
||||||
|
"agent_role": agent.role,
|
||||||
|
},
|
||||||
|
)
|
||||||
data["from_agent"] = None
|
data["from_agent"] = None
|
||||||
|
|
||||||
super().__init__(**data)
|
return data
|
||||||
|
|
||||||
|
|
||||||
class A2ADelegationStartedEvent(A2AEventBase):
|
class A2ADelegationStartedEvent(A2AEventBase):
|
||||||
"""Event emitted when A2A delegation starts.
|
"""Event emitted when A2A delegation starts.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
endpoint: A2A agent endpoint URL (AgentCard URL)
|
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||||
task_description: Task being delegated to the A2A agent
|
task_description: Task being delegated to the A2A agent.
|
||||||
agent_id: A2A agent identifier
|
agent_id: A2A agent identifier.
|
||||||
is_multiturn: Whether this is part of a multiturn conversation
|
context_id: A2A context ID grouping related tasks.
|
||||||
turn_number: Current turn number (1-indexed, 1 for single-turn)
|
is_multiturn: Whether this is part of a multiturn conversation.
|
||||||
|
turn_number: Current turn number (1-indexed, 1 for single-turn).
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
agent_card: Full A2A agent card metadata.
|
||||||
|
protocol_version: A2A protocol version being used.
|
||||||
|
provider: Agent provider/organization info from agent card.
|
||||||
|
skill_id: ID of the specific skill being invoked.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_delegation_started"
|
type: str = "a2a_delegation_started"
|
||||||
endpoint: str
|
endpoint: str
|
||||||
task_description: str
|
task_description: str
|
||||||
agent_id: str
|
agent_id: str
|
||||||
|
context_id: str | None = None
|
||||||
is_multiturn: bool = False
|
is_multiturn: bool = False
|
||||||
turn_number: int = 1
|
turn_number: int = 1
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
agent_card: dict[str, Any] | None = None
|
||||||
|
protocol_version: str | None = None
|
||||||
|
provider: dict[str, Any] | None = None
|
||||||
|
skill_id: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
extensions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2ADelegationCompletedEvent(A2AEventBase):
|
class A2ADelegationCompletedEvent(A2AEventBase):
|
||||||
"""Event emitted when A2A delegation completes.
|
"""Event emitted when A2A delegation completes.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
status: Completion status (completed, input_required, failed, etc.)
|
status: Completion status (completed, input_required, failed, etc.).
|
||||||
result: Result message if status is completed
|
result: Result message if status is completed.
|
||||||
error: Error/response message (error for failed, response for input_required)
|
error: Error/response message (error for failed, response for input_required).
|
||||||
is_multiturn: Whether this is part of a multiturn conversation
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
is_multiturn: Whether this is part of a multiturn conversation.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
agent_card: Full A2A agent card metadata.
|
||||||
|
provider: Agent provider/organization info from agent card.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_delegation_completed"
|
type: str = "a2a_delegation_completed"
|
||||||
status: str
|
status: str
|
||||||
result: str | None = None
|
result: str | None = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
context_id: str | None = None
|
||||||
is_multiturn: bool = False
|
is_multiturn: bool = False
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
agent_card: dict[str, Any] | None = None
|
||||||
|
provider: dict[str, Any] | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
extensions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2AConversationStartedEvent(A2AEventBase):
|
class A2AConversationStartedEvent(A2AEventBase):
|
||||||
@@ -75,51 +127,95 @@ class A2AConversationStartedEvent(A2AEventBase):
|
|||||||
before the first message exchange.
|
before the first message exchange.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
agent_id: A2A agent identifier
|
agent_id: A2A agent identifier.
|
||||||
endpoint: A2A agent endpoint URL
|
endpoint: A2A agent endpoint URL.
|
||||||
a2a_agent_name: Name of the A2A agent from agent card
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
agent_card: Full A2A agent card metadata.
|
||||||
|
protocol_version: A2A protocol version being used.
|
||||||
|
provider: Agent provider/organization info from agent card.
|
||||||
|
skill_id: ID of the specific skill being invoked.
|
||||||
|
reference_task_ids: Related task IDs for context.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_conversation_started"
|
type: str = "a2a_conversation_started"
|
||||||
agent_id: str
|
agent_id: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
|
context_id: str | None = None
|
||||||
a2a_agent_name: str | None = None
|
a2a_agent_name: str | None = None
|
||||||
|
agent_card: dict[str, Any] | None = None
|
||||||
|
protocol_version: str | None = None
|
||||||
|
provider: dict[str, Any] | None = None
|
||||||
|
skill_id: str | None = None
|
||||||
|
reference_task_ids: list[str] | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
extensions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2AMessageSentEvent(A2AEventBase):
|
class A2AMessageSentEvent(A2AEventBase):
|
||||||
"""Event emitted when a message is sent to the A2A agent.
|
"""Event emitted when a message is sent to the A2A agent.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
message: Message content sent to the A2A agent
|
message: Message content sent to the A2A agent.
|
||||||
turn_number: Current turn number (1-indexed)
|
turn_number: Current turn number (1-indexed).
|
||||||
is_multiturn: Whether this is part of a multiturn conversation
|
context_id: A2A context ID grouping related tasks.
|
||||||
agent_role: Role of the CrewAI agent sending the message
|
message_id: Unique A2A message identifier.
|
||||||
|
is_multiturn: Whether this is part of a multiturn conversation.
|
||||||
|
agent_role: Role of the CrewAI agent sending the message.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
skill_id: ID of the specific skill being invoked.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_message_sent"
|
type: str = "a2a_message_sent"
|
||||||
message: str
|
message: str
|
||||||
turn_number: int
|
turn_number: int
|
||||||
|
context_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
is_multiturn: bool = False
|
is_multiturn: bool = False
|
||||||
agent_role: str | None = None
|
agent_role: str | None = None
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
skill_id: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
extensions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2AResponseReceivedEvent(A2AEventBase):
|
class A2AResponseReceivedEvent(A2AEventBase):
|
||||||
"""Event emitted when a response is received from the A2A agent.
|
"""Event emitted when a response is received from the A2A agent.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
response: Response content from the A2A agent
|
response: Response content from the A2A agent.
|
||||||
turn_number: Current turn number (1-indexed)
|
turn_number: Current turn number (1-indexed).
|
||||||
is_multiturn: Whether this is part of a multiturn conversation
|
context_id: A2A context ID grouping related tasks.
|
||||||
status: Response status (input_required, completed, etc.)
|
message_id: Unique A2A message identifier.
|
||||||
agent_role: Role of the CrewAI agent (for display)
|
is_multiturn: Whether this is part of a multiturn conversation.
|
||||||
|
status: Response status (input_required, completed, etc.).
|
||||||
|
final: Whether this is the final response in the stream.
|
||||||
|
agent_role: Role of the CrewAI agent (for display).
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_response_received"
|
type: str = "a2a_response_received"
|
||||||
response: str
|
response: str
|
||||||
turn_number: int
|
turn_number: int
|
||||||
|
context_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
is_multiturn: bool = False
|
is_multiturn: bool = False
|
||||||
status: str
|
status: str
|
||||||
|
final: bool = False
|
||||||
agent_role: str | None = None
|
agent_role: str | None = None
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
extensions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2AConversationCompletedEvent(A2AEventBase):
|
class A2AConversationCompletedEvent(A2AEventBase):
|
||||||
@@ -128,119 +224,433 @@ class A2AConversationCompletedEvent(A2AEventBase):
|
|||||||
This is emitted once at the end of a multiturn conversation.
|
This is emitted once at the end of a multiturn conversation.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
status: Final status (completed, failed, etc.)
|
status: Final status (completed, failed, etc.).
|
||||||
final_result: Final result if completed successfully
|
final_result: Final result if completed successfully.
|
||||||
error: Error message if failed
|
error: Error message if failed.
|
||||||
total_turns: Total number of turns in the conversation
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
total_turns: Total number of turns in the conversation.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
agent_card: Full A2A agent card metadata.
|
||||||
|
reference_task_ids: Related task IDs for context.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
extensions: List of A2A extension URIs in use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_conversation_completed"
|
type: str = "a2a_conversation_completed"
|
||||||
status: Literal["completed", "failed"]
|
status: Literal["completed", "failed"]
|
||||||
final_result: str | None = None
|
final_result: str | None = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
context_id: str | None = None
|
||||||
total_turns: int
|
total_turns: int
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
agent_card: dict[str, Any] | None = None
|
||||||
|
reference_task_ids: list[str] | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
extensions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2APollingStartedEvent(A2AEventBase):
|
class A2APollingStartedEvent(A2AEventBase):
|
||||||
"""Event emitted when polling mode begins for A2A delegation.
|
"""Event emitted when polling mode begins for A2A delegation.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
task_id: A2A task ID being polled
|
task_id: A2A task ID being polled.
|
||||||
polling_interval: Seconds between poll attempts
|
context_id: A2A context ID grouping related tasks.
|
||||||
endpoint: A2A agent endpoint URL
|
polling_interval: Seconds between poll attempts.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_polling_started"
|
type: str = "a2a_polling_started"
|
||||||
task_id: str
|
task_id: str
|
||||||
|
context_id: str | None = None
|
||||||
polling_interval: float
|
polling_interval: float
|
||||||
endpoint: str
|
endpoint: str
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2APollingStatusEvent(A2AEventBase):
|
class A2APollingStatusEvent(A2AEventBase):
|
||||||
"""Event emitted on each polling iteration.
|
"""Event emitted on each polling iteration.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
task_id: A2A task ID being polled
|
task_id: A2A task ID being polled.
|
||||||
state: Current task state from remote agent
|
context_id: A2A context ID grouping related tasks.
|
||||||
elapsed_seconds: Time since polling started
|
state: Current task state from remote agent.
|
||||||
poll_count: Number of polls completed
|
elapsed_seconds: Time since polling started.
|
||||||
|
poll_count: Number of polls completed.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_polling_status"
|
type: str = "a2a_polling_status"
|
||||||
task_id: str
|
task_id: str
|
||||||
|
context_id: str | None = None
|
||||||
state: str
|
state: str
|
||||||
elapsed_seconds: float
|
elapsed_seconds: float
|
||||||
poll_count: int
|
poll_count: int
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2APushNotificationRegisteredEvent(A2AEventBase):
|
class A2APushNotificationRegisteredEvent(A2AEventBase):
|
||||||
"""Event emitted when push notification callback is registered.
|
"""Event emitted when push notification callback is registered.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
task_id: A2A task ID for which callback is registered
|
task_id: A2A task ID for which callback is registered.
|
||||||
callback_url: URL where agent will send push notifications
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
callback_url: URL where agent will send push notifications.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_push_notification_registered"
|
type: str = "a2a_push_notification_registered"
|
||||||
task_id: str
|
task_id: str
|
||||||
|
context_id: str | None = None
|
||||||
callback_url: str
|
callback_url: str
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2APushNotificationReceivedEvent(A2AEventBase):
|
class A2APushNotificationReceivedEvent(A2AEventBase):
|
||||||
"""Event emitted when a push notification is received.
|
"""Event emitted when a push notification is received.
|
||||||
|
|
||||||
|
This event should be emitted by the user's webhook handler when it receives
|
||||||
|
a push notification from the remote A2A agent, before calling
|
||||||
|
`result_store.store_result()`.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
task_id: A2A task ID from the notification
|
task_id: A2A task ID from the notification.
|
||||||
state: Current task state from the notification
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
state: Current task state from the notification.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_push_notification_received"
|
type: str = "a2a_push_notification_received"
|
||||||
task_id: str
|
task_id: str
|
||||||
|
context_id: str | None = None
|
||||||
state: str
|
state: str
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class A2APushNotificationSentEvent(A2AEventBase):
|
||||||
|
"""Event emitted when a push notification is sent to a callback URL.
|
||||||
|
|
||||||
|
Emitted by the A2A server when it sends a task status update to the
|
||||||
|
client's registered push notification callback URL.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID being notified.
|
||||||
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
callback_url: URL the notification was sent to.
|
||||||
|
state: Task state being reported.
|
||||||
|
success: Whether the notification was successfully delivered.
|
||||||
|
error: Error message if delivery failed.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_push_notification_sent"
|
||||||
|
task_id: str
|
||||||
|
context_id: str | None = None
|
||||||
|
callback_url: str
|
||||||
|
state: str
|
||||||
|
success: bool = True
|
||||||
|
error: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2APushNotificationTimeoutEvent(A2AEventBase):
|
class A2APushNotificationTimeoutEvent(A2AEventBase):
|
||||||
"""Event emitted when push notification wait times out.
|
"""Event emitted when push notification wait times out.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
task_id: A2A task ID that timed out
|
task_id: A2A task ID that timed out.
|
||||||
timeout_seconds: Timeout duration in seconds
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
timeout_seconds: Timeout duration in seconds.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "a2a_push_notification_timeout"
|
type: str = "a2a_push_notification_timeout"
|
||||||
task_id: str
|
task_id: str
|
||||||
|
context_id: str | None = None
|
||||||
timeout_seconds: float
|
timeout_seconds: float
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class A2AStreamingStartedEvent(A2AEventBase):
|
||||||
|
"""Event emitted when streaming mode begins for A2A delegation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID for the streaming session.
|
||||||
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
turn_number: Current turn number (1-indexed).
|
||||||
|
is_multiturn: Whether this is part of a multiturn conversation.
|
||||||
|
agent_role: Role of the CrewAI agent.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
extensions: List of A2A extension URIs in use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_streaming_started"
|
||||||
|
task_id: str | None = None
|
||||||
|
context_id: str | None = None
|
||||||
|
endpoint: str
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
turn_number: int = 1
|
||||||
|
is_multiturn: bool = False
|
||||||
|
agent_role: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
extensions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class A2AStreamingChunkEvent(A2AEventBase):
|
||||||
|
"""Event emitted when a streaming chunk is received.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID for the streaming session.
|
||||||
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
chunk: The text content of the chunk.
|
||||||
|
chunk_index: Index of this chunk in the stream (0-indexed).
|
||||||
|
final: Whether this is the final chunk in the stream.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
turn_number: Current turn number (1-indexed).
|
||||||
|
is_multiturn: Whether this is part of a multiturn conversation.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
extensions: List of A2A extension URIs in use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_streaming_chunk"
|
||||||
|
task_id: str | None = None
|
||||||
|
context_id: str | None = None
|
||||||
|
chunk: str
|
||||||
|
chunk_index: int
|
||||||
|
final: bool = False
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
turn_number: int = 1
|
||||||
|
is_multiturn: bool = False
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
extensions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class A2AAgentCardFetchedEvent(A2AEventBase):
|
||||||
|
"""Event emitted when an agent card is successfully fetched.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
agent_card: Full A2A agent card metadata.
|
||||||
|
protocol_version: A2A protocol version from agent card.
|
||||||
|
provider: Agent provider/organization info from agent card.
|
||||||
|
cached: Whether the agent card was retrieved from cache.
|
||||||
|
fetch_time_ms: Time taken to fetch the agent card in milliseconds.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_agent_card_fetched"
|
||||||
|
endpoint: str
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
agent_card: dict[str, Any] | None = None
|
||||||
|
protocol_version: str | None = None
|
||||||
|
provider: dict[str, Any] | None = None
|
||||||
|
cached: bool = False
|
||||||
|
fetch_time_ms: float | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class A2AAuthenticationFailedEvent(A2AEventBase):
|
||||||
|
"""Event emitted when authentication to an A2A agent fails.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
auth_type: Type of authentication attempted (e.g., bearer, oauth2, api_key).
|
||||||
|
error: Error message describing the failure.
|
||||||
|
status_code: HTTP status code if applicable.
|
||||||
|
a2a_agent_name: Name of the A2A agent if known.
|
||||||
|
protocol_version: A2A protocol version being used.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_authentication_failed"
|
||||||
|
endpoint: str
|
||||||
|
auth_type: str | None = None
|
||||||
|
error: str
|
||||||
|
status_code: int | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
protocol_version: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class A2AArtifactReceivedEvent(A2AEventBase):
|
||||||
|
"""Event emitted when an artifact is received from a remote A2A agent.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID the artifact belongs to.
|
||||||
|
artifact_id: Unique identifier for the artifact.
|
||||||
|
artifact_name: Name of the artifact.
|
||||||
|
artifact_description: Purpose description of the artifact.
|
||||||
|
mime_type: MIME type of the artifact content.
|
||||||
|
size_bytes: Size of the artifact in bytes.
|
||||||
|
append: Whether content should be appended to existing artifact.
|
||||||
|
last_chunk: Whether this is the final chunk of the artifact.
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
context_id: Context ID for correlation.
|
||||||
|
turn_number: Current turn number (1-indexed).
|
||||||
|
is_multiturn: Whether this is part of a multiturn conversation.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
extensions: List of A2A extension URIs in use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_artifact_received"
|
||||||
|
task_id: str
|
||||||
|
artifact_id: str
|
||||||
|
artifact_name: str | None = None
|
||||||
|
artifact_description: str | None = None
|
||||||
|
mime_type: str | None = None
|
||||||
|
size_bytes: int | None = None
|
||||||
|
append: bool = False
|
||||||
|
last_chunk: bool = False
|
||||||
|
endpoint: str | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
context_id: str | None = None
|
||||||
|
turn_number: int = 1
|
||||||
|
is_multiturn: bool = False
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
extensions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class A2AConnectionErrorEvent(A2AEventBase):
|
||||||
|
"""Event emitted when a connection error occurs during A2A communication.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
endpoint: A2A agent endpoint URL.
|
||||||
|
error: Error message describing the connection failure.
|
||||||
|
error_type: Type of error (e.g., timeout, connection_refused, dns_error).
|
||||||
|
status_code: HTTP status code if applicable.
|
||||||
|
a2a_agent_name: Name of the A2A agent from agent card.
|
||||||
|
operation: The operation being attempted when error occurred.
|
||||||
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
task_id: A2A task ID if applicable.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_connection_error"
|
||||||
|
endpoint: str
|
||||||
|
error: str
|
||||||
|
error_type: str | None = None
|
||||||
|
status_code: int | None = None
|
||||||
|
a2a_agent_name: str | None = None
|
||||||
|
operation: str | None = None
|
||||||
|
context_id: str | None = None
|
||||||
|
task_id: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2AServerTaskStartedEvent(A2AEventBase):
|
class A2AServerTaskStartedEvent(A2AEventBase):
|
||||||
"""Event emitted when an A2A server task execution starts."""
|
"""Event emitted when an A2A server task execution starts.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID for this execution.
|
||||||
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
type: str = "a2a_server_task_started"
|
type: str = "a2a_server_task_started"
|
||||||
a2a_task_id: str
|
task_id: str
|
||||||
a2a_context_id: str
|
context_id: str
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2AServerTaskCompletedEvent(A2AEventBase):
|
class A2AServerTaskCompletedEvent(A2AEventBase):
|
||||||
"""Event emitted when an A2A server task execution completes."""
|
"""Event emitted when an A2A server task execution completes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID for this execution.
|
||||||
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
result: The task result.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
type: str = "a2a_server_task_completed"
|
type: str = "a2a_server_task_completed"
|
||||||
a2a_task_id: str
|
task_id: str
|
||||||
a2a_context_id: str
|
context_id: str
|
||||||
result: str
|
result: str
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2AServerTaskCanceledEvent(A2AEventBase):
|
class A2AServerTaskCanceledEvent(A2AEventBase):
|
||||||
"""Event emitted when an A2A server task execution is canceled."""
|
"""Event emitted when an A2A server task execution is canceled.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID for this execution.
|
||||||
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
type: str = "a2a_server_task_canceled"
|
type: str = "a2a_server_task_canceled"
|
||||||
a2a_task_id: str
|
task_id: str
|
||||||
a2a_context_id: str
|
context_id: str
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class A2AServerTaskFailedEvent(A2AEventBase):
|
class A2AServerTaskFailedEvent(A2AEventBase):
|
||||||
"""Event emitted when an A2A server task execution fails."""
|
"""Event emitted when an A2A server task execution fails.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID for this execution.
|
||||||
|
context_id: A2A context ID grouping related tasks.
|
||||||
|
error: Error message describing the failure.
|
||||||
|
metadata: Custom A2A metadata key-value pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
type: str = "a2a_server_task_failed"
|
type: str = "a2a_server_task_failed"
|
||||||
a2a_task_id: str
|
task_id: str
|
||||||
a2a_context_id: str
|
context_id: str
|
||||||
error: str
|
error: str
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class A2AParallelDelegationStartedEvent(A2AEventBase):
|
||||||
|
"""Event emitted when parallel delegation to multiple A2A agents begins.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
endpoints: List of A2A agent endpoints being delegated to.
|
||||||
|
task_description: Description of the task being delegated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_parallel_delegation_started"
|
||||||
|
endpoints: list[str]
|
||||||
|
task_description: str
|
||||||
|
|
||||||
|
|
||||||
|
class A2AParallelDelegationCompletedEvent(A2AEventBase):
|
||||||
|
"""Event emitted when parallel delegation to multiple A2A agents completes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
endpoints: List of A2A agent endpoints that were delegated to.
|
||||||
|
success_count: Number of successful delegations.
|
||||||
|
failure_count: Number of failed delegations.
|
||||||
|
results: Summary of results from each agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_parallel_delegation_completed"
|
||||||
|
endpoints: list[str]
|
||||||
|
success_count: int
|
||||||
|
failure_count: int
|
||||||
|
results: dict[str, str] | None = None
|
||||||
|
|||||||
@@ -533,7 +533,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
"""
|
"""
|
||||||
# Execute the agent loop
|
# Execute the agent loop
|
||||||
formatted_answer: AgentAction | AgentFinish | None = None
|
formatted_answer: AgentAction | AgentFinish | None = None
|
||||||
last_raw_output: str | None = None
|
|
||||||
while not isinstance(formatted_answer, AgentFinish):
|
while not isinstance(formatted_answer, AgentFinish):
|
||||||
try:
|
try:
|
||||||
if has_reached_max_iterations(self._iterations, self.max_iterations):
|
if has_reached_max_iterations(self._iterations, self.max_iterations):
|
||||||
@@ -557,7 +556,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
from_agent=self,
|
from_agent=self,
|
||||||
executor_context=self,
|
executor_context=self,
|
||||||
)
|
)
|
||||||
last_raw_output = answer
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
@@ -596,8 +594,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
iterations=self._iterations,
|
iterations=self._iterations,
|
||||||
log_error_after=3,
|
log_error_after=3,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
raw_output=last_raw_output,
|
|
||||||
agent_role=self.role,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
|
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
|
||||||
|
|
||||||
@@ -52,8 +51,6 @@ class SummaryContent(TypedDict):
|
|||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_MULTIPLE_NEWLINES: Final[re.Pattern[str]] = re.compile(r"\n+")
|
_MULTIPLE_NEWLINES: Final[re.Pattern[str]] = re.compile(r"\n+")
|
||||||
|
|
||||||
|
|
||||||
@@ -433,8 +430,6 @@ def handle_output_parser_exception(
|
|||||||
iterations: int,
|
iterations: int,
|
||||||
log_error_after: int = 3,
|
log_error_after: int = 3,
|
||||||
printer: Printer | None = None,
|
printer: Printer | None = None,
|
||||||
raw_output: str | None = None,
|
|
||||||
agent_role: str | None = None,
|
|
||||||
) -> AgentAction:
|
) -> AgentAction:
|
||||||
"""Handle OutputParserError by updating messages and formatted_answer.
|
"""Handle OutputParserError by updating messages and formatted_answer.
|
||||||
|
|
||||||
@@ -444,8 +439,6 @@ def handle_output_parser_exception(
|
|||||||
iterations: Current iteration count
|
iterations: Current iteration count
|
||||||
log_error_after: Number of iterations after which to log errors
|
log_error_after: Number of iterations after which to log errors
|
||||||
printer: Optional printer instance for logging
|
printer: Optional printer instance for logging
|
||||||
raw_output: The raw LLM output that failed to parse
|
|
||||||
agent_role: The role of the agent for logging context
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AgentAction: A formatted answer with the error
|
AgentAction: A formatted answer with the error
|
||||||
@@ -459,27 +452,6 @@ def handle_output_parser_exception(
|
|||||||
thought="",
|
thought="",
|
||||||
)
|
)
|
||||||
|
|
||||||
retry_count = iterations + 1
|
|
||||||
agent_context = f" for agent '{agent_role}'" if agent_role else ""
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Parse failed%s: %s",
|
|
||||||
agent_context,
|
|
||||||
e.error.split("\n")[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
if raw_output is not None:
|
|
||||||
truncated_output = (
|
|
||||||
raw_output[:500] + "..." if len(raw_output) > 500 else raw_output
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Raw output (truncated)%s: %s",
|
|
||||||
agent_context,
|
|
||||||
truncated_output.replace("\n", "\\n"),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Retry %d initiated%s", retry_count, agent_context)
|
|
||||||
|
|
||||||
if iterations > log_error_after and printer:
|
if iterations > log_error_after and printer:
|
||||||
printer.print(
|
printer.print(
|
||||||
content=f"Error parsing LLM output, agent will retry: {e.error}",
|
content=f"Error parsing LLM output, agent will retry: {e.error}",
|
||||||
|
|||||||
@@ -26,9 +26,13 @@ def mock_agent() -> MagicMock:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_task() -> MagicMock:
|
def mock_task(mock_context: MagicMock) -> MagicMock:
|
||||||
"""Create a mock Task."""
|
"""Create a mock Task."""
|
||||||
return MagicMock()
|
task = MagicMock()
|
||||||
|
task.id = mock_context.task_id
|
||||||
|
task.name = "Mock Task"
|
||||||
|
task.description = "Mock task description"
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -179,8 +183,8 @@ class TestExecute:
|
|||||||
event = first_call[0][1]
|
event = first_call[0][1]
|
||||||
|
|
||||||
assert event.type == "a2a_server_task_started"
|
assert event.type == "a2a_server_task_started"
|
||||||
assert event.a2a_task_id == mock_context.task_id
|
assert event.task_id == mock_context.task_id
|
||||||
assert event.a2a_context_id == mock_context.context_id
|
assert event.context_id == mock_context.context_id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_emits_completed_event(
|
async def test_emits_completed_event(
|
||||||
@@ -201,7 +205,7 @@ class TestExecute:
|
|||||||
event = second_call[0][1]
|
event = second_call[0][1]
|
||||||
|
|
||||||
assert event.type == "a2a_server_task_completed"
|
assert event.type == "a2a_server_task_completed"
|
||||||
assert event.a2a_task_id == mock_context.task_id
|
assert event.task_id == mock_context.task_id
|
||||||
assert event.result == "Task completed successfully"
|
assert event.result == "Task completed successfully"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -250,7 +254,7 @@ class TestExecute:
|
|||||||
event = canceled_call[0][1]
|
event = canceled_call[0][1]
|
||||||
|
|
||||||
assert event.type == "a2a_server_task_canceled"
|
assert event.type == "a2a_server_task_canceled"
|
||||||
assert event.a2a_task_id == mock_context.task_id
|
assert event.task_id == mock_context.task_id
|
||||||
|
|
||||||
|
|
||||||
class TestCancel:
|
class TestCancel:
|
||||||
|
|||||||
@@ -14,6 +14,16 @@ except ImportError:
|
|||||||
A2A_SDK_INSTALLED = False
|
A2A_SDK_INSTALLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mock_agent_card(name: str = "Test", url: str = "http://test-endpoint.com/"):
|
||||||
|
"""Create a mock agent card with proper model_dump behavior."""
|
||||||
|
mock_card = MagicMock()
|
||||||
|
mock_card.name = name
|
||||||
|
mock_card.url = url
|
||||||
|
mock_card.model_dump.return_value = {"name": name, "url": url}
|
||||||
|
mock_card.model_dump_json.return_value = f'{{"name": "{name}", "url": "{url}"}}'
|
||||||
|
return mock_card
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||||
def test_trust_remote_completion_status_true_returns_directly():
|
def test_trust_remote_completion_status_true_returns_directly():
|
||||||
"""When trust_remote_completion_status=True and A2A returns completed, return result directly."""
|
"""When trust_remote_completion_status=True and A2A returns completed, return result directly."""
|
||||||
@@ -44,8 +54,7 @@ def test_trust_remote_completion_status_true_returns_directly():
|
|||||||
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||||
):
|
):
|
||||||
mock_card = MagicMock()
|
mock_card = _create_mock_agent_card()
|
||||||
mock_card.name = "Test"
|
|
||||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||||
|
|
||||||
# A2A returns completed
|
# A2A returns completed
|
||||||
@@ -110,8 +119,7 @@ def test_trust_remote_completion_status_false_continues_conversation():
|
|||||||
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||||
):
|
):
|
||||||
mock_card = MagicMock()
|
mock_card = _create_mock_agent_card()
|
||||||
mock_card.name = "Test"
|
|
||||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||||
|
|
||||||
# A2A returns completed
|
# A2A returns completed
|
||||||
|
|||||||
@@ -1,240 +0,0 @@
|
|||||||
"""Tests for agent_utils module, specifically debug logging for OutputParserError."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from crewai.agents.parser import AgentAction, OutputParserError
|
|
||||||
from crewai.utilities.agent_utils import handle_output_parser_exception
|
|
||||||
|
|
||||||
|
|
||||||
class TestHandleOutputParserExceptionDebugLogging:
|
|
||||||
"""Tests for debug logging in handle_output_parser_exception."""
|
|
||||||
|
|
||||||
def test_debug_logging_with_raw_output_and_agent_role(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
||||||
"""Test that debug logging includes raw output and agent role when provided."""
|
|
||||||
error = OutputParserError("Invalid Format: I missed the 'Action:' after 'Thought:'.")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
raw_output = "Let me think about this... The answer is..."
|
|
||||||
agent_role = "Researcher"
|
|
||||||
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
result = handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
raw_output=raw_output,
|
|
||||||
agent_role=agent_role,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, AgentAction)
|
|
||||||
assert "Parse failed for agent 'Researcher'" in caplog.text
|
|
||||||
assert "Raw output (truncated) for agent 'Researcher'" in caplog.text
|
|
||||||
assert "Let me think about this... The answer is..." in caplog.text
|
|
||||||
assert "Retry 1 initiated for agent 'Researcher'" in caplog.text
|
|
||||||
|
|
||||||
def test_debug_logging_without_agent_role(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
||||||
"""Test that debug logging works without agent role."""
|
|
||||||
error = OutputParserError("Invalid Format: I missed the 'Action:' after 'Thought:'.")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
raw_output = "Some raw output"
|
|
||||||
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
result = handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
raw_output=raw_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, AgentAction)
|
|
||||||
assert "Parse failed:" in caplog.text
|
|
||||||
assert "for agent" not in caplog.text.split("Parse failed:")[1].split("\n")[0]
|
|
||||||
assert "Raw output (truncated):" in caplog.text
|
|
||||||
assert "Retry 1 initiated" in caplog.text
|
|
||||||
|
|
||||||
def test_debug_logging_without_raw_output(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
||||||
"""Test that debug logging works without raw output."""
|
|
||||||
error = OutputParserError("Invalid Format: I missed the 'Action:' after 'Thought:'.")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
result = handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
agent_role="Researcher",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, AgentAction)
|
|
||||||
assert "Parse failed for agent 'Researcher'" in caplog.text
|
|
||||||
assert "Raw output (truncated)" not in caplog.text
|
|
||||||
assert "Retry 1 initiated for agent 'Researcher'" in caplog.text
|
|
||||||
|
|
||||||
def test_debug_logging_truncates_long_raw_output(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
||||||
"""Test that raw output is truncated when longer than 500 characters."""
|
|
||||||
error = OutputParserError("Invalid Format")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
long_output = "A" * 600
|
|
||||||
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
raw_output=long_output,
|
|
||||||
agent_role="Researcher",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "A" * 500 + "..." in caplog.text
|
|
||||||
assert "A" * 600 not in caplog.text
|
|
||||||
|
|
||||||
def test_debug_logging_does_not_truncate_short_raw_output(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
||||||
"""Test that short raw output is not truncated."""
|
|
||||||
error = OutputParserError("Invalid Format")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
short_output = "Short output"
|
|
||||||
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
raw_output=short_output,
|
|
||||||
agent_role="Researcher",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "Short output" in caplog.text
|
|
||||||
assert "..." not in caplog.text.split("Short output")[1].split("\n")[0]
|
|
||||||
|
|
||||||
def test_debug_logging_retry_count_increments(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
||||||
"""Test that retry count is correctly calculated from iterations."""
|
|
||||||
error = OutputParserError("Invalid Format")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=4,
|
|
||||||
raw_output="test",
|
|
||||||
agent_role="Researcher",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "Retry 5 initiated" in caplog.text
|
|
||||||
|
|
||||||
def test_debug_logging_escapes_newlines_in_raw_output(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
||||||
"""Test that newlines in raw output are escaped for readability."""
|
|
||||||
error = OutputParserError("Invalid Format")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
output_with_newlines = "Line 1\nLine 2\nLine 3"
|
|
||||||
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
raw_output=output_with_newlines,
|
|
||||||
agent_role="Researcher",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "Line 1\\nLine 2\\nLine 3" in caplog.text
|
|
||||||
|
|
||||||
def test_debug_logging_extracts_first_line_of_error(self, caplog: pytest.LogCaptureFixture) -> None:
|
|
||||||
"""Test that only the first line of the error message is logged."""
|
|
||||||
error = OutputParserError("First line of error\nSecond line\nThird line")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
agent_role="Researcher",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "First line of error" in caplog.text
|
|
||||||
parse_failed_line = [line for line in caplog.text.split("\n") if "Parse failed" in line][0]
|
|
||||||
assert "Second line" not in parse_failed_line
|
|
||||||
|
|
||||||
def test_messages_updated_with_error(self) -> None:
|
|
||||||
"""Test that messages list is updated with the error."""
|
|
||||||
error = OutputParserError("Test error message")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
|
|
||||||
handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(messages) == 1
|
|
||||||
assert messages[0]["role"] == "user"
|
|
||||||
assert messages[0]["content"] == "Test error message"
|
|
||||||
|
|
||||||
def test_returns_agent_action_with_error_text(self) -> None:
|
|
||||||
"""Test that the function returns an AgentAction with the error text."""
|
|
||||||
error = OutputParserError("Test error message")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
|
|
||||||
result = handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, AgentAction)
|
|
||||||
assert result.text == "Test error message"
|
|
||||||
assert result.tool == ""
|
|
||||||
assert result.tool_input == ""
|
|
||||||
assert result.thought == ""
|
|
||||||
|
|
||||||
def test_printer_logs_after_log_error_after_iterations(self) -> None:
|
|
||||||
"""Test that printer logs error after log_error_after iterations."""
|
|
||||||
error = OutputParserError("Test error")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
printer = MagicMock()
|
|
||||||
|
|
||||||
handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=4,
|
|
||||||
log_error_after=3,
|
|
||||||
printer=printer,
|
|
||||||
)
|
|
||||||
|
|
||||||
printer.print.assert_called_once()
|
|
||||||
call_args = printer.print.call_args
|
|
||||||
assert "Error parsing LLM output" in call_args.kwargs["content"]
|
|
||||||
assert call_args.kwargs["color"] == "red"
|
|
||||||
|
|
||||||
def test_printer_does_not_log_before_log_error_after_iterations(self) -> None:
|
|
||||||
"""Test that printer does not log before log_error_after iterations."""
|
|
||||||
error = OutputParserError("Test error")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
printer = MagicMock()
|
|
||||||
|
|
||||||
handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=2,
|
|
||||||
log_error_after=3,
|
|
||||||
printer=printer,
|
|
||||||
)
|
|
||||||
|
|
||||||
printer.print.assert_not_called()
|
|
||||||
|
|
||||||
def test_backward_compatibility_without_new_parameters(self) -> None:
|
|
||||||
"""Test that the function works without the new optional parameters."""
|
|
||||||
error = OutputParserError("Test error")
|
|
||||||
messages: list[dict[str, str]] = []
|
|
||||||
|
|
||||||
result = handle_output_parser_exception(
|
|
||||||
e=error,
|
|
||||||
messages=messages,
|
|
||||||
iterations=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, AgentAction)
|
|
||||||
assert len(messages) == 1
|
|
||||||
Reference in New Issue
Block a user