mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-19 12:58:14 +00:00
Compare commits
2 Commits
lorenze/im
...
devin/1768
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca1f8fd7e0 | ||
|
|
ceef062426 |
@@ -89,6 +89,9 @@ from crewai_tools.tools.jina_scrape_website_tool.jina_scrape_website_tool import
|
||||
from crewai_tools.tools.json_search_tool.json_search_tool import JSONSearchTool
|
||||
from crewai_tools.tools.linkup.linkup_search_tool import LinkupSearchTool
|
||||
from crewai_tools.tools.llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
|
||||
MCPDiscoveryTool,
|
||||
)
|
||||
from crewai_tools.tools.mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||
from crewai_tools.tools.merge_agent_handler_tool.merge_agent_handler_tool import (
|
||||
MergeAgentHandlerTool,
|
||||
@@ -236,6 +239,7 @@ __all__ = [
|
||||
"JinaScrapeWebsiteTool",
|
||||
"LinkupSearchTool",
|
||||
"LlamaIndexTool",
|
||||
"MCPDiscoveryTool",
|
||||
"MCPServerAdapter",
|
||||
"MDXSearchTool",
|
||||
"MergeAgentHandlerTool",
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
|
||||
MCPDiscoveryResult,
|
||||
MCPDiscoveryTool,
|
||||
MCPDiscoveryToolSchema,
|
||||
MCPServerMetrics,
|
||||
MCPServerRecommendation,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MCPDiscoveryResult",
|
||||
"MCPDiscoveryTool",
|
||||
"MCPDiscoveryToolSchema",
|
||||
"MCPServerMetrics",
|
||||
"MCPServerRecommendation",
|
||||
]
|
||||
@@ -0,0 +1,414 @@
|
||||
"""MCP Discovery Tool for CrewAI agents.
|
||||
|
||||
This tool enables agents to dynamically discover MCP servers based on
|
||||
natural language queries using the MCP Discovery API.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerMetrics(TypedDict, total=False):
|
||||
"""Performance metrics for an MCP server."""
|
||||
|
||||
avg_latency_ms: float | None
|
||||
uptime_pct: float | None
|
||||
last_checked: str | None
|
||||
|
||||
|
||||
class MCPServerRecommendation(TypedDict, total=False):
|
||||
"""A recommended MCP server from the discovery API."""
|
||||
|
||||
server: str
|
||||
npm_package: str
|
||||
install_command: str
|
||||
confidence: float
|
||||
description: str
|
||||
capabilities: list[str]
|
||||
metrics: MCPServerMetrics
|
||||
docs_url: str
|
||||
github_url: str
|
||||
|
||||
|
||||
class MCPDiscoveryResult(TypedDict):
|
||||
"""Result from the MCP Discovery API."""
|
||||
|
||||
recommendations: list[MCPServerRecommendation]
|
||||
total_found: int
|
||||
query_time_ms: int
|
||||
|
||||
|
||||
class MCPDiscoveryConstraints(BaseModel):
|
||||
"""Constraints for MCP server discovery."""
|
||||
|
||||
max_latency_ms: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum acceptable latency in milliseconds",
|
||||
)
|
||||
required_features: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of required features/capabilities",
|
||||
)
|
||||
exclude_servers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of server names to exclude from results",
|
||||
)
|
||||
|
||||
|
||||
class MCPDiscoveryToolSchema(BaseModel):
|
||||
"""Input schema for MCPDiscoveryTool."""
|
||||
|
||||
need: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Natural language description of what you need. "
|
||||
"For example: 'database with authentication', 'email automation', "
|
||||
"'file storage', 'web scraping'"
|
||||
),
|
||||
)
|
||||
constraints: MCPDiscoveryConstraints | None = Field(
|
||||
default=None,
|
||||
description="Optional constraints to filter results",
|
||||
)
|
||||
limit: int = Field(
|
||||
default=5,
|
||||
description="Maximum number of recommendations to return (1-10)",
|
||||
ge=1,
|
||||
le=10,
|
||||
)
|
||||
|
||||
|
||||
class MCPDiscoveryTool(BaseTool):
|
||||
"""Tool for discovering MCP servers dynamically.
|
||||
|
||||
This tool uses the MCP Discovery API to find MCP servers that match
|
||||
a natural language description of what the agent needs. It enables
|
||||
agents to dynamically discover and select the best MCP servers for
|
||||
their tasks without requiring pre-configuration.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai_tools import MCPDiscoveryTool
|
||||
|
||||
discovery_tool = MCPDiscoveryTool()
|
||||
|
||||
agent = Agent(
|
||||
role='Researcher',
|
||||
tools=[discovery_tool],
|
||||
goal='Research and analyze data'
|
||||
)
|
||||
|
||||
# The agent can now discover MCP servers dynamically:
|
||||
# discover_mcp_server(need="database with authentication")
|
||||
# Returns: Supabase MCP server with installation instructions
|
||||
```
|
||||
|
||||
Attributes:
|
||||
name: The name of the tool.
|
||||
description: A description of what the tool does.
|
||||
args_schema: The Pydantic model for input validation.
|
||||
base_url: The base URL for the MCP Discovery API.
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
|
||||
name: str = "Discover MCP Server"
|
||||
description: str = (
|
||||
"Discover MCP (Model Context Protocol) servers that match your needs. "
|
||||
"Use this tool to find the best MCP server for any task using natural "
|
||||
"language. Returns server recommendations with installation instructions, "
|
||||
"capabilities, and performance metrics."
|
||||
)
|
||||
args_schema: type[BaseModel] = MCPDiscoveryToolSchema
|
||||
base_url: str = "https://mcp-discovery-production.up.railway.app"
|
||||
timeout: int = 30
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="MCP_DISCOVERY_API_KEY",
|
||||
description="API key for MCP Discovery (optional for free tier)",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _build_request_payload(
|
||||
self,
|
||||
need: str,
|
||||
constraints: MCPDiscoveryConstraints | None,
|
||||
limit: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the request payload for the discovery API.
|
||||
|
||||
Args:
|
||||
need: Natural language description of what is needed.
|
||||
constraints: Optional constraints to filter results.
|
||||
limit: Maximum number of recommendations.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the request payload.
|
||||
"""
|
||||
payload: dict[str, Any] = {
|
||||
"need": need,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
if constraints:
|
||||
constraints_dict: dict[str, Any] = {}
|
||||
if constraints.max_latency_ms is not None:
|
||||
constraints_dict["max_latency_ms"] = constraints.max_latency_ms
|
||||
if constraints.required_features:
|
||||
constraints_dict["required_features"] = constraints.required_features
|
||||
if constraints.exclude_servers:
|
||||
constraints_dict["exclude_servers"] = constraints.exclude_servers
|
||||
if constraints_dict:
|
||||
payload["constraints"] = constraints_dict
|
||||
|
||||
return payload
|
||||
|
||||
def _make_api_request(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Make a request to the MCP Discovery API.
|
||||
|
||||
Args:
|
||||
payload: The request payload.
|
||||
|
||||
Returns:
|
||||
The API response as a dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API returns an empty response.
|
||||
requests.exceptions.RequestException: If the request fails.
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/discover"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
api_key = os.environ.get("MCP_DISCOVERY_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
response = None
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
if not results:
|
||||
logger.error("Empty response from MCP Discovery API")
|
||||
raise ValueError("Empty response from MCP Discovery API")
|
||||
return results
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"Error making request to MCP Discovery API: {e}"
|
||||
if response is not None and hasattr(response, "content"):
|
||||
error_msg += (
|
||||
f"\nResponse content: "
|
||||
f"{response.content.decode('utf-8', errors='replace')}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
if response is not None and hasattr(response, "content"):
|
||||
logger.error(f"Error decoding JSON response: {e}")
|
||||
logger.error(
|
||||
f"Response content: "
|
||||
f"{response.content.decode('utf-8', errors='replace')}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Error decoding JSON response: {e} (No response content available)"
|
||||
)
|
||||
raise
|
||||
|
||||
def _process_single_recommendation(
|
||||
self, rec: dict[str, Any]
|
||||
) -> MCPServerRecommendation | None:
|
||||
"""Process a single recommendation from the API.
|
||||
|
||||
Args:
|
||||
rec: Raw recommendation dictionary from the API.
|
||||
|
||||
Returns:
|
||||
Processed MCPServerRecommendation or None if malformed.
|
||||
"""
|
||||
try:
|
||||
metrics_data = rec.get("metrics", {}) if isinstance(rec, dict) else {}
|
||||
metrics: MCPServerMetrics = {
|
||||
"avg_latency_ms": metrics_data.get("avg_latency_ms"),
|
||||
"uptime_pct": metrics_data.get("uptime_pct"),
|
||||
"last_checked": metrics_data.get("last_checked"),
|
||||
}
|
||||
|
||||
recommendation: MCPServerRecommendation = {
|
||||
"server": rec.get("server", ""),
|
||||
"npm_package": rec.get("npm_package", ""),
|
||||
"install_command": rec.get("install_command", ""),
|
||||
"confidence": rec.get("confidence", 0.0),
|
||||
"description": rec.get("description", ""),
|
||||
"capabilities": rec.get("capabilities", []),
|
||||
"metrics": metrics,
|
||||
"docs_url": rec.get("docs_url", ""),
|
||||
"github_url": rec.get("github_url", ""),
|
||||
}
|
||||
return recommendation
|
||||
except (KeyError, TypeError, AttributeError) as e:
|
||||
logger.warning(f"Skipping malformed recommendation: {rec}, error: {e}")
|
||||
return None
|
||||
|
||||
def _process_recommendations(
|
||||
self, recommendations: list[dict[str, Any]]
|
||||
) -> list[MCPServerRecommendation]:
|
||||
"""Process and validate server recommendations.
|
||||
|
||||
Args:
|
||||
recommendations: Raw recommendations from the API.
|
||||
|
||||
Returns:
|
||||
List of processed MCPServerRecommendation objects.
|
||||
"""
|
||||
processed: list[MCPServerRecommendation] = []
|
||||
for rec in recommendations:
|
||||
result = self._process_single_recommendation(rec)
|
||||
if result is not None:
|
||||
processed.append(result)
|
||||
return processed
|
||||
|
||||
def _format_result(self, result: MCPDiscoveryResult) -> str:
|
||||
"""Format the discovery result as a human-readable string.
|
||||
|
||||
Args:
|
||||
result: The discovery result to format.
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the result.
|
||||
"""
|
||||
if not result["recommendations"]:
|
||||
return "No MCP servers found matching your requirements."
|
||||
|
||||
lines = [
|
||||
f"Found {result['total_found']} MCP server(s) "
|
||||
f"(query took {result['query_time_ms']}ms):\n"
|
||||
]
|
||||
|
||||
for i, rec in enumerate(result["recommendations"], 1):
|
||||
confidence_pct = rec.get("confidence", 0) * 100
|
||||
lines.append(f"{i}. {rec.get('server', 'Unknown')} ({confidence_pct:.0f}% confidence)")
|
||||
lines.append(f" Description: {rec.get('description', 'N/A')}")
|
||||
lines.append(f" Capabilities: {', '.join(rec.get('capabilities', []))}")
|
||||
lines.append(f" Install: {rec.get('install_command', 'N/A')}")
|
||||
lines.append(f" NPM Package: {rec.get('npm_package', 'N/A')}")
|
||||
|
||||
metrics = rec.get("metrics", {})
|
||||
if metrics.get("avg_latency_ms") is not None:
|
||||
lines.append(f" Avg Latency: {metrics['avg_latency_ms']}ms")
|
||||
if metrics.get("uptime_pct") is not None:
|
||||
lines.append(f" Uptime: {metrics['uptime_pct']}%")
|
||||
|
||||
if rec.get("docs_url"):
|
||||
lines.append(f" Docs: {rec['docs_url']}")
|
||||
if rec.get("github_url"):
|
||||
lines.append(f" GitHub: {rec['github_url']}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Execute the MCP discovery operation.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments matching MCPDiscoveryToolSchema.
|
||||
|
||||
Returns:
|
||||
A formatted string with discovered MCP servers.
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing.
|
||||
"""
|
||||
need: str | None = kwargs.get("need")
|
||||
if not need:
|
||||
raise ValueError("'need' parameter is required")
|
||||
|
||||
constraints_data = kwargs.get("constraints")
|
||||
constraints: MCPDiscoveryConstraints | None = None
|
||||
if constraints_data:
|
||||
if isinstance(constraints_data, dict):
|
||||
constraints = MCPDiscoveryConstraints(**constraints_data)
|
||||
elif isinstance(constraints_data, MCPDiscoveryConstraints):
|
||||
constraints = constraints_data
|
||||
|
||||
limit: int = kwargs.get("limit", 5)
|
||||
|
||||
payload = self._build_request_payload(need, constraints, limit)
|
||||
response = self._make_api_request(payload)
|
||||
|
||||
recommendations = self._process_recommendations(
|
||||
response.get("recommendations", [])
|
||||
)
|
||||
|
||||
result: MCPDiscoveryResult = {
|
||||
"recommendations": recommendations,
|
||||
"total_found": response.get("total_found", len(recommendations)),
|
||||
"query_time_ms": response.get("query_time_ms", 0),
|
||||
}
|
||||
|
||||
return self._format_result(result)
|
||||
|
||||
def discover(
|
||||
self,
|
||||
need: str,
|
||||
constraints: MCPDiscoveryConstraints | None = None,
|
||||
limit: int = 5,
|
||||
) -> MCPDiscoveryResult:
|
||||
"""Discover MCP servers matching the given requirements.
|
||||
|
||||
This is a convenience method that returns structured data instead
|
||||
of a formatted string.
|
||||
|
||||
Args:
|
||||
need: Natural language description of what is needed.
|
||||
constraints: Optional constraints to filter results.
|
||||
limit: Maximum number of recommendations (1-10).
|
||||
|
||||
Returns:
|
||||
MCPDiscoveryResult containing server recommendations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
tool = MCPDiscoveryTool()
|
||||
result = tool.discover(
|
||||
need="database with authentication",
|
||||
constraints=MCPDiscoveryConstraints(
|
||||
max_latency_ms=200,
|
||||
required_features=["auth", "realtime"]
|
||||
),
|
||||
limit=3
|
||||
)
|
||||
for rec in result["recommendations"]:
|
||||
print(f"{rec['server']}: {rec['description']}")
|
||||
```
|
||||
"""
|
||||
payload = self._build_request_payload(need, constraints, limit)
|
||||
response = self._make_api_request(payload)
|
||||
|
||||
recommendations = self._process_recommendations(
|
||||
response.get("recommendations", [])
|
||||
)
|
||||
|
||||
return {
|
||||
"recommendations": recommendations,
|
||||
"total_found": response.get("total_found", len(recommendations)),
|
||||
"query_time_ms": response.get("query_time_ms", 0),
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,452 @@
|
||||
"""Tests for the MCP Discovery Tool."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
|
||||
MCPDiscoveryConstraints,
|
||||
MCPDiscoveryResult,
|
||||
MCPDiscoveryTool,
|
||||
MCPDiscoveryToolSchema,
|
||||
MCPServerMetrics,
|
||||
MCPServerRecommendation,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_response() -> dict:
|
||||
"""Create a mock API response for testing."""
|
||||
return {
|
||||
"recommendations": [
|
||||
{
|
||||
"server": "sqlite-server",
|
||||
"npm_package": "@modelcontextprotocol/server-sqlite",
|
||||
"install_command": "npx -y @modelcontextprotocol/server-sqlite",
|
||||
"confidence": 0.38,
|
||||
"description": "SQLite database server for MCP.",
|
||||
"capabilities": ["sqlite", "sql", "database", "embedded"],
|
||||
"metrics": {
|
||||
"avg_latency_ms": 50.0,
|
||||
"uptime_pct": 99.9,
|
||||
"last_checked": "2026-01-17T10:30:00Z",
|
||||
},
|
||||
"docs_url": "https://modelcontextprotocol.io/docs/servers/sqlite",
|
||||
"github_url": "https://github.com/modelcontextprotocol/servers",
|
||||
},
|
||||
{
|
||||
"server": "postgres-server",
|
||||
"npm_package": "@modelcontextprotocol/server-postgres",
|
||||
"install_command": "npx -y @modelcontextprotocol/server-postgres",
|
||||
"confidence": 0.33,
|
||||
"description": "PostgreSQL database server for MCP.",
|
||||
"capabilities": ["postgres", "sql", "database", "queries"],
|
||||
"metrics": {
|
||||
"avg_latency_ms": None,
|
||||
"uptime_pct": None,
|
||||
"last_checked": None,
|
||||
},
|
||||
"docs_url": "https://modelcontextprotocol.io/docs/servers/postgres",
|
||||
"github_url": "https://github.com/modelcontextprotocol/servers",
|
||||
},
|
||||
],
|
||||
"total_found": 2,
|
||||
"query_time_ms": 245,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discovery_tool() -> MCPDiscoveryTool:
|
||||
"""Create an MCPDiscoveryTool instance for testing."""
|
||||
return MCPDiscoveryTool()
|
||||
|
||||
|
||||
class TestMCPDiscoveryToolSchema:
|
||||
"""Tests for MCPDiscoveryToolSchema."""
|
||||
|
||||
def test_schema_with_required_fields(self) -> None:
|
||||
"""Test schema with only required fields."""
|
||||
schema = MCPDiscoveryToolSchema(need="database server")
|
||||
assert schema.need == "database server"
|
||||
assert schema.constraints is None
|
||||
assert schema.limit == 5
|
||||
|
||||
def test_schema_with_all_fields(self) -> None:
|
||||
"""Test schema with all fields."""
|
||||
constraints = MCPDiscoveryConstraints(
|
||||
max_latency_ms=200,
|
||||
required_features=["auth", "realtime"],
|
||||
exclude_servers=["deprecated-server"],
|
||||
)
|
||||
schema = MCPDiscoveryToolSchema(
|
||||
need="database with authentication",
|
||||
constraints=constraints,
|
||||
limit=3,
|
||||
)
|
||||
assert schema.need == "database with authentication"
|
||||
assert schema.constraints is not None
|
||||
assert schema.constraints.max_latency_ms == 200
|
||||
assert schema.constraints.required_features == ["auth", "realtime"]
|
||||
assert schema.constraints.exclude_servers == ["deprecated-server"]
|
||||
assert schema.limit == 3
|
||||
|
||||
def test_schema_limit_validation(self) -> None:
|
||||
"""Test that limit is validated to be between 1 and 10."""
|
||||
with pytest.raises(ValueError):
|
||||
MCPDiscoveryToolSchema(need="test", limit=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MCPDiscoveryToolSchema(need="test", limit=11)
|
||||
|
||||
|
||||
class TestMCPDiscoveryConstraints:
|
||||
"""Tests for MCPDiscoveryConstraints."""
|
||||
|
||||
def test_empty_constraints(self) -> None:
|
||||
"""Test creating empty constraints."""
|
||||
constraints = MCPDiscoveryConstraints()
|
||||
assert constraints.max_latency_ms is None
|
||||
assert constraints.required_features is None
|
||||
assert constraints.exclude_servers is None
|
||||
|
||||
def test_full_constraints(self) -> None:
|
||||
"""Test creating constraints with all fields."""
|
||||
constraints = MCPDiscoveryConstraints(
|
||||
max_latency_ms=100,
|
||||
required_features=["feature1", "feature2"],
|
||||
exclude_servers=["server1", "server2"],
|
||||
)
|
||||
assert constraints.max_latency_ms == 100
|
||||
assert constraints.required_features == ["feature1", "feature2"]
|
||||
assert constraints.exclude_servers == ["server1", "server2"]
|
||||
|
||||
|
||||
class TestMCPDiscoveryTool:
|
||||
"""Tests for MCPDiscoveryTool."""
|
||||
|
||||
def test_tool_initialization(self, discovery_tool: MCPDiscoveryTool) -> None:
|
||||
"""Test tool initialization with default values."""
|
||||
assert discovery_tool.name == "Discover MCP Server"
|
||||
assert "MCP" in discovery_tool.description
|
||||
assert discovery_tool.base_url == "https://mcp-discovery-production.up.railway.app"
|
||||
assert discovery_tool.timeout == 30
|
||||
|
||||
def test_build_request_payload_basic(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test building request payload with basic parameters."""
|
||||
payload = discovery_tool._build_request_payload(
|
||||
need="database server",
|
||||
constraints=None,
|
||||
limit=5,
|
||||
)
|
||||
assert payload == {"need": "database server", "limit": 5}
|
||||
|
||||
def test_build_request_payload_with_constraints(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test building request payload with constraints."""
|
||||
constraints = MCPDiscoveryConstraints(
|
||||
max_latency_ms=200,
|
||||
required_features=["auth"],
|
||||
exclude_servers=["old-server"],
|
||||
)
|
||||
payload = discovery_tool._build_request_payload(
|
||||
need="database",
|
||||
constraints=constraints,
|
||||
limit=3,
|
||||
)
|
||||
assert payload["need"] == "database"
|
||||
assert payload["limit"] == 3
|
||||
assert "constraints" in payload
|
||||
assert payload["constraints"]["max_latency_ms"] == 200
|
||||
assert payload["constraints"]["required_features"] == ["auth"]
|
||||
assert payload["constraints"]["exclude_servers"] == ["old-server"]
|
||||
|
||||
def test_build_request_payload_partial_constraints(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test building request payload with partial constraints."""
|
||||
constraints = MCPDiscoveryConstraints(max_latency_ms=100)
|
||||
payload = discovery_tool._build_request_payload(
|
||||
need="test",
|
||||
constraints=constraints,
|
||||
limit=5,
|
||||
)
|
||||
assert payload["constraints"] == {"max_latency_ms": 100}
|
||||
|
||||
def test_process_recommendations(
|
||||
self, discovery_tool: MCPDiscoveryTool, mock_api_response: dict
|
||||
) -> None:
|
||||
"""Test processing recommendations from API response."""
|
||||
recommendations = discovery_tool._process_recommendations(
|
||||
mock_api_response["recommendations"]
|
||||
)
|
||||
assert len(recommendations) == 2
|
||||
|
||||
first_rec = recommendations[0]
|
||||
assert first_rec["server"] == "sqlite-server"
|
||||
assert first_rec["npm_package"] == "@modelcontextprotocol/server-sqlite"
|
||||
assert first_rec["confidence"] == 0.38
|
||||
assert first_rec["capabilities"] == ["sqlite", "sql", "database", "embedded"]
|
||||
assert first_rec["metrics"]["avg_latency_ms"] == 50.0
|
||||
assert first_rec["metrics"]["uptime_pct"] == 99.9
|
||||
|
||||
second_rec = recommendations[1]
|
||||
assert second_rec["server"] == "postgres-server"
|
||||
assert second_rec["metrics"]["avg_latency_ms"] is None
|
||||
|
||||
def test_process_recommendations_with_malformed_data(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test processing recommendations with malformed data."""
|
||||
malformed_recommendations = [
|
||||
{"server": "valid-server", "confidence": 0.5},
|
||||
None,
|
||||
{"invalid": "data"},
|
||||
]
|
||||
recommendations = discovery_tool._process_recommendations(
|
||||
malformed_recommendations
|
||||
)
|
||||
assert len(recommendations) >= 1
|
||||
assert recommendations[0]["server"] == "valid-server"
|
||||
|
||||
def test_format_result_with_recommendations(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test formatting results with recommendations."""
|
||||
result: MCPDiscoveryResult = {
|
||||
"recommendations": [
|
||||
{
|
||||
"server": "test-server",
|
||||
"npm_package": "@test/server",
|
||||
"install_command": "npx -y @test/server",
|
||||
"confidence": 0.85,
|
||||
"description": "A test server",
|
||||
"capabilities": ["test", "demo"],
|
||||
"metrics": {
|
||||
"avg_latency_ms": 100.0,
|
||||
"uptime_pct": 99.5,
|
||||
"last_checked": "2026-01-17T10:00:00Z",
|
||||
},
|
||||
"docs_url": "https://example.com/docs",
|
||||
"github_url": "https://github.com/test/server",
|
||||
}
|
||||
],
|
||||
"total_found": 1,
|
||||
"query_time_ms": 150,
|
||||
}
|
||||
formatted = discovery_tool._format_result(result)
|
||||
assert "Found 1 MCP server(s)" in formatted
|
||||
assert "test-server" in formatted
|
||||
assert "85% confidence" in formatted
|
||||
assert "A test server" in formatted
|
||||
assert "test, demo" in formatted
|
||||
assert "npx -y @test/server" in formatted
|
||||
assert "100.0ms" in formatted
|
||||
assert "99.5%" in formatted
|
||||
|
||||
def test_format_result_empty(self, discovery_tool: MCPDiscoveryTool) -> None:
|
||||
"""Test formatting results with no recommendations."""
|
||||
result: MCPDiscoveryResult = {
|
||||
"recommendations": [],
|
||||
"total_found": 0,
|
||||
"query_time_ms": 50,
|
||||
}
|
||||
formatted = discovery_tool._format_result(result)
|
||||
assert "No MCP servers found" in formatted
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_success(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test successful API request."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool._make_api_request({"need": "database", "limit": 5})
|
||||
|
||||
assert result == mock_api_response
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"] == {"need": "database", "limit": 5}
|
||||
assert call_args[1]["timeout"] == 30
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_with_api_key(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test API request with API key."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch.dict("os.environ", {"MCP_DISCOVERY_API_KEY": "test-key"}):
|
||||
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||
|
||||
call_args = mock_post.call_args
|
||||
assert "Authorization" in call_args[1]["headers"]
|
||||
assert call_args[1]["headers"]["Authorization"] == "Bearer test-key"
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_empty_response(
|
||||
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test API request with empty response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(ValueError, match="Empty response"):
|
||||
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_network_error(
|
||||
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test API request with network error."""
|
||||
mock_post.side_effect = requests.exceptions.ConnectionError("Network error")
|
||||
|
||||
with pytest.raises(requests.exceptions.ConnectionError):
|
||||
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_make_api_request_json_decode_error(
|
||||
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test API request with JSON decode error."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Error", "", 0)
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.content = b"invalid json"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
discovery_tool._make_api_request({"need": "test", "limit": 5})
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_run_success(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test successful _run execution."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool._run(need="database server")
|
||||
|
||||
assert "sqlite-server" in result
|
||||
assert "postgres-server" in result
|
||||
assert "Found 2 MCP server(s)" in result
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_run_with_constraints(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test _run with constraints."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool._run(
|
||||
need="database",
|
||||
constraints={"max_latency_ms": 100, "required_features": ["sql"]},
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert "sqlite-server" in result
|
||||
call_args = mock_post.call_args
|
||||
payload = call_args[1]["json"]
|
||||
assert payload["constraints"]["max_latency_ms"] == 100
|
||||
assert payload["constraints"]["required_features"] == ["sql"]
|
||||
assert payload["limit"] == 3
|
||||
|
||||
def test_run_missing_need_parameter(
|
||||
self, discovery_tool: MCPDiscoveryTool
|
||||
) -> None:
|
||||
"""Test _run with missing need parameter."""
|
||||
with pytest.raises(ValueError, match="'need' parameter is required"):
|
||||
discovery_tool._run()
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_discover_method(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test the discover convenience method."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool.discover(
|
||||
need="database",
|
||||
constraints=MCPDiscoveryConstraints(max_latency_ms=200),
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "recommendations" in result
|
||||
assert "total_found" in result
|
||||
assert "query_time_ms" in result
|
||||
assert len(result["recommendations"]) == 2
|
||||
assert result["total_found"] == 2
|
||||
|
||||
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
|
||||
def test_discover_returns_structured_data(
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
discovery_tool: MCPDiscoveryTool,
|
||||
mock_api_response: dict,
|
||||
) -> None:
|
||||
"""Test that discover returns properly structured data."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discovery_tool.discover(need="database")
|
||||
|
||||
first_rec = result["recommendations"][0]
|
||||
assert "server" in first_rec
|
||||
assert "npm_package" in first_rec
|
||||
assert "install_command" in first_rec
|
||||
assert "confidence" in first_rec
|
||||
assert "description" in first_rec
|
||||
assert "capabilities" in first_rec
|
||||
assert "metrics" in first_rec
|
||||
assert "docs_url" in first_rec
|
||||
assert "github_url" in first_rec
|
||||
|
||||
|
||||
class TestMCPDiscoveryToolIntegration:
|
||||
"""Integration tests for MCPDiscoveryTool (requires network)."""
|
||||
|
||||
@pytest.mark.skip(reason="Integration test - requires network access")
|
||||
def test_real_api_call(self) -> None:
|
||||
"""Test actual API call to MCP Discovery service."""
|
||||
tool = MCPDiscoveryTool()
|
||||
result = tool._run(need="database", limit=3)
|
||||
assert "MCP server" in result or "No MCP servers found" in result
|
||||
@@ -3,9 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
import uuid
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
@@ -20,7 +21,10 @@ from a2a.types import (
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -55,7 +59,8 @@ class TaskStateResult(TypedDict):
|
||||
history: list[Message]
|
||||
result: NotRequired[str]
|
||||
error: NotRequired[str]
|
||||
agent_card: NotRequired[AgentCard]
|
||||
agent_card: NotRequired[dict[str, Any]]
|
||||
a2a_agent_name: NotRequired[str | None]
|
||||
|
||||
|
||||
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
|
||||
@@ -131,50 +136,69 @@ def process_task_state(
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
result_parts: list[str] | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
is_final: bool = True,
|
||||
) -> TaskStateResult | None:
|
||||
"""Process A2A task state and return result dictionary.
|
||||
|
||||
Shared logic for both polling and streaming handlers.
|
||||
|
||||
Args:
|
||||
a2a_task: The A2A task to process
|
||||
new_messages: List to collect messages (modified in place)
|
||||
agent_card: The agent card
|
||||
turn_number: Current turn number
|
||||
is_multiturn: Whether multi-turn conversation
|
||||
agent_role: Agent role for logging
|
||||
a2a_task: The A2A task to process.
|
||||
new_messages: List to collect messages (modified in place).
|
||||
agent_card: The agent card.
|
||||
turn_number: Current turn number.
|
||||
is_multiturn: Whether multi-turn conversation.
|
||||
agent_role: Agent role for logging.
|
||||
result_parts: Accumulated result parts (streaming passes accumulated,
|
||||
polling passes None to extract from task)
|
||||
polling passes None to extract from task).
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
from_task: Optional CrewAI Task for event metadata.
|
||||
from_agent: Optional CrewAI Agent for event metadata.
|
||||
is_final: Whether this is the final response in the stream.
|
||||
|
||||
Returns:
|
||||
Result dictionary if terminal/actionable state, None otherwise
|
||||
Result dictionary if terminal/actionable state, None otherwise.
|
||||
"""
|
||||
should_extract = result_parts is None
|
||||
if result_parts is None:
|
||||
result_parts = []
|
||||
|
||||
if a2a_task.status.state == TaskState.completed:
|
||||
if should_extract:
|
||||
if not result_parts:
|
||||
extracted_parts = extract_task_result_parts(a2a_task)
|
||||
result_parts.extend(extracted_parts)
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
message_id = None
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
message_id = a2a_task.status.message.message_id
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=a2a_task.context_id,
|
||||
message_id=message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
final=is_final,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
agent_card=agent_card,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -194,14 +218,24 @@ def process_task_state(
|
||||
)
|
||||
new_messages.append(agent_message)
|
||||
|
||||
input_message_id = None
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
input_message_id = a2a_task.status.message.message_id
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=a2a_task.context_id,
|
||||
message_id=input_message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="input_required",
|
||||
final=is_final,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -209,7 +243,7 @@ def process_task_state(
|
||||
status=TaskState.input_required,
|
||||
error=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
|
||||
@@ -248,6 +282,11 @@ async def send_message_and_get_task_id(
|
||||
turn_number: int,
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
context_id: str | None = None,
|
||||
) -> str | TaskStateResult:
|
||||
"""Send message and process initial response.
|
||||
|
||||
@@ -262,6 +301,11 @@ async def send_message_and_get_task_id(
|
||||
turn_number: Current turn number
|
||||
is_multiturn: Whether multi-turn conversation
|
||||
agent_role: Agent role for logging
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
endpoint: Optional A2A endpoint URL.
|
||||
a2a_agent_name: Optional A2A agent name.
|
||||
context_id: Optional A2A context ID for correlation.
|
||||
|
||||
Returns:
|
||||
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
|
||||
@@ -280,9 +324,16 @@ async def send_message_and_get_task_id(
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=event.context_id,
|
||||
message_id=event.message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -290,7 +341,7 @@ async def send_message_and_get_task_id(
|
||||
status=TaskState.completed,
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
if isinstance(event, tuple):
|
||||
@@ -304,6 +355,10 @@ async def send_message_and_get_task_id(
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
@@ -316,6 +371,99 @@ async def send_message_and_get_task_id(
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="send_message",
|
||||
context_id=context_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during send_message: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="send_message",
|
||||
context_id=context_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
finally:
|
||||
aclose = getattr(event_stream, "aclose", None)
|
||||
if aclose:
|
||||
|
||||
@@ -22,6 +22,13 @@ class BaseHandlerKwargs(TypedDict, total=False):
|
||||
turn_number: int
|
||||
is_multiturn: bool
|
||||
agent_role: str | None
|
||||
context_id: str | None
|
||||
task_id: str | None
|
||||
endpoint: str | None
|
||||
agent_branch: Any
|
||||
a2a_agent_name: str | None
|
||||
from_task: Any
|
||||
from_agent: Any
|
||||
|
||||
|
||||
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
@@ -29,8 +36,6 @@ class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
|
||||
polling_interval: float
|
||||
polling_timeout: float
|
||||
endpoint: str
|
||||
agent_branch: Any
|
||||
history_length: int
|
||||
max_polls: int | None
|
||||
|
||||
@@ -38,9 +43,6 @@ class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for streaming handler."""
|
||||
|
||||
context_id: str | None
|
||||
task_id: str | None
|
||||
|
||||
|
||||
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for push notification handler."""
|
||||
@@ -49,7 +51,6 @@ class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
result_store: PushNotificationResultStore
|
||||
polling_timeout: float
|
||||
polling_interval: float
|
||||
agent_branch: Any
|
||||
|
||||
|
||||
class PushNotificationResultStore(Protocol):
|
||||
|
||||
@@ -31,6 +31,7 @@ from crewai.a2a.task_helpers import (
|
||||
from crewai.a2a.updates.base import PollingHandlerKwargs
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2APollingStartedEvent,
|
||||
A2APollingStatusEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
@@ -49,23 +50,33 @@ async def _poll_task_until_complete(
|
||||
agent_branch: Any | None = None,
|
||||
history_length: int = 100,
|
||||
max_polls: int | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
context_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> A2ATask:
|
||||
"""Poll task status until terminal state reached.
|
||||
|
||||
Args:
|
||||
client: A2A client instance
|
||||
task_id: Task ID to poll
|
||||
polling_interval: Seconds between poll attempts
|
||||
polling_timeout: Max seconds before timeout
|
||||
agent_branch: Agent tree branch for logging
|
||||
history_length: Number of messages to retrieve per poll
|
||||
max_polls: Max number of poll attempts (None = unlimited)
|
||||
client: A2A client instance.
|
||||
task_id: Task ID to poll.
|
||||
polling_interval: Seconds between poll attempts.
|
||||
polling_timeout: Max seconds before timeout.
|
||||
agent_branch: Agent tree branch for logging.
|
||||
history_length: Number of messages to retrieve per poll.
|
||||
max_polls: Max number of poll attempts (None = unlimited).
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
context_id: A2A context ID for correlation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
|
||||
Returns:
|
||||
Final task object in terminal state
|
||||
Final task object in terminal state.
|
||||
|
||||
Raises:
|
||||
A2APollingTimeoutError: If polling exceeds timeout or max_polls
|
||||
A2APollingTimeoutError: If polling exceeds timeout or max_polls.
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
poll_count = 0
|
||||
@@ -77,13 +88,19 @@ async def _poll_task_until_complete(
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
effective_context_id = task.context_id or context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APollingStatusEvent(
|
||||
task_id=task_id,
|
||||
context_id=effective_context_id,
|
||||
state=str(task.status.state.value) if task.status.state else "unknown",
|
||||
elapsed_seconds=elapsed,
|
||||
poll_count=poll_count,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -137,6 +154,9 @@ class PollingHandler:
|
||||
max_polls = kwargs.get("max_polls")
|
||||
context_id = kwargs.get("context_id")
|
||||
task_id = kwargs.get("task_id")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
|
||||
try:
|
||||
result_or_task_id = await send_message_and_get_task_id(
|
||||
@@ -146,6 +166,11 @@ class PollingHandler:
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
@@ -157,8 +182,12 @@ class PollingHandler:
|
||||
agent_branch,
|
||||
A2APollingStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
polling_interval=polling_interval,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -170,6 +199,11 @@ class PollingHandler:
|
||||
agent_branch=agent_branch,
|
||||
history_length=history_length,
|
||||
max_polls=max_polls,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
)
|
||||
|
||||
result = process_task_state(
|
||||
@@ -179,6 +213,10 @@ class PollingHandler:
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
@@ -206,9 +244,15 @@ class PollingHandler:
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
@@ -229,14 +273,83 @@ class PollingHandler:
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="polling",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during polling: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="polling",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
|
||||
@@ -29,6 +29,7 @@ from crewai.a2a.updates.base import (
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2APushNotificationRegisteredEvent,
|
||||
A2APushNotificationTimeoutEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
@@ -48,6 +49,11 @@ async def _wait_for_push_result(
|
||||
timeout: float,
|
||||
poll_interval: float,
|
||||
agent_branch: Any | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
context_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> A2ATask | None:
|
||||
"""Wait for push notification result.
|
||||
|
||||
@@ -57,6 +63,11 @@ async def _wait_for_push_result(
|
||||
timeout: Max seconds to wait.
|
||||
poll_interval: Seconds between polling attempts.
|
||||
agent_branch: Agent tree branch for logging.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
context_id: A2A context ID for correlation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent.
|
||||
|
||||
Returns:
|
||||
Final task object, or None if timeout.
|
||||
@@ -72,7 +83,12 @@ async def _wait_for_push_result(
|
||||
agent_branch,
|
||||
A2APushNotificationTimeoutEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
timeout_seconds=timeout,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -115,18 +131,56 @@ class PushNotificationHandler:
|
||||
agent_role = kwargs.get("agent_role")
|
||||
context_id = kwargs.get("context_id")
|
||||
task_id = kwargs.get("task_id")
|
||||
endpoint = kwargs.get("endpoint")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
|
||||
if config is None:
|
||||
error_msg = (
|
||||
"PushNotificationConfig is required for push notification handler"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error="PushNotificationConfig is required for push notification handler",
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if result_store is None:
|
||||
error_msg = (
|
||||
"PushNotificationResultStore is required for push notification handler"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error="PushNotificationResultStore is required for push notification handler",
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
@@ -138,6 +192,11 @@ class PushNotificationHandler:
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
@@ -149,7 +208,12 @@ class PushNotificationHandler:
|
||||
agent_branch,
|
||||
A2APushNotificationRegisteredEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
callback_url=str(config.url),
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -165,6 +229,11 @@ class PushNotificationHandler:
|
||||
timeout=polling_timeout,
|
||||
poll_interval=polling_interval,
|
||||
agent_branch=agent_branch,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
)
|
||||
|
||||
if final_task is None:
|
||||
@@ -181,6 +250,10 @@ class PushNotificationHandler:
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
@@ -203,14 +276,83 @@ class PushNotificationHandler:
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during push notification: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
|
||||
@@ -26,7 +26,13 @@ from crewai.a2a.task_helpers import (
|
||||
)
|
||||
from crewai.a2a.updates.base import StreamingHandlerKwargs
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AArtifactReceivedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
A2AStreamingChunkEvent,
|
||||
A2AStreamingStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
class StreamingHandler:
|
||||
@@ -57,19 +63,57 @@ class StreamingHandler:
|
||||
turn_number = kwargs.get("turn_number", 0)
|
||||
is_multiturn = kwargs.get("is_multiturn", False)
|
||||
agent_role = kwargs.get("agent_role")
|
||||
endpoint = kwargs.get("endpoint")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
|
||||
result_parts: list[str] = []
|
||||
final_result: TaskStateResult | None = None
|
||||
event_stream = client.send_message(message)
|
||||
chunk_index = 0
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint or "",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
message_context_id = event.context_id or context_id
|
||||
for part in event.parts:
|
||||
if part.root.kind == "text":
|
||||
text = part.root.text
|
||||
result_parts.append(text)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingChunkEvent(
|
||||
task_id=event.task_id or task_id,
|
||||
context_id=message_context_id,
|
||||
chunk=text,
|
||||
chunk_index=chunk_index,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
chunk_index += 1
|
||||
|
||||
elif isinstance(event, tuple):
|
||||
a2a_task, update = event
|
||||
@@ -81,10 +125,51 @@ class StreamingHandler:
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
artifact_size = None
|
||||
if artifact.parts:
|
||||
artifact_size = sum(
|
||||
len(p.root.text.encode("utf-8"))
|
||||
if p.root.kind == "text"
|
||||
else len(getattr(p.root, "data", b""))
|
||||
for p in artifact.parts
|
||||
)
|
||||
effective_context_id = a2a_task.context_id or context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AArtifactReceivedEvent(
|
||||
task_id=a2a_task.id,
|
||||
artifact_id=artifact.artifact_id,
|
||||
artifact_name=artifact.name,
|
||||
artifact_description=artifact.description,
|
||||
mime_type=artifact.parts[0].root.kind
|
||||
if artifact.parts
|
||||
else None,
|
||||
size_bytes=artifact_size,
|
||||
append=update.append or False,
|
||||
last_chunk=update.last_chunk or False,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=effective_context_id,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
is_final_update = False
|
||||
if isinstance(update, TaskStatusUpdateEvent):
|
||||
is_final_update = update.final
|
||||
if (
|
||||
update.status
|
||||
and update.status.message
|
||||
and update.status.message.parts
|
||||
):
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in update.status.message.parts
|
||||
if part.root.kind == "text" and part.root.text
|
||||
)
|
||||
|
||||
if (
|
||||
not is_final_update
|
||||
@@ -101,6 +186,11 @@ class StreamingHandler:
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
if final_result:
|
||||
break
|
||||
@@ -118,13 +208,82 @@ class StreamingHandler:
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during streaming: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
@@ -136,7 +295,23 @@ class StreamingHandler:
|
||||
finally:
|
||||
aclose = getattr(event_stream, "aclose", None)
|
||||
if aclose:
|
||||
await aclose()
|
||||
try:
|
||||
await aclose()
|
||||
except Exception as close_error:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(close_error),
|
||||
error_type="stream_close_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="stream_close",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if final_result:
|
||||
return final_result
|
||||
@@ -145,5 +320,5 @@ class StreamingHandler:
|
||||
status=TaskState.completed,
|
||||
result=" ".join(result_parts) if result_parts else "",
|
||||
history=new_messages,
|
||||
agent_card=agent_card,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
@@ -23,6 +23,12 @@ from crewai.a2a.auth.utils import (
|
||||
)
|
||||
from crewai.a2a.config import A2AServerConfig
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AAgentCardFetchedEvent,
|
||||
A2AAuthenticationFailedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -183,6 +189,8 @@ async def _afetch_agent_card_impl(
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Internal async implementation of AgentCard fetching."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if "/.well-known/agent-card.json" in endpoint:
|
||||
base_url = endpoint.replace("/.well-known/agent-card.json", "")
|
||||
agent_card_path = "/.well-known/agent-card.json"
|
||||
@@ -217,9 +225,29 @@ async def _afetch_agent_card_impl(
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return AgentCard.model_validate(response.json())
|
||||
agent_card = AgentCard.model_validate(response.json())
|
||||
fetch_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AAgentCardFetchedEvent(
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
cached=False,
|
||||
fetch_time_ms=fetch_time_ms,
|
||||
),
|
||||
)
|
||||
|
||||
return agent_card
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
response_body = e.response.text[:1000] if e.response.text else None
|
||||
|
||||
if e.response.status_code == 401:
|
||||
error_details = ["Authentication failed"]
|
||||
www_auth = e.response.headers.get("WWW-Authenticate")
|
||||
@@ -228,7 +256,93 @@ async def _afetch_agent_card_impl(
|
||||
if not auth:
|
||||
error_details.append("No auth scheme provided")
|
||||
msg = " | ".join(error_details)
|
||||
|
||||
auth_type = type(auth).__name__ if auth else None
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AAuthenticationFailedEvent(
|
||||
endpoint=endpoint,
|
||||
auth_type=auth_type,
|
||||
error=msg,
|
||||
status_code=401,
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"response_body": response_body,
|
||||
"www_authenticate": www_auth,
|
||||
"request_url": str(e.request.url),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
raise A2AClientHTTPError(401, msg) from e
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.response.status_code,
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"response_body": response_body,
|
||||
"request_url": str(e.request.url),
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="timeout",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"timeout_config": timeout,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="connection_error",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.RequestError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="request_error",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -88,6 +88,9 @@ def execute_a2a_delegation(
|
||||
response_model: type[BaseModel] | None = None,
|
||||
turn_number: int | None = None,
|
||||
updates: UpdateConfig | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent synchronously.
|
||||
|
||||
@@ -129,6 +132,9 @@ def execute_a2a_delegation(
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
turn_number: Optional turn number for multi-turn conversations.
|
||||
updates: Update mechanism config from A2AConfig.updates.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
@@ -156,10 +162,16 @@ def execute_a2a_delegation(
|
||||
transport_protocol=transport_protocol,
|
||||
turn_number=turn_number,
|
||||
updates=updates,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
try:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def aexecute_a2a_delegation(
|
||||
@@ -181,6 +193,9 @@ async def aexecute_a2a_delegation(
|
||||
response_model: type[BaseModel] | None = None,
|
||||
turn_number: int | None = None,
|
||||
updates: UpdateConfig | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent asynchronously.
|
||||
|
||||
@@ -222,6 +237,9 @@ async def aexecute_a2a_delegation(
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
turn_number: Optional turn number for multi-turn conversations.
|
||||
updates: Update mechanism config from A2AConfig.updates.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
@@ -233,17 +251,6 @@ async def aexecute_a2a_delegation(
|
||||
if turn_number is None:
|
||||
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationStartedEvent(
|
||||
endpoint=endpoint,
|
||||
task_description=task_description,
|
||||
agent_id=agent_id,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
),
|
||||
)
|
||||
|
||||
result = await _aexecute_a2a_delegation_impl(
|
||||
endpoint=endpoint,
|
||||
auth=auth,
|
||||
@@ -264,15 +271,28 @@ async def aexecute_a2a_delegation(
|
||||
response_model=response_model,
|
||||
updates=updates,
|
||||
transport_protocol=transport_protocol,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
)
|
||||
|
||||
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
status=result["status"],
|
||||
result=result.get("result"),
|
||||
error=result.get("error"),
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=result.get("a2a_agent_name"),
|
||||
agent_card=agent_card_data,
|
||||
provider=agent_card_data.get("provider"),
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -299,6 +319,9 @@ async def _aexecute_a2a_delegation_impl(
|
||||
agent_role: str | None,
|
||||
response_model: type[BaseModel] | None,
|
||||
updates: UpdateConfig | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Internal async implementation of A2A delegation."""
|
||||
if auth:
|
||||
@@ -331,6 +354,28 @@ async def _aexecute_a2a_delegation_impl(
|
||||
if agent_card.name:
|
||||
a2a_agent_name = agent_card.name
|
||||
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationStartedEvent(
|
||||
endpoint=endpoint,
|
||||
task_description=task_description,
|
||||
agent_id=agent_id or endpoint,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if turn_number == 1:
|
||||
agent_id_for_event = agent_id or endpoint
|
||||
crewai_event_bus.emit(
|
||||
@@ -338,7 +383,17 @@ async def _aexecute_a2a_delegation_impl(
|
||||
A2AConversationStartedEvent(
|
||||
agent_id=agent_id_for_event,
|
||||
endpoint=endpoint,
|
||||
context_id=context_id,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -364,6 +419,10 @@ async def _aexecute_a2a_delegation_impl(
|
||||
}
|
||||
)
|
||||
|
||||
message_metadata = metadata.copy() if metadata else {}
|
||||
if skill_id:
|
||||
message_metadata["skill_id"] = skill_id
|
||||
|
||||
message = Message(
|
||||
role=Role.user,
|
||||
message_id=str(uuid.uuid4()),
|
||||
@@ -371,7 +430,7 @@ async def _aexecute_a2a_delegation_impl(
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
metadata=message_metadata if message_metadata else None,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
@@ -381,8 +440,17 @@ async def _aexecute_a2a_delegation_impl(
|
||||
A2AMessageSentEvent(
|
||||
message=message_text,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
message_id=message.message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
skill_id=skill_id,
|
||||
metadata=message_metadata if message_metadata else None,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -397,6 +465,9 @@ async def _aexecute_a2a_delegation_impl(
|
||||
"task_id": task_id,
|
||||
"endpoint": endpoint,
|
||||
"agent_branch": agent_branch,
|
||||
"a2a_agent_name": a2a_agent_name,
|
||||
"from_task": from_task,
|
||||
"from_agent": from_agent,
|
||||
}
|
||||
|
||||
if isinstance(updates, PollingConfig):
|
||||
@@ -434,13 +505,16 @@ async def _aexecute_a2a_delegation_impl(
|
||||
use_polling=use_polling,
|
||||
push_notification_config=push_config_for_client,
|
||||
) as client:
|
||||
return await handler.execute(
|
||||
result = await handler.execute(
|
||||
client=client,
|
||||
message=message,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
**handler_kwargs,
|
||||
)
|
||||
result["a2a_agent_name"] = a2a_agent_name
|
||||
result["agent_card"] = agent_card.model_dump(exclude_none=True)
|
||||
return result
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
@@ -45,7 +48,14 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
def _parse_redis_url(url: str) -> dict[str, Any]:
|
||||
from urllib.parse import urlparse
|
||||
"""Parse a Redis URL into aiocache configuration.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL (e.g., redis://localhost:6379/0).
|
||||
|
||||
Returns:
|
||||
Configuration dict for aiocache.RedisCache.
|
||||
"""
|
||||
|
||||
parsed = urlparse(url)
|
||||
config: dict[str, Any] = {
|
||||
@@ -127,7 +137,7 @@ def cancellable(
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
return True
|
||||
except Exception as e:
|
||||
except (OSError, ConnectionError) as e:
|
||||
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
|
||||
return await poll_for_cancel()
|
||||
return False
|
||||
@@ -183,7 +193,12 @@ async def execute(
|
||||
msg = "task_id and context_id are required"
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskFailedEvent(a2a_task_id="", a2a_context_id="", error=msg),
|
||||
A2AServerTaskFailedEvent(
|
||||
task_id="",
|
||||
context_id="",
|
||||
error=msg,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise ServerError(InvalidParamsError(message=msg)) from None
|
||||
|
||||
@@ -195,7 +210,12 @@ async def execute(
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskStartedEvent(a2a_task_id=task_id, a2a_context_id=context_id),
|
||||
A2AServerTaskStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -215,20 +235,33 @@ async def execute(
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskCompletedEvent(
|
||||
a2a_task_id=task_id, a2a_context_id=context_id, result=str(result)
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
result=str(result),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskCanceledEvent(a2a_task_id=task_id, a2a_context_id=context_id),
|
||||
A2AServerTaskCanceledEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskFailedEvent(
|
||||
a2a_task_id=task_id, a2a_context_id=context_id, error=str(e)
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise ServerError(
|
||||
@@ -282,3 +315,85 @@ async def cancel(
|
||||
context.current_task.status = TaskStatus(state=TaskState.canceled)
|
||||
return context.current_task
|
||||
return None
|
||||
|
||||
|
||||
def list_tasks(
|
||||
tasks: list[A2ATask],
|
||||
context_id: str | None = None,
|
||||
status: TaskState | None = None,
|
||||
status_timestamp_after: datetime | None = None,
|
||||
page_size: int = 50,
|
||||
page_token: str | None = None,
|
||||
history_length: int | None = None,
|
||||
include_artifacts: bool = False,
|
||||
) -> tuple[list[A2ATask], str | None, int]:
|
||||
"""Filter and paginate A2A tasks.
|
||||
|
||||
Provides filtering by context, status, and timestamp, along with
|
||||
cursor-based pagination. This is a pure utility function that operates
|
||||
on an in-memory list of tasks - storage retrieval is handled separately.
|
||||
|
||||
Args:
|
||||
tasks: All tasks to filter.
|
||||
context_id: Filter by context ID to get tasks in a conversation.
|
||||
status: Filter by task state (e.g., completed, working).
|
||||
status_timestamp_after: Filter to tasks updated after this time.
|
||||
page_size: Maximum tasks per page (default 50).
|
||||
page_token: Base64-encoded cursor from previous response.
|
||||
history_length: Limit history messages per task (None = full history).
|
||||
include_artifacts: Whether to include task artifacts (default False).
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_tasks, next_page_token, total_count).
|
||||
- filtered_tasks: Tasks matching filters, paginated and trimmed.
|
||||
- next_page_token: Token for next page, or None if no more pages.
|
||||
- total_count: Total number of tasks matching filters (before pagination).
|
||||
"""
|
||||
filtered: list[A2ATask] = []
|
||||
for task in tasks:
|
||||
if context_id and task.context_id != context_id:
|
||||
continue
|
||||
if status and task.status.state != status:
|
||||
continue
|
||||
if status_timestamp_after and task.status.timestamp:
|
||||
ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00"))
|
||||
if ts <= status_timestamp_after:
|
||||
continue
|
||||
filtered.append(task)
|
||||
|
||||
def get_timestamp(t: A2ATask) -> datetime:
|
||||
"""Extract timestamp from task status for sorting."""
|
||||
if t.status.timestamp is None:
|
||||
return datetime.min
|
||||
return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00"))
|
||||
|
||||
filtered.sort(key=get_timestamp, reverse=True)
|
||||
total = len(filtered)
|
||||
|
||||
start = 0
|
||||
if page_token:
|
||||
try:
|
||||
cursor_id = base64.b64decode(page_token).decode()
|
||||
for idx, task in enumerate(filtered):
|
||||
if task.id == cursor_id:
|
||||
start = idx + 1
|
||||
break
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
pass
|
||||
|
||||
page = filtered[start : start + page_size]
|
||||
|
||||
result: list[A2ATask] = []
|
||||
for task in page:
|
||||
task = task.model_copy(deep=True)
|
||||
if history_length is not None and task.history:
|
||||
task.history = task.history[-history_length:]
|
||||
if not include_artifacts:
|
||||
task.artifacts = None
|
||||
result.append(task)
|
||||
|
||||
next_token: str | None = None
|
||||
if result and len(result) == page_size:
|
||||
next_token = base64.b64encode(result[-1].id.encode()).decode()
|
||||
|
||||
return result, next_token, total
|
||||
|
||||
@@ -6,9 +6,10 @@ Wraps agent classes with A2A delegation capabilities.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from collections.abc import Callable, Coroutine, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from functools import wraps
|
||||
import json
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -189,7 +190,7 @@ def _execute_task_with_a2a(
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
original_fn: Callable[..., str],
|
||||
task: Task,
|
||||
agent_response_model: type[BaseModel],
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
extension_registry: ExtensionRegistry,
|
||||
@@ -277,7 +278,7 @@ def _execute_task_with_a2a(
|
||||
def _augment_prompt_with_a2a(
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
task_description: str,
|
||||
agent_cards: dict[str, AgentCard],
|
||||
agent_cards: Mapping[str, AgentCard | dict[str, Any]],
|
||||
conversation_history: list[Message] | None = None,
|
||||
turn_num: int = 0,
|
||||
max_turns: int | None = None,
|
||||
@@ -309,7 +310,15 @@ def _augment_prompt_with_a2a(
|
||||
for config in a2a_agents:
|
||||
if config.endpoint in agent_cards:
|
||||
card = agent_cards[config.endpoint]
|
||||
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
|
||||
if isinstance(card, dict):
|
||||
filtered = {
|
||||
k: v
|
||||
for k, v in card.items()
|
||||
if k in {"description", "url", "skills"} and v is not None
|
||||
}
|
||||
agents_text += f"\n{json.dumps(filtered, indent=2)}\n"
|
||||
else:
|
||||
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
|
||||
|
||||
failed_agents = failed_agents or {}
|
||||
if failed_agents:
|
||||
@@ -377,7 +386,7 @@ IMPORTANT: You have the ability to delegate this task to remote A2A agents.
|
||||
|
||||
|
||||
def _parse_agent_response(
|
||||
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
|
||||
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel] | None
|
||||
) -> BaseModel | str | dict[str, Any]:
|
||||
"""Parse LLM output as AgentResponse or return raw agent response."""
|
||||
if agent_response_model:
|
||||
@@ -394,6 +403,11 @@ def _parse_agent_response(
|
||||
def _handle_max_turns_exceeded(
|
||||
conversation_history: list[Message],
|
||||
max_turns: int,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Handle the case when max turns is exceeded.
|
||||
|
||||
@@ -421,6 +435,11 @@ def _handle_max_turns_exceeded(
|
||||
final_result=final_message,
|
||||
error=None,
|
||||
total_turns=max_turns,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
return final_message
|
||||
@@ -432,6 +451,11 @@ def _handle_max_turns_exceeded(
|
||||
final_result=None,
|
||||
error=f"Conversation exceeded maximum turns ({max_turns})",
|
||||
total_turns=max_turns,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})")
|
||||
@@ -442,7 +466,12 @@ def _process_response_result(
|
||||
disable_structured_output: bool,
|
||||
turn_num: int,
|
||||
agent_role: str,
|
||||
agent_response_model: type[BaseModel],
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Process LLM response and determine next action.
|
||||
|
||||
@@ -461,6 +490,10 @@ def _process_response_result(
|
||||
turn_number=final_turn_number,
|
||||
is_multiturn=True,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
@@ -470,6 +503,11 @@ def _process_response_result(
|
||||
final_result=result_text,
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
return result_text, None
|
||||
@@ -490,6 +528,10 @@ def _process_response_result(
|
||||
turn_number=final_turn_number,
|
||||
is_multiturn=True,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
@@ -499,6 +541,11 @@ def _process_response_result(
|
||||
final_result=str(llm_response.message),
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
return str(llm_response.message), None
|
||||
@@ -510,13 +557,15 @@ def _process_response_result(
|
||||
def _prepare_agent_cards_dict(
|
||||
a2a_result: TaskStateResult,
|
||||
agent_id: str,
|
||||
agent_cards: dict[str, AgentCard] | None,
|
||||
) -> dict[str, AgentCard]:
|
||||
agent_cards: Mapping[str, AgentCard | dict[str, Any]] | None,
|
||||
) -> dict[str, AgentCard | dict[str, Any]]:
|
||||
"""Prepare agent cards dictionary from result and existing cards.
|
||||
|
||||
Shared logic for both sync and async response handlers.
|
||||
"""
|
||||
agent_cards_dict = agent_cards or {}
|
||||
agent_cards_dict: dict[str, AgentCard | dict[str, Any]] = (
|
||||
dict(agent_cards) if agent_cards else {}
|
||||
)
|
||||
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
|
||||
agent_cards_dict[agent_id] = a2a_result["agent_card"]
|
||||
return agent_cards_dict
|
||||
@@ -529,7 +578,7 @@ def _prepare_delegation_context(
|
||||
original_task_description: str | None,
|
||||
) -> tuple[
|
||||
list[A2AConfig | A2AClientConfig],
|
||||
type[BaseModel],
|
||||
type[BaseModel] | None,
|
||||
str,
|
||||
str,
|
||||
A2AConfig | A2AClientConfig,
|
||||
@@ -598,6 +647,11 @@ def _handle_task_completion(
|
||||
reference_task_ids: list[str],
|
||||
agent_config: A2AConfig | A2AClientConfig,
|
||||
turn_num: int,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, str | None, list[str]]:
|
||||
"""Handle task completion state including reference task updates.
|
||||
|
||||
@@ -624,6 +678,11 @@ def _handle_task_completion(
|
||||
final_result=result_text,
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
return str(result_text), task_id_config, reference_task_ids
|
||||
@@ -645,8 +704,11 @@ def _handle_agent_response_and_continue(
|
||||
original_fn: Callable[..., str],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
agent_response_model: type[BaseModel],
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
remote_task_completed: bool = False,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Handle A2A result and get CrewAI agent's response.
|
||||
|
||||
@@ -698,6 +760,11 @@ def _handle_agent_response_and_continue(
|
||||
turn_num=turn_num,
|
||||
agent_role=self.role,
|
||||
agent_response_model=agent_response_model,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
|
||||
@@ -750,6 +817,12 @@ def _delegate_to_a2a(
|
||||
|
||||
conversation_history: list[Message] = []
|
||||
|
||||
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
|
||||
current_agent_card_dict = (
|
||||
current_agent_card.model_dump() if current_agent_card else None
|
||||
)
|
||||
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
|
||||
|
||||
try:
|
||||
for turn_num in range(max_turns):
|
||||
console_formatter = getattr(crewai_event_bus, "_console", None)
|
||||
@@ -777,6 +850,8 @@ def _delegate_to_a2a(
|
||||
turn_number=turn_num + 1,
|
||||
updates=agent_config.updates,
|
||||
transport_protocol=agent_config.transport_protocol,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
)
|
||||
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
@@ -797,6 +872,11 @@ def _delegate_to_a2a(
|
||||
reference_task_ids,
|
||||
agent_config,
|
||||
turn_num,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
)
|
||||
if trusted_result is not None:
|
||||
@@ -818,6 +898,9 @@ def _delegate_to_a2a(
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
remote_task_completed=(a2a_result["status"] == TaskState.completed),
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
@@ -846,6 +929,9 @@ def _delegate_to_a2a(
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
remote_task_completed=False,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
@@ -862,11 +948,24 @@ def _delegate_to_a2a(
|
||||
final_result=None,
|
||||
error=error_msg,
|
||||
total_turns=turn_num + 1,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
),
|
||||
)
|
||||
return f"A2A delegation failed: {error_msg}"
|
||||
|
||||
return _handle_max_turns_exceeded(conversation_history, max_turns)
|
||||
return _handle_max_turns_exceeded(
|
||||
conversation_history,
|
||||
max_turns,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
finally:
|
||||
task.description = original_task_description
|
||||
@@ -916,7 +1015,7 @@ async def _aexecute_task_with_a2a(
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
original_fn: Callable[..., Coroutine[Any, Any, str]],
|
||||
task: Task,
|
||||
agent_response_model: type[BaseModel],
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
extension_registry: ExtensionRegistry,
|
||||
@@ -1001,8 +1100,11 @@ async def _ahandle_agent_response_and_continue(
|
||||
original_fn: Callable[..., Coroutine[Any, Any, str]],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
agent_response_model: type[BaseModel],
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
remote_task_completed: bool = False,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Async version of _handle_agent_response_and_continue."""
|
||||
agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards)
|
||||
@@ -1032,6 +1134,11 @@ async def _ahandle_agent_response_and_continue(
|
||||
turn_num=turn_num,
|
||||
agent_role=self.role,
|
||||
agent_response_model=agent_response_model,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
|
||||
@@ -1066,6 +1173,12 @@ async def _adelegate_to_a2a(
|
||||
|
||||
conversation_history: list[Message] = []
|
||||
|
||||
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
|
||||
current_agent_card_dict = (
|
||||
current_agent_card.model_dump() if current_agent_card else None
|
||||
)
|
||||
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
|
||||
|
||||
try:
|
||||
for turn_num in range(max_turns):
|
||||
console_formatter = getattr(crewai_event_bus, "_console", None)
|
||||
@@ -1093,6 +1206,8 @@ async def _adelegate_to_a2a(
|
||||
turn_number=turn_num + 1,
|
||||
transport_protocol=agent_config.transport_protocol,
|
||||
updates=agent_config.updates,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
)
|
||||
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
@@ -1113,6 +1228,11 @@ async def _adelegate_to_a2a(
|
||||
reference_task_ids,
|
||||
agent_config,
|
||||
turn_num,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
)
|
||||
if trusted_result is not None:
|
||||
@@ -1134,6 +1254,9 @@ async def _adelegate_to_a2a(
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
remote_task_completed=(a2a_result["status"] == TaskState.completed),
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
@@ -1161,6 +1284,9 @@ async def _adelegate_to_a2a(
|
||||
context=context,
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
@@ -1177,11 +1303,24 @@ async def _adelegate_to_a2a(
|
||||
final_result=None,
|
||||
error=error_msg,
|
||||
total_turns=turn_num + 1,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
),
|
||||
)
|
||||
return f"A2A delegation failed: {error_msg}"
|
||||
|
||||
return _handle_max_turns_exceeded(conversation_history, max_turns)
|
||||
return _handle_max_turns_exceeded(
|
||||
conversation_history,
|
||||
max_turns,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
finally:
|
||||
task.description = original_task_description
|
||||
|
||||
@@ -725,17 +725,9 @@ class Agent(BaseAgent):
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
use_native_tool_calling = (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
and self.llm.supports_function_calling()
|
||||
and len(raw_tools) > 0
|
||||
)
|
||||
|
||||
prompt = Prompts(
|
||||
agent=self,
|
||||
has_tools=len(raw_tools) > 0,
|
||||
use_native_tool_calling=use_native_tool_calling,
|
||||
i18n=self.i18n,
|
||||
use_system_prompt=self.use_system_prompt,
|
||||
system_template=self.system_template,
|
||||
@@ -743,8 +735,6 @@ class Agent(BaseAgent):
|
||||
response_template=self.response_template,
|
||||
).task_execution()
|
||||
|
||||
print("prompt", prompt)
|
||||
|
||||
stop_words = [self.i18n.slice("observation")]
|
||||
|
||||
if self.response_template:
|
||||
|
||||
@@ -236,30 +236,14 @@ def process_tool_results(agent: Agent, result: Any) -> Any:
|
||||
def save_last_messages(agent: Agent) -> None:
|
||||
"""Save the last messages from agent executor.
|
||||
|
||||
Sanitizes messages to be compatible with TaskOutput's LLMMessage type,
|
||||
which only accepts 'user', 'assistant', 'system' roles and requires
|
||||
content to be a string or list (not None).
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
"""
|
||||
if not agent.agent_executor or not hasattr(agent.agent_executor, "messages"):
|
||||
agent._last_messages = []
|
||||
return
|
||||
|
||||
sanitized_messages = []
|
||||
for msg in agent.agent_executor.messages:
|
||||
role = msg.get("role", "")
|
||||
# Only include messages with valid LLMMessage roles
|
||||
if role not in ("user", "assistant", "system"):
|
||||
continue
|
||||
# Ensure content is not None (can happen with tool call assistant messages)
|
||||
content = msg.get("content")
|
||||
if content is None:
|
||||
content = ""
|
||||
sanitized_messages.append({"role": role, "content": content})
|
||||
|
||||
agent._last_messages = sanitized_messages
|
||||
agent._last_messages = (
|
||||
agent.agent_executor.messages.copy()
|
||||
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
|
||||
def prepare_tools(
|
||||
|
||||
@@ -30,7 +30,6 @@ from crewai.hooks.llm_hooks import (
|
||||
)
|
||||
from crewai.utilities.agent_utils import (
|
||||
aget_llm_response,
|
||||
convert_tools_to_openai_schema,
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
get_llm_response,
|
||||
@@ -216,33 +215,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""Execute agent loop until completion.
|
||||
|
||||
Checks if the LLM supports native function calling and uses that
|
||||
approach if available, otherwise falls back to the ReAct text pattern.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
# Check if model supports native function calling
|
||||
use_native_tools = (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
and self.llm.supports_function_calling()
|
||||
and self.original_tools
|
||||
)
|
||||
|
||||
if use_native_tools:
|
||||
return self._invoke_loop_native_tools()
|
||||
|
||||
# Fall back to ReAct text-based pattern
|
||||
return self._invoke_loop_react()
|
||||
|
||||
def _invoke_loop_react(self) -> AgentFinish:
|
||||
"""Execute agent loop using ReAct text-based pattern.
|
||||
|
||||
This is the traditional approach where tool definitions are embedded
|
||||
in the prompt and the LLM outputs Action/Action Input text that is
|
||||
parsed to execute tools.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
@@ -272,10 +244,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
print("--------------------------------")
|
||||
print("get_llm_response answer", answer)
|
||||
print("--------------------------------")
|
||||
# breakpoint()
|
||||
if self.response_model is not None:
|
||||
try:
|
||||
self.response_model.model_validate_json(answer)
|
||||
@@ -365,338 +333,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _invoke_loop_native_tools(self) -> AgentFinish:
|
||||
"""Execute agent loop using native function calling.
|
||||
|
||||
This method uses the LLM's native tool/function calling capability
|
||||
instead of the text-based ReAct pattern. The LLM directly returns
|
||||
structured tool calls which are executed and results fed back.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
print("--------------------------------")
|
||||
print("invoke_loop_native_tools")
|
||||
print("--------------------------------")
|
||||
# Convert tools to OpenAI schema format
|
||||
if not self.original_tools:
|
||||
# No tools available, fall back to simple LLM call
|
||||
return self._invoke_loop_native_no_tools()
|
||||
|
||||
openai_tools, available_functions = convert_tools_to_openai_schema(
|
||||
self.original_tools
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
None,
|
||||
printer=self._printer,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
# Debug: Show messages being sent to LLM
|
||||
print("--------------------------------")
|
||||
print(f"Messages count: {len(self.messages)}")
|
||||
for i, msg in enumerate(self.messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if content:
|
||||
preview = (
|
||||
content[:200] + "..." if len(content) > 200 else content
|
||||
)
|
||||
else:
|
||||
preview = "(no content)"
|
||||
print(f" [{i}] {role}: {preview}")
|
||||
print("--------------------------------")
|
||||
|
||||
# Call LLM with native tools
|
||||
# Pass available_functions=None so the LLM returns tool_calls
|
||||
# without executing them. The executor handles tool execution
|
||||
# via _handle_native_tool_calls to properly manage message history.
|
||||
answer = get_llm_response(
|
||||
llm=self.llm,
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
tools=openai_tools,
|
||||
available_functions=None,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
print("--------------------------------")
|
||||
print("invoke_loop_native_tools answer", answer)
|
||||
print("--------------------------------")
|
||||
# print("get_llm_response answer", answer[:500] + "...")
|
||||
|
||||
# Check if the response is a list of tool calls
|
||||
if (
|
||||
isinstance(answer, list)
|
||||
and answer
|
||||
and self._is_tool_call_list(answer)
|
||||
):
|
||||
# Handle tool calls - execute tools and add results to messages
|
||||
self._handle_native_tool_calls(answer, available_functions)
|
||||
# Continue loop to let LLM analyze results and decide next steps
|
||||
continue
|
||||
|
||||
# Text or other response - handle as potential final answer
|
||||
if isinstance(answer, str):
|
||||
# Text response - this is the final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=answer,
|
||||
)
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(answer) # Save final answer to messages
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
# Unexpected response type, treat as final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=str(answer),
|
||||
text=str(answer),
|
||||
)
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(str(answer)) # Save final answer to messages
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
raise e
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
def _invoke_loop_native_no_tools(self) -> AgentFinish:
|
||||
"""Execute a simple LLM call when no tools are available.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = get_llm_response(
|
||||
llm=self.llm,
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=str(answer),
|
||||
text=str(answer),
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _is_tool_call_list(self, response: list[Any]) -> bool:
|
||||
"""Check if a response is a list of tool calls.
|
||||
|
||||
Args:
|
||||
response: The response to check.
|
||||
|
||||
Returns:
|
||||
True if the response appears to be a list of tool calls.
|
||||
"""
|
||||
if not response:
|
||||
return False
|
||||
first_item = response[0]
|
||||
# OpenAI-style
|
||||
if hasattr(first_item, "function") or (
|
||||
isinstance(first_item, dict) and "function" in first_item
|
||||
):
|
||||
return True
|
||||
# Anthropic-style
|
||||
if (
|
||||
hasattr(first_item, "type")
|
||||
and getattr(first_item, "type", None) == "tool_use"
|
||||
):
|
||||
return True
|
||||
if hasattr(first_item, "name") and hasattr(first_item, "input"):
|
||||
return True
|
||||
# Gemini-style
|
||||
if hasattr(first_item, "function_call") and first_item.function_call:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _handle_native_tool_calls(
|
||||
self,
|
||||
tool_calls: list[Any],
|
||||
available_functions: dict[str, Callable[..., Any]],
|
||||
) -> None:
|
||||
"""Handle a single native tool call from the LLM.
|
||||
|
||||
Executes only the FIRST tool call and appends the result to message history.
|
||||
This enables sequential tool execution with reflection after each tool,
|
||||
allowing the LLM to reason about results before deciding on next steps.
|
||||
|
||||
Args:
|
||||
tool_calls: List of tool calls from the LLM (only first is processed).
|
||||
available_functions: Dict mapping function names to callables.
|
||||
"""
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from crewai.events import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return
|
||||
|
||||
# Only process the FIRST tool call for sequential execution with reflection
|
||||
tool_call = tool_calls[0]
|
||||
|
||||
# Extract tool call info - handle OpenAI-style, Anthropic-style, and Gemini-style
|
||||
if hasattr(tool_call, "function"):
|
||||
# OpenAI-style: has .function.name and .function.arguments
|
||||
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
|
||||
func_name = tool_call.function.name
|
||||
func_args = tool_call.function.arguments
|
||||
elif hasattr(tool_call, "function_call") and tool_call.function_call:
|
||||
# Gemini-style: has .function_call.name and .function_call.args
|
||||
call_id = f"call_{id(tool_call)}"
|
||||
func_name = tool_call.function_call.name
|
||||
func_args = (
|
||||
dict(tool_call.function_call.args)
|
||||
if tool_call.function_call.args
|
||||
else {}
|
||||
)
|
||||
elif hasattr(tool_call, "name") and hasattr(tool_call, "input"):
|
||||
# Anthropic format: has .name and .input (ToolUseBlock)
|
||||
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
|
||||
func_name = tool_call.name
|
||||
func_args = tool_call.input # Already a dict in Anthropic
|
||||
elif isinstance(tool_call, dict):
|
||||
call_id = tool_call.get("id", f"call_{id(tool_call)}")
|
||||
func_info = tool_call.get("function", {})
|
||||
func_name = func_info.get("name", "") or tool_call.get("name", "")
|
||||
func_args = func_info.get("arguments", "{}") or tool_call.get("input", {})
|
||||
else:
|
||||
return
|
||||
|
||||
# Append assistant message with single tool call
|
||||
assistant_message: LLMMessage = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": func_args
|
||||
if isinstance(func_args, str)
|
||||
else json.dumps(func_args),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
self.messages.append(assistant_message)
|
||||
|
||||
# Parse arguments for the single tool call
|
||||
if isinstance(func_args, str):
|
||||
try:
|
||||
args_dict = json.loads(func_args)
|
||||
except json.JSONDecodeError:
|
||||
args_dict = {}
|
||||
else:
|
||||
args_dict = func_args
|
||||
|
||||
# Emit tool usage started event
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=func_name,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
print(f"Using Tool: {func_name}")
|
||||
result = "Tool not found"
|
||||
if func_name in available_functions:
|
||||
try:
|
||||
tool_func = available_functions[func_name]
|
||||
result = tool_func(**args_dict)
|
||||
if not isinstance(result, str):
|
||||
result = str(result)
|
||||
except Exception as e:
|
||||
result = f"Error executing tool: {e}"
|
||||
|
||||
# Emit tool usage finished event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=func_name,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Append tool result message
|
||||
tool_message: LLMMessage = {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result,
|
||||
}
|
||||
self.messages.append(tool_message)
|
||||
|
||||
# Log the tool execution
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(
|
||||
content=f"Tool {func_name} executed with result: {result[:200]}...",
|
||||
color="green",
|
||||
)
|
||||
|
||||
# Inject post-tool reasoning prompt to enforce analysis
|
||||
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
|
||||
reasoning_message: LLMMessage = {
|
||||
"role": "user",
|
||||
"content": reasoning_prompt,
|
||||
}
|
||||
self.messages.append(reasoning_message)
|
||||
|
||||
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the agent asynchronously with given inputs.
|
||||
|
||||
@@ -746,29 +382,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
async def _ainvoke_loop(self) -> AgentFinish:
|
||||
"""Execute agent loop asynchronously until completion.
|
||||
|
||||
Checks if the LLM supports native function calling and uses that
|
||||
approach if available, otherwise falls back to the ReAct text pattern.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
# Check if model supports native function calling
|
||||
use_native_tools = (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
and self.llm.supports_function_calling()
|
||||
and self.original_tools
|
||||
)
|
||||
|
||||
if use_native_tools:
|
||||
return await self._ainvoke_loop_native_tools()
|
||||
|
||||
# Fall back to ReAct text-based pattern
|
||||
return await self._ainvoke_loop_react()
|
||||
|
||||
async def _ainvoke_loop_react(self) -> AgentFinish:
|
||||
"""Execute agent loop asynchronously using ReAct text-based pattern.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
@@ -882,139 +495,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
async def _ainvoke_loop_native_tools(self) -> AgentFinish:
|
||||
"""Execute agent loop asynchronously using native function calling.
|
||||
|
||||
This method uses the LLM's native tool/function calling capability
|
||||
instead of the text-based ReAct pattern.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
# Convert tools to OpenAI schema format
|
||||
if not self.original_tools:
|
||||
return await self._ainvoke_loop_native_no_tools()
|
||||
|
||||
openai_tools, available_functions = convert_tools_to_openai_schema(
|
||||
self.original_tools
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
None,
|
||||
printer=self._printer,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
# Call LLM with native tools
|
||||
# Pass available_functions=None so the LLM returns tool_calls
|
||||
# without executing them. The executor handles tool execution
|
||||
# via _handle_native_tool_calls to properly manage message history.
|
||||
answer = await aget_llm_response(
|
||||
llm=self.llm,
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
tools=openai_tools,
|
||||
available_functions=None,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
print("--------------------------------")
|
||||
print("native llm completion answer", answer)
|
||||
print("--------------------------------")
|
||||
|
||||
# Check if the response is a list of tool calls
|
||||
if (
|
||||
isinstance(answer, list)
|
||||
and answer
|
||||
and self._is_tool_call_list(answer)
|
||||
):
|
||||
# Handle tool calls - execute tools and add results to messages
|
||||
self._handle_native_tool_calls(answer, available_functions)
|
||||
# Continue loop to let LLM analyze results and decide next steps
|
||||
continue
|
||||
|
||||
# Text or other response - handle as potential final answer
|
||||
if isinstance(answer, str):
|
||||
# Text response - this is the final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=answer,
|
||||
)
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(answer) # Save final answer to messages
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
# Unexpected response type, treat as final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=str(answer),
|
||||
text=str(answer),
|
||||
)
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(str(answer)) # Save final answer to messages
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
raise e
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
async def _ainvoke_loop_native_no_tools(self) -> AgentFinish:
|
||||
"""Execute a simple async LLM call when no tools are available.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = await aget_llm_response(
|
||||
llm=self.llm,
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=str(answer),
|
||||
text=str(answer),
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _handle_agent_action(
|
||||
self, formatted_answer: AgentAction, tool_result: ToolResult
|
||||
) -> AgentAction | AgentFinish:
|
||||
|
||||
@@ -378,12 +378,6 @@ class EventListener(BaseEventListener):
|
||||
self.formatter.handle_llm_tool_usage_finished(
|
||||
event.tool_name,
|
||||
)
|
||||
else:
|
||||
self.formatter.handle_tool_usage_finished(
|
||||
event.tool_name,
|
||||
event.output,
|
||||
getattr(event, "run_attempts", None),
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
|
||||
|
||||
@@ -1,19 +1,28 @@
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AAgentCardFetchedEvent,
|
||||
A2AArtifactReceivedEvent,
|
||||
A2AAuthenticationFailedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
A2AConversationCompletedEvent,
|
||||
A2AConversationStartedEvent,
|
||||
A2ADelegationCompletedEvent,
|
||||
A2ADelegationStartedEvent,
|
||||
A2AMessageSentEvent,
|
||||
A2AParallelDelegationCompletedEvent,
|
||||
A2AParallelDelegationStartedEvent,
|
||||
A2APollingStartedEvent,
|
||||
A2APollingStatusEvent,
|
||||
A2APushNotificationReceivedEvent,
|
||||
A2APushNotificationRegisteredEvent,
|
||||
A2APushNotificationSentEvent,
|
||||
A2APushNotificationTimeoutEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
A2AServerTaskCanceledEvent,
|
||||
A2AServerTaskCompletedEvent,
|
||||
A2AServerTaskFailedEvent,
|
||||
A2AServerTaskStartedEvent,
|
||||
A2AStreamingChunkEvent,
|
||||
A2AStreamingStartedEvent,
|
||||
)
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
@@ -93,7 +102,11 @@ from crewai.events.types.tool_usage_events import (
|
||||
|
||||
|
||||
EventTypes = (
|
||||
A2AConversationCompletedEvent
|
||||
A2AAgentCardFetchedEvent
|
||||
| A2AArtifactReceivedEvent
|
||||
| A2AAuthenticationFailedEvent
|
||||
| A2AConnectionErrorEvent
|
||||
| A2AConversationCompletedEvent
|
||||
| A2AConversationStartedEvent
|
||||
| A2ADelegationCompletedEvent
|
||||
| A2ADelegationStartedEvent
|
||||
@@ -102,12 +115,17 @@ EventTypes = (
|
||||
| A2APollingStatusEvent
|
||||
| A2APushNotificationReceivedEvent
|
||||
| A2APushNotificationRegisteredEvent
|
||||
| A2APushNotificationSentEvent
|
||||
| A2APushNotificationTimeoutEvent
|
||||
| A2AResponseReceivedEvent
|
||||
| A2AServerTaskCanceledEvent
|
||||
| A2AServerTaskCompletedEvent
|
||||
| A2AServerTaskFailedEvent
|
||||
| A2AServerTaskStartedEvent
|
||||
| A2AStreamingChunkEvent
|
||||
| A2AStreamingStartedEvent
|
||||
| A2AParallelDelegationStartedEvent
|
||||
| A2AParallelDelegationCompletedEvent
|
||||
| CrewKickoffStartedEvent
|
||||
| CrewKickoffCompletedEvent
|
||||
| CrewKickoffFailedEvent
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Trace collection listener for orchestrating trace collection."""
|
||||
|
||||
import os
|
||||
from typing import Any, ClassVar, cast
|
||||
from typing import Any, ClassVar
|
||||
import uuid
|
||||
|
||||
from typing_extensions import Self
|
||||
@@ -18,6 +18,32 @@ from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
safe_serialize_to_dict,
|
||||
)
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AAgentCardFetchedEvent,
|
||||
A2AArtifactReceivedEvent,
|
||||
A2AAuthenticationFailedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
A2AConversationCompletedEvent,
|
||||
A2AConversationStartedEvent,
|
||||
A2ADelegationCompletedEvent,
|
||||
A2ADelegationStartedEvent,
|
||||
A2AMessageSentEvent,
|
||||
A2AParallelDelegationCompletedEvent,
|
||||
A2AParallelDelegationStartedEvent,
|
||||
A2APollingStartedEvent,
|
||||
A2APollingStatusEvent,
|
||||
A2APushNotificationReceivedEvent,
|
||||
A2APushNotificationRegisteredEvent,
|
||||
A2APushNotificationSentEvent,
|
||||
A2APushNotificationTimeoutEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
A2AServerTaskCanceledEvent,
|
||||
A2AServerTaskCompletedEvent,
|
||||
A2AServerTaskFailedEvent,
|
||||
A2AServerTaskStartedEvent,
|
||||
A2AStreamingChunkEvent,
|
||||
A2AStreamingStartedEvent,
|
||||
)
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
@@ -105,7 +131,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"""Create or return singleton instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cast(Self, cls._instance)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -160,6 +186,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self._register_flow_event_handlers(crewai_event_bus)
|
||||
self._register_context_event_handlers(crewai_event_bus)
|
||||
self._register_action_event_handlers(crewai_event_bus)
|
||||
self._register_a2a_event_handlers(crewai_event_bus)
|
||||
self._register_system_event_handlers(crewai_event_bus)
|
||||
|
||||
self._listeners_setup = True
|
||||
@@ -439,6 +466,147 @@ class TraceCollectionListener(BaseEventListener):
|
||||
) -> None:
|
||||
self._handle_action_event("knowledge_query_failed", source, event)
|
||||
|
||||
def _register_a2a_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for A2A (Agent-to-Agent) events."""
|
||||
|
||||
@event_bus.on(A2ADelegationStartedEvent)
|
||||
def on_a2a_delegation_started(
|
||||
source: Any, event: A2ADelegationStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_delegation_started", source, event)
|
||||
|
||||
@event_bus.on(A2ADelegationCompletedEvent)
|
||||
def on_a2a_delegation_completed(
|
||||
source: Any, event: A2ADelegationCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_delegation_completed", source, event)
|
||||
|
||||
@event_bus.on(A2AConversationStartedEvent)
|
||||
def on_a2a_conversation_started(
|
||||
source: Any, event: A2AConversationStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_conversation_started", source, event)
|
||||
|
||||
@event_bus.on(A2AMessageSentEvent)
|
||||
def on_a2a_message_sent(source: Any, event: A2AMessageSentEvent) -> None:
|
||||
self._handle_action_event("a2a_message_sent", source, event)
|
||||
|
||||
@event_bus.on(A2AResponseReceivedEvent)
|
||||
def on_a2a_response_received(
|
||||
source: Any, event: A2AResponseReceivedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_response_received", source, event)
|
||||
|
||||
@event_bus.on(A2AConversationCompletedEvent)
|
||||
def on_a2a_conversation_completed(
|
||||
source: Any, event: A2AConversationCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_conversation_completed", source, event)
|
||||
|
||||
@event_bus.on(A2APollingStartedEvent)
|
||||
def on_a2a_polling_started(source: Any, event: A2APollingStartedEvent) -> None:
|
||||
self._handle_action_event("a2a_polling_started", source, event)
|
||||
|
||||
@event_bus.on(A2APollingStatusEvent)
|
||||
def on_a2a_polling_status(source: Any, event: A2APollingStatusEvent) -> None:
|
||||
self._handle_action_event("a2a_polling_status", source, event)
|
||||
|
||||
@event_bus.on(A2APushNotificationRegisteredEvent)
|
||||
def on_a2a_push_notification_registered(
|
||||
source: Any, event: A2APushNotificationRegisteredEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_push_notification_registered", source, event)
|
||||
|
||||
@event_bus.on(A2APushNotificationReceivedEvent)
|
||||
def on_a2a_push_notification_received(
|
||||
source: Any, event: A2APushNotificationReceivedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_push_notification_received", source, event)
|
||||
|
||||
@event_bus.on(A2APushNotificationSentEvent)
|
||||
def on_a2a_push_notification_sent(
|
||||
source: Any, event: A2APushNotificationSentEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_push_notification_sent", source, event)
|
||||
|
||||
@event_bus.on(A2APushNotificationTimeoutEvent)
|
||||
def on_a2a_push_notification_timeout(
|
||||
source: Any, event: A2APushNotificationTimeoutEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_push_notification_timeout", source, event)
|
||||
|
||||
@event_bus.on(A2AStreamingStartedEvent)
|
||||
def on_a2a_streaming_started(
|
||||
source: Any, event: A2AStreamingStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_streaming_started", source, event)
|
||||
|
||||
@event_bus.on(A2AStreamingChunkEvent)
|
||||
def on_a2a_streaming_chunk(source: Any, event: A2AStreamingChunkEvent) -> None:
|
||||
self._handle_action_event("a2a_streaming_chunk", source, event)
|
||||
|
||||
@event_bus.on(A2AAgentCardFetchedEvent)
|
||||
def on_a2a_agent_card_fetched(
|
||||
source: Any, event: A2AAgentCardFetchedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_agent_card_fetched", source, event)
|
||||
|
||||
@event_bus.on(A2AAuthenticationFailedEvent)
|
||||
def on_a2a_authentication_failed(
|
||||
source: Any, event: A2AAuthenticationFailedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_authentication_failed", source, event)
|
||||
|
||||
@event_bus.on(A2AArtifactReceivedEvent)
|
||||
def on_a2a_artifact_received(
|
||||
source: Any, event: A2AArtifactReceivedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_artifact_received", source, event)
|
||||
|
||||
@event_bus.on(A2AConnectionErrorEvent)
|
||||
def on_a2a_connection_error(
|
||||
source: Any, event: A2AConnectionErrorEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_connection_error", source, event)
|
||||
|
||||
@event_bus.on(A2AServerTaskStartedEvent)
|
||||
def on_a2a_server_task_started(
|
||||
source: Any, event: A2AServerTaskStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_server_task_started", source, event)
|
||||
|
||||
@event_bus.on(A2AServerTaskCompletedEvent)
|
||||
def on_a2a_server_task_completed(
|
||||
source: Any, event: A2AServerTaskCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_server_task_completed", source, event)
|
||||
|
||||
@event_bus.on(A2AServerTaskCanceledEvent)
|
||||
def on_a2a_server_task_canceled(
|
||||
source: Any, event: A2AServerTaskCanceledEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_server_task_canceled", source, event)
|
||||
|
||||
@event_bus.on(A2AServerTaskFailedEvent)
|
||||
def on_a2a_server_task_failed(
|
||||
source: Any, event: A2AServerTaskFailedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_server_task_failed", source, event)
|
||||
|
||||
@event_bus.on(A2AParallelDelegationStartedEvent)
|
||||
def on_a2a_parallel_delegation_started(
|
||||
source: Any, event: A2AParallelDelegationStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("a2a_parallel_delegation_started", source, event)
|
||||
|
||||
@event_bus.on(A2AParallelDelegationCompletedEvent)
|
||||
def on_a2a_parallel_delegation_completed(
|
||||
source: Any, event: A2AParallelDelegationCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event(
|
||||
"a2a_parallel_delegation_completed", source, event
|
||||
)
|
||||
|
||||
def _register_system_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for system signal events (SIGTERM, SIGINT, etc.)."""
|
||||
|
||||
@@ -570,10 +738,15 @@ class TraceCollectionListener(BaseEventListener):
|
||||
if event_type not in self.complex_events:
|
||||
return safe_serialize_to_dict(event)
|
||||
if event_type == "task_started":
|
||||
task_name = event.task.name or event.task.description
|
||||
task_display_name = (
|
||||
task_name[:80] + "..." if len(task_name) > 80 else task_name
|
||||
)
|
||||
return {
|
||||
"task_description": event.task.description,
|
||||
"expected_output": event.task.expected_output,
|
||||
"task_name": event.task.name or event.task.description,
|
||||
"task_name": task_name,
|
||||
"task_display_name": task_display_name,
|
||||
"context": event.context,
|
||||
"agent_role": source.agent.role,
|
||||
"task_id": str(event.task.id),
|
||||
|
||||
@@ -4,68 +4,120 @@ This module defines events emitted during A2A protocol delegation,
|
||||
including both single-turn and multiturn conversation flows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class A2AEventBase(BaseEvent):
|
||||
"""Base class for A2A events with task/agent context."""
|
||||
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
from_task: Any = None
|
||||
from_agent: Any = None
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
"""Initialize A2A event, extracting task and agent metadata."""
|
||||
if data.get("from_task"):
|
||||
task = data["from_task"]
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def extract_task_and_agent_metadata(cls, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract task and agent metadata before validation."""
|
||||
if task := data.get("from_task"):
|
||||
data["task_id"] = str(task.id)
|
||||
data["task_name"] = task.name or task.description
|
||||
data.setdefault("source_fingerprint", str(task.id))
|
||||
data.setdefault("source_type", "task")
|
||||
data.setdefault(
|
||||
"fingerprint_metadata",
|
||||
{
|
||||
"task_id": str(task.id),
|
||||
"task_name": task.name or task.description,
|
||||
},
|
||||
)
|
||||
data["from_task"] = None
|
||||
|
||||
if data.get("from_agent"):
|
||||
agent = data["from_agent"]
|
||||
if agent := data.get("from_agent"):
|
||||
data["agent_id"] = str(agent.id)
|
||||
data["agent_role"] = agent.role
|
||||
data.setdefault("source_fingerprint", str(agent.id))
|
||||
data.setdefault("source_type", "agent")
|
||||
data.setdefault(
|
||||
"fingerprint_metadata",
|
||||
{
|
||||
"agent_id": str(agent.id),
|
||||
"agent_role": agent.role,
|
||||
},
|
||||
)
|
||||
data["from_agent"] = None
|
||||
|
||||
super().__init__(**data)
|
||||
return data
|
||||
|
||||
|
||||
class A2ADelegationStartedEvent(A2AEventBase):
|
||||
"""Event emitted when A2A delegation starts.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL)
|
||||
task_description: Task being delegated to the A2A agent
|
||||
agent_id: A2A agent identifier
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
turn_number: Current turn number (1-indexed, 1 for single-turn)
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
task_description: Task being delegated to the A2A agent.
|
||||
agent_id: A2A agent identifier.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
is_multiturn: Whether this is part of a multiturn conversation.
|
||||
turn_number: Current turn number (1-indexed, 1 for single-turn).
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
agent_card: Full A2A agent card metadata.
|
||||
protocol_version: A2A protocol version being used.
|
||||
provider: Agent provider/organization info from agent card.
|
||||
skill_id: ID of the specific skill being invoked.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
extensions: List of A2A extension URIs in use.
|
||||
"""
|
||||
|
||||
type: str = "a2a_delegation_started"
|
||||
endpoint: str
|
||||
task_description: str
|
||||
agent_id: str
|
||||
context_id: str | None = None
|
||||
is_multiturn: bool = False
|
||||
turn_number: int = 1
|
||||
a2a_agent_name: str | None = None
|
||||
agent_card: dict[str, Any] | None = None
|
||||
protocol_version: str | None = None
|
||||
provider: dict[str, Any] | None = None
|
||||
skill_id: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
extensions: list[str] | None = None
|
||||
|
||||
|
||||
class A2ADelegationCompletedEvent(A2AEventBase):
|
||||
"""Event emitted when A2A delegation completes.
|
||||
|
||||
Attributes:
|
||||
status: Completion status (completed, input_required, failed, etc.)
|
||||
result: Result message if status is completed
|
||||
error: Error/response message (error for failed, response for input_required)
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
status: Completion status (completed, input_required, failed, etc.).
|
||||
result: Result message if status is completed.
|
||||
error: Error/response message (error for failed, response for input_required).
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
is_multiturn: Whether this is part of a multiturn conversation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
agent_card: Full A2A agent card metadata.
|
||||
provider: Agent provider/organization info from agent card.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
extensions: List of A2A extension URIs in use.
|
||||
"""
|
||||
|
||||
type: str = "a2a_delegation_completed"
|
||||
status: str
|
||||
result: str | None = None
|
||||
error: str | None = None
|
||||
context_id: str | None = None
|
||||
is_multiturn: bool = False
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
agent_card: dict[str, Any] | None = None
|
||||
provider: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
extensions: list[str] | None = None
|
||||
|
||||
|
||||
class A2AConversationStartedEvent(A2AEventBase):
|
||||
@@ -75,51 +127,95 @@ class A2AConversationStartedEvent(A2AEventBase):
|
||||
before the first message exchange.
|
||||
|
||||
Attributes:
|
||||
agent_id: A2A agent identifier
|
||||
endpoint: A2A agent endpoint URL
|
||||
a2a_agent_name: Name of the A2A agent from agent card
|
||||
agent_id: A2A agent identifier.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
agent_card: Full A2A agent card metadata.
|
||||
protocol_version: A2A protocol version being used.
|
||||
provider: Agent provider/organization info from agent card.
|
||||
skill_id: ID of the specific skill being invoked.
|
||||
reference_task_ids: Related task IDs for context.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
extensions: List of A2A extension URIs in use.
|
||||
"""
|
||||
|
||||
type: str = "a2a_conversation_started"
|
||||
agent_id: str
|
||||
endpoint: str
|
||||
context_id: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
agent_card: dict[str, Any] | None = None
|
||||
protocol_version: str | None = None
|
||||
provider: dict[str, Any] | None = None
|
||||
skill_id: str | None = None
|
||||
reference_task_ids: list[str] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
extensions: list[str] | None = None
|
||||
|
||||
|
||||
class A2AMessageSentEvent(A2AEventBase):
|
||||
"""Event emitted when a message is sent to the A2A agent.
|
||||
|
||||
Attributes:
|
||||
message: Message content sent to the A2A agent
|
||||
turn_number: Current turn number (1-indexed)
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
agent_role: Role of the CrewAI agent sending the message
|
||||
message: Message content sent to the A2A agent.
|
||||
turn_number: Current turn number (1-indexed).
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
message_id: Unique A2A message identifier.
|
||||
is_multiturn: Whether this is part of a multiturn conversation.
|
||||
agent_role: Role of the CrewAI agent sending the message.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
skill_id: ID of the specific skill being invoked.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
extensions: List of A2A extension URIs in use.
|
||||
"""
|
||||
|
||||
type: str = "a2a_message_sent"
|
||||
message: str
|
||||
turn_number: int
|
||||
context_id: str | None = None
|
||||
message_id: str | None = None
|
||||
is_multiturn: bool = False
|
||||
agent_role: str | None = None
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
skill_id: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
extensions: list[str] | None = None
|
||||
|
||||
|
||||
class A2AResponseReceivedEvent(A2AEventBase):
|
||||
"""Event emitted when a response is received from the A2A agent.
|
||||
|
||||
Attributes:
|
||||
response: Response content from the A2A agent
|
||||
turn_number: Current turn number (1-indexed)
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
status: Response status (input_required, completed, etc.)
|
||||
agent_role: Role of the CrewAI agent (for display)
|
||||
response: Response content from the A2A agent.
|
||||
turn_number: Current turn number (1-indexed).
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
message_id: Unique A2A message identifier.
|
||||
is_multiturn: Whether this is part of a multiturn conversation.
|
||||
status: Response status (input_required, completed, etc.).
|
||||
final: Whether this is the final response in the stream.
|
||||
agent_role: Role of the CrewAI agent (for display).
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
extensions: List of A2A extension URIs in use.
|
||||
"""
|
||||
|
||||
type: str = "a2a_response_received"
|
||||
response: str
|
||||
turn_number: int
|
||||
context_id: str | None = None
|
||||
message_id: str | None = None
|
||||
is_multiturn: bool = False
|
||||
status: str
|
||||
final: bool = False
|
||||
agent_role: str | None = None
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
extensions: list[str] | None = None
|
||||
|
||||
|
||||
class A2AConversationCompletedEvent(A2AEventBase):
|
||||
@@ -128,119 +224,433 @@ class A2AConversationCompletedEvent(A2AEventBase):
|
||||
This is emitted once at the end of a multiturn conversation.
|
||||
|
||||
Attributes:
|
||||
status: Final status (completed, failed, etc.)
|
||||
final_result: Final result if completed successfully
|
||||
error: Error message if failed
|
||||
total_turns: Total number of turns in the conversation
|
||||
status: Final status (completed, failed, etc.).
|
||||
final_result: Final result if completed successfully.
|
||||
error: Error message if failed.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
total_turns: Total number of turns in the conversation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
agent_card: Full A2A agent card metadata.
|
||||
reference_task_ids: Related task IDs for context.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
extensions: List of A2A extension URIs in use.
|
||||
"""
|
||||
|
||||
type: str = "a2a_conversation_completed"
|
||||
status: Literal["completed", "failed"]
|
||||
final_result: str | None = None
|
||||
error: str | None = None
|
||||
context_id: str | None = None
|
||||
total_turns: int
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
agent_card: dict[str, Any] | None = None
|
||||
reference_task_ids: list[str] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
extensions: list[str] | None = None
|
||||
|
||||
|
||||
class A2APollingStartedEvent(A2AEventBase):
|
||||
"""Event emitted when polling mode begins for A2A delegation.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID being polled
|
||||
polling_interval: Seconds between poll attempts
|
||||
endpoint: A2A agent endpoint URL
|
||||
task_id: A2A task ID being polled.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
polling_interval: Seconds between poll attempts.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_polling_started"
|
||||
task_id: str
|
||||
context_id: str | None = None
|
||||
polling_interval: float
|
||||
endpoint: str
|
||||
a2a_agent_name: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2APollingStatusEvent(A2AEventBase):
|
||||
"""Event emitted on each polling iteration.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID being polled
|
||||
state: Current task state from remote agent
|
||||
elapsed_seconds: Time since polling started
|
||||
poll_count: Number of polls completed
|
||||
task_id: A2A task ID being polled.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
state: Current task state from remote agent.
|
||||
elapsed_seconds: Time since polling started.
|
||||
poll_count: Number of polls completed.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_polling_status"
|
||||
task_id: str
|
||||
context_id: str | None = None
|
||||
state: str
|
||||
elapsed_seconds: float
|
||||
poll_count: int
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2APushNotificationRegisteredEvent(A2AEventBase):
|
||||
"""Event emitted when push notification callback is registered.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID for which callback is registered
|
||||
callback_url: URL where agent will send push notifications
|
||||
task_id: A2A task ID for which callback is registered.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
callback_url: URL where agent will send push notifications.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_push_notification_registered"
|
||||
task_id: str
|
||||
context_id: str | None = None
|
||||
callback_url: str
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2APushNotificationReceivedEvent(A2AEventBase):
|
||||
"""Event emitted when a push notification is received.
|
||||
|
||||
This event should be emitted by the user's webhook handler when it receives
|
||||
a push notification from the remote A2A agent, before calling
|
||||
`result_store.store_result()`.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID from the notification
|
||||
state: Current task state from the notification
|
||||
task_id: A2A task ID from the notification.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
state: Current task state from the notification.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_push_notification_received"
|
||||
task_id: str
|
||||
context_id: str | None = None
|
||||
state: str
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2APushNotificationSentEvent(A2AEventBase):
|
||||
"""Event emitted when a push notification is sent to a callback URL.
|
||||
|
||||
Emitted by the A2A server when it sends a task status update to the
|
||||
client's registered push notification callback URL.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID being notified.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
callback_url: URL the notification was sent to.
|
||||
state: Task state being reported.
|
||||
success: Whether the notification was successfully delivered.
|
||||
error: Error message if delivery failed.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_push_notification_sent"
|
||||
task_id: str
|
||||
context_id: str | None = None
|
||||
callback_url: str
|
||||
state: str
|
||||
success: bool = True
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2APushNotificationTimeoutEvent(A2AEventBase):
|
||||
"""Event emitted when push notification wait times out.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID that timed out
|
||||
timeout_seconds: Timeout duration in seconds
|
||||
task_id: A2A task ID that timed out.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
timeout_seconds: Timeout duration in seconds.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_push_notification_timeout"
|
||||
task_id: str
|
||||
context_id: str | None = None
|
||||
timeout_seconds: float
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AStreamingStartedEvent(A2AEventBase):
|
||||
"""Event emitted when streaming mode begins for A2A delegation.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID for the streaming session.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
turn_number: Current turn number (1-indexed).
|
||||
is_multiturn: Whether this is part of a multiturn conversation.
|
||||
agent_role: Role of the CrewAI agent.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
extensions: List of A2A extension URIs in use.
|
||||
"""
|
||||
|
||||
type: str = "a2a_streaming_started"
|
||||
task_id: str | None = None
|
||||
context_id: str | None = None
|
||||
endpoint: str
|
||||
a2a_agent_name: str | None = None
|
||||
turn_number: int = 1
|
||||
is_multiturn: bool = False
|
||||
agent_role: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
extensions: list[str] | None = None
|
||||
|
||||
|
||||
class A2AStreamingChunkEvent(A2AEventBase):
|
||||
"""Event emitted when a streaming chunk is received.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID for the streaming session.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
chunk: The text content of the chunk.
|
||||
chunk_index: Index of this chunk in the stream (0-indexed).
|
||||
final: Whether this is the final chunk in the stream.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
turn_number: Current turn number (1-indexed).
|
||||
is_multiturn: Whether this is part of a multiturn conversation.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
extensions: List of A2A extension URIs in use.
|
||||
"""
|
||||
|
||||
type: str = "a2a_streaming_chunk"
|
||||
task_id: str | None = None
|
||||
context_id: str | None = None
|
||||
chunk: str
|
||||
chunk_index: int
|
||||
final: bool = False
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
turn_number: int = 1
|
||||
is_multiturn: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
extensions: list[str] | None = None
|
||||
|
||||
|
||||
class A2AAgentCardFetchedEvent(A2AEventBase):
|
||||
"""Event emitted when an agent card is successfully fetched.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
agent_card: Full A2A agent card metadata.
|
||||
protocol_version: A2A protocol version from agent card.
|
||||
provider: Agent provider/organization info from agent card.
|
||||
cached: Whether the agent card was retrieved from cache.
|
||||
fetch_time_ms: Time taken to fetch the agent card in milliseconds.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_agent_card_fetched"
|
||||
endpoint: str
|
||||
a2a_agent_name: str | None = None
|
||||
agent_card: dict[str, Any] | None = None
|
||||
protocol_version: str | None = None
|
||||
provider: dict[str, Any] | None = None
|
||||
cached: bool = False
|
||||
fetch_time_ms: float | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AAuthenticationFailedEvent(A2AEventBase):
|
||||
"""Event emitted when authentication to an A2A agent fails.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth_type: Type of authentication attempted (e.g., bearer, oauth2, api_key).
|
||||
error: Error message describing the failure.
|
||||
status_code: HTTP status code if applicable.
|
||||
a2a_agent_name: Name of the A2A agent if known.
|
||||
protocol_version: A2A protocol version being used.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_authentication_failed"
|
||||
endpoint: str
|
||||
auth_type: str | None = None
|
||||
error: str
|
||||
status_code: int | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
protocol_version: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AArtifactReceivedEvent(A2AEventBase):
|
||||
"""Event emitted when an artifact is received from a remote A2A agent.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID the artifact belongs to.
|
||||
artifact_id: Unique identifier for the artifact.
|
||||
artifact_name: Name of the artifact.
|
||||
artifact_description: Purpose description of the artifact.
|
||||
mime_type: MIME type of the artifact content.
|
||||
size_bytes: Size of the artifact in bytes.
|
||||
append: Whether content should be appended to existing artifact.
|
||||
last_chunk: Whether this is the final chunk of the artifact.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
context_id: Context ID for correlation.
|
||||
turn_number: Current turn number (1-indexed).
|
||||
is_multiturn: Whether this is part of a multiturn conversation.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
extensions: List of A2A extension URIs in use.
|
||||
"""
|
||||
|
||||
type: str = "a2a_artifact_received"
|
||||
task_id: str
|
||||
artifact_id: str
|
||||
artifact_name: str | None = None
|
||||
artifact_description: str | None = None
|
||||
mime_type: str | None = None
|
||||
size_bytes: int | None = None
|
||||
append: bool = False
|
||||
last_chunk: bool = False
|
||||
endpoint: str | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
context_id: str | None = None
|
||||
turn_number: int = 1
|
||||
is_multiturn: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
extensions: list[str] | None = None
|
||||
|
||||
|
||||
class A2AConnectionErrorEvent(A2AEventBase):
|
||||
"""Event emitted when a connection error occurs during A2A communication.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
error: Error message describing the connection failure.
|
||||
error_type: Type of error (e.g., timeout, connection_refused, dns_error).
|
||||
status_code: HTTP status code if applicable.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
operation: The operation being attempted when error occurred.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
task_id: A2A task ID if applicable.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_connection_error"
|
||||
endpoint: str
|
||||
error: str
|
||||
error_type: str | None = None
|
||||
status_code: int | None = None
|
||||
a2a_agent_name: str | None = None
|
||||
operation: str | None = None
|
||||
context_id: str | None = None
|
||||
task_id: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AServerTaskStartedEvent(A2AEventBase):
|
||||
"""Event emitted when an A2A server task execution starts."""
|
||||
"""Event emitted when an A2A server task execution starts.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID for this execution.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_server_task_started"
|
||||
a2a_task_id: str
|
||||
a2a_context_id: str
|
||||
task_id: str
|
||||
context_id: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AServerTaskCompletedEvent(A2AEventBase):
|
||||
"""Event emitted when an A2A server task execution completes."""
|
||||
"""Event emitted when an A2A server task execution completes.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID for this execution.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
result: The task result.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_server_task_completed"
|
||||
a2a_task_id: str
|
||||
a2a_context_id: str
|
||||
task_id: str
|
||||
context_id: str
|
||||
result: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AServerTaskCanceledEvent(A2AEventBase):
|
||||
"""Event emitted when an A2A server task execution is canceled."""
|
||||
"""Event emitted when an A2A server task execution is canceled.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID for this execution.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_server_task_canceled"
|
||||
a2a_task_id: str
|
||||
a2a_context_id: str
|
||||
task_id: str
|
||||
context_id: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AServerTaskFailedEvent(A2AEventBase):
|
||||
"""Event emitted when an A2A server task execution fails."""
|
||||
"""Event emitted when an A2A server task execution fails.
|
||||
|
||||
Attributes:
|
||||
task_id: A2A task ID for this execution.
|
||||
context_id: A2A context ID grouping related tasks.
|
||||
error: Error message describing the failure.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_server_task_failed"
|
||||
a2a_task_id: str
|
||||
a2a_context_id: str
|
||||
task_id: str
|
||||
context_id: str
|
||||
error: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AParallelDelegationStartedEvent(A2AEventBase):
|
||||
"""Event emitted when parallel delegation to multiple A2A agents begins.
|
||||
|
||||
Attributes:
|
||||
endpoints: List of A2A agent endpoints being delegated to.
|
||||
task_description: Description of the task being delegated.
|
||||
"""
|
||||
|
||||
type: str = "a2a_parallel_delegation_started"
|
||||
endpoints: list[str]
|
||||
task_description: str
|
||||
|
||||
|
||||
class A2AParallelDelegationCompletedEvent(A2AEventBase):
|
||||
"""Event emitted when parallel delegation to multiple A2A agents completes.
|
||||
|
||||
Attributes:
|
||||
endpoints: List of A2A agent endpoints that were delegated to.
|
||||
success_count: Number of successful delegations.
|
||||
failure_count: Number of failed delegations.
|
||||
results: Summary of results from each agent.
|
||||
"""
|
||||
|
||||
type: str = "a2a_parallel_delegation_completed"
|
||||
endpoints: list[str]
|
||||
success_count: int
|
||||
failure_count: int
|
||||
results: dict[str, str] | None = None
|
||||
|
||||
@@ -366,32 +366,6 @@ To enable tracing, do any one of these:
|
||||
|
||||
self.print_panel(content, f"🔧 Tool Execution Started (#{iteration})", "yellow")
|
||||
|
||||
def handle_tool_usage_finished(
|
||||
self,
|
||||
tool_name: str,
|
||||
output: str,
|
||||
run_attempts: int | None = None,
|
||||
) -> None:
|
||||
"""Handle tool usage finished event with panel display."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
|
||||
content = Text()
|
||||
content.append("Tool Completed\n", style="green bold")
|
||||
content.append("Tool: ", style="white")
|
||||
content.append(f"{tool_name}\n", style="green bold")
|
||||
|
||||
if output:
|
||||
content.append("Output: ", style="white")
|
||||
|
||||
content.append(f"{output}\n", style="green")
|
||||
|
||||
self.print_panel(
|
||||
content, f"✅ Tool Execution Completed (#{iteration})", "green"
|
||||
)
|
||||
|
||||
def handle_tool_usage_error(
|
||||
self,
|
||||
tool_name: str,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
import json
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from uuid import uuid4
|
||||
@@ -19,24 +17,16 @@ from crewai.agents.parser import (
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled_in_context,
|
||||
)
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, listen, or_, router, start
|
||||
from crewai.hooks.llm_hooks import (
|
||||
get_after_llm_call_hooks,
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.agent_utils import (
|
||||
convert_tools_to_openai_schema,
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
get_llm_response,
|
||||
@@ -81,8 +71,6 @@ class AgentReActState(BaseModel):
|
||||
current_answer: AgentAction | AgentFinish | None = Field(default=None)
|
||||
is_finished: bool = Field(default=False)
|
||||
ask_for_human_input: bool = Field(default=False)
|
||||
use_native_tools: bool = Field(default=False)
|
||||
pending_tool_calls: list[Any] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
@@ -191,10 +179,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
)
|
||||
)
|
||||
|
||||
# Native tool calling support
|
||||
self._openai_tools: list[dict[str, Any]] = []
|
||||
self._available_functions: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
self._state = AgentReActState()
|
||||
|
||||
def _ensure_flow_initialized(self) -> None:
|
||||
@@ -205,66 +189,14 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
Only the instance that actually executes via invoke() will emit events.
|
||||
"""
|
||||
if not self._flow_initialized:
|
||||
current_tracing = is_tracing_enabled_in_context()
|
||||
# Now call Flow's __init__ which will replace self._state
|
||||
# with Flow's managed state. Suppress flow events since this is
|
||||
# an agent executor, not a user-facing flow.
|
||||
super().__init__(
|
||||
suppress_flow_events=True,
|
||||
tracing=current_tracing if current_tracing else None,
|
||||
)
|
||||
self._flow_initialized = True
|
||||
|
||||
def _check_native_tool_support(self) -> bool:
|
||||
"""Check if LLM supports native function calling.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports native function calling and tools are available.
|
||||
"""
|
||||
return (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
and self.llm.supports_function_calling()
|
||||
and bool(self.original_tools)
|
||||
)
|
||||
|
||||
def _setup_native_tools(self) -> None:
|
||||
"""Convert tools to OpenAI schema format for native function calling."""
|
||||
if self.original_tools:
|
||||
self._openai_tools, self._available_functions = (
|
||||
convert_tools_to_openai_schema(self.original_tools)
|
||||
)
|
||||
|
||||
def _is_tool_call_list(self, response: list[Any]) -> bool:
|
||||
"""Check if a response is a list of tool calls.
|
||||
|
||||
Args:
|
||||
response: The response to check.
|
||||
|
||||
Returns:
|
||||
True if the response appears to be a list of tool calls.
|
||||
"""
|
||||
if not response:
|
||||
return False
|
||||
first_item = response[0]
|
||||
# Check for OpenAI-style tool call structure
|
||||
if hasattr(first_item, "function") or (
|
||||
isinstance(first_item, dict) and "function" in first_item
|
||||
):
|
||||
return True
|
||||
# Check for Anthropic-style tool call structure (ToolUseBlock)
|
||||
if (
|
||||
hasattr(first_item, "type")
|
||||
and getattr(first_item, "type", None) == "tool_use"
|
||||
):
|
||||
return True
|
||||
if hasattr(first_item, "name") and hasattr(first_item, "input"):
|
||||
return True
|
||||
# Check for Gemini-style function call (Part with function_call)
|
||||
if hasattr(first_item, "function_call") and first_item.function_call:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def use_stop_words(self) -> bool:
|
||||
"""Check to determine if stop words are being used.
|
||||
@@ -297,11 +229,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
def initialize_reasoning(self) -> Literal["initialized"]:
|
||||
"""Initialize the reasoning flow and emit agent start logs."""
|
||||
self._show_start_logs()
|
||||
# Check for native tool support on first iteration
|
||||
if self.state.iterations == 0:
|
||||
self.state.use_native_tools = self._check_native_tool_support()
|
||||
if self.state.use_native_tools:
|
||||
self._setup_native_tools()
|
||||
return "initialized"
|
||||
|
||||
@listen("force_final_answer")
|
||||
@@ -376,69 +303,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
|
||||
@listen("continue_reasoning_native")
|
||||
def call_llm_native_tools(
|
||||
self,
|
||||
) -> Literal["native_tool_calls", "native_finished", "context_error"]:
|
||||
"""Execute LLM call with native function calling.
|
||||
|
||||
Returns routing decision based on whether tool calls or final answer.
|
||||
"""
|
||||
try:
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
# Call LLM with native tools
|
||||
# Pass available_functions=None so the LLM returns tool_calls
|
||||
# without executing them. The executor handles tool execution.
|
||||
answer = get_llm_response(
|
||||
llm=self.llm,
|
||||
messages=list(self.state.messages),
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
tools=self._openai_tools,
|
||||
available_functions=None,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
|
||||
# Check if the response is a list of tool calls
|
||||
if isinstance(answer, list) and answer and self._is_tool_call_list(answer):
|
||||
# Store tool calls for sequential processing
|
||||
self.state.pending_tool_calls = list(answer)
|
||||
return "native_tool_calls"
|
||||
|
||||
# Text response - this is the final answer
|
||||
if isinstance(answer, str):
|
||||
self.state.current_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=answer,
|
||||
)
|
||||
self._invoke_step_callback(self.state.current_answer)
|
||||
self._append_message_to_state(answer)
|
||||
return "native_finished"
|
||||
|
||||
# Unexpected response type, treat as final answer
|
||||
self.state.current_answer = AgentFinish(
|
||||
thought="",
|
||||
output=str(answer),
|
||||
text=str(answer),
|
||||
)
|
||||
self._invoke_step_callback(self.state.current_answer)
|
||||
self._append_message_to_state(str(answer))
|
||||
return "native_finished"
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
self._last_context_error = e
|
||||
return "context_error"
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
raise e
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
|
||||
@router(call_llm_and_parse)
|
||||
def route_by_answer_type(self) -> Literal["execute_tool", "agent_finished"]:
|
||||
"""Route based on whether answer is AgentAction or AgentFinish."""
|
||||
@@ -494,14 +358,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
self.state.is_finished = True
|
||||
return "tool_result_is_final"
|
||||
|
||||
# Inject post-tool reasoning prompt to enforce analysis
|
||||
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
|
||||
reasoning_message: LLMMessage = {
|
||||
"role": "user",
|
||||
"content": reasoning_prompt,
|
||||
}
|
||||
self.state.messages.append(reasoning_message)
|
||||
|
||||
return "tool_completed"
|
||||
|
||||
except Exception as e:
|
||||
@@ -511,143 +367,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
self._console.print(error_text)
|
||||
raise
|
||||
|
||||
@listen("native_tool_calls")
|
||||
def execute_native_tool(self) -> Literal["native_tool_completed"]:
|
||||
"""Execute a single native tool call and inject reasoning prompt.
|
||||
|
||||
Processes only the FIRST tool call from pending_tool_calls for
|
||||
sequential execution with reflection after each tool.
|
||||
"""
|
||||
if not self.state.pending_tool_calls:
|
||||
return "native_tool_completed"
|
||||
|
||||
tool_call = self.state.pending_tool_calls[0]
|
||||
self.state.pending_tool_calls = [] # Clear pending calls
|
||||
|
||||
# Extract tool call info - handle OpenAI, Anthropic, and Gemini formats
|
||||
if hasattr(tool_call, "function"):
|
||||
# OpenAI format: has .function.name and .function.arguments
|
||||
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
|
||||
func_name = tool_call.function.name
|
||||
func_args = tool_call.function.arguments
|
||||
elif hasattr(tool_call, "function_call") and tool_call.function_call:
|
||||
# Gemini format: has .function_call.name and .function_call.args
|
||||
call_id = f"call_{id(tool_call)}"
|
||||
func_name = tool_call.function_call.name
|
||||
func_args = (
|
||||
dict(tool_call.function_call.args)
|
||||
if tool_call.function_call.args
|
||||
else {}
|
||||
)
|
||||
elif hasattr(tool_call, "name") and hasattr(tool_call, "input"):
|
||||
# Anthropic format: has .name and .input (ToolUseBlock)
|
||||
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
|
||||
func_name = tool_call.name
|
||||
func_args = tool_call.input # Already a dict in Anthropic
|
||||
elif isinstance(tool_call, dict):
|
||||
call_id = tool_call.get("id", f"call_{id(tool_call)}")
|
||||
func_info = tool_call.get("function", {})
|
||||
func_name = func_info.get("name", "") or tool_call.get("name", "")
|
||||
func_args = func_info.get("arguments", "{}") or tool_call.get("input", {})
|
||||
else:
|
||||
return "native_tool_completed"
|
||||
|
||||
# Append assistant message with single tool call
|
||||
assistant_message: LLMMessage = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": func_args
|
||||
if isinstance(func_args, str)
|
||||
else json.dumps(func_args),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
self.state.messages.append(assistant_message)
|
||||
|
||||
# Parse arguments for the single tool call
|
||||
if isinstance(func_args, str):
|
||||
try:
|
||||
args_dict = json.loads(func_args)
|
||||
except json.JSONDecodeError:
|
||||
args_dict = {}
|
||||
else:
|
||||
args_dict = func_args
|
||||
|
||||
# Emit tool usage started event
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=func_name,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
result = "Tool not found"
|
||||
if func_name in self._available_functions:
|
||||
try:
|
||||
tool_func = self._available_functions[func_name]
|
||||
result = tool_func(**args_dict)
|
||||
if not isinstance(result, str):
|
||||
result = str(result)
|
||||
except Exception as e:
|
||||
result = f"Error executing tool: {e}"
|
||||
|
||||
# Emit tool usage finished event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=func_name,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Append tool result message
|
||||
tool_message: LLMMessage = {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result,
|
||||
}
|
||||
self.state.messages.append(tool_message)
|
||||
|
||||
# Log the tool execution
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(
|
||||
content=f"Tool {func_name} executed with result: {result[:200]}...",
|
||||
color="green",
|
||||
)
|
||||
|
||||
# Inject post-tool reasoning prompt to enforce analysis
|
||||
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
|
||||
reasoning_message: LLMMessage = {
|
||||
"role": "user",
|
||||
"content": reasoning_prompt,
|
||||
}
|
||||
self.state.messages.append(reasoning_message)
|
||||
|
||||
return "native_tool_completed"
|
||||
|
||||
@router(execute_native_tool)
|
||||
def increment_native_and_continue(self) -> Literal["initialized"]:
|
||||
"""Increment iteration counter after native tool execution."""
|
||||
self.state.iterations += 1
|
||||
return "initialized"
|
||||
|
||||
@listen("initialized")
|
||||
def continue_iteration(self) -> Literal["check_iteration"]:
|
||||
"""Bridge listener that connects iteration loop back to iteration check."""
|
||||
@@ -656,14 +375,10 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
@router(or_(initialize_reasoning, continue_iteration))
|
||||
def check_max_iterations(
|
||||
self,
|
||||
) -> Literal[
|
||||
"force_final_answer", "continue_reasoning", "continue_reasoning_native"
|
||||
]:
|
||||
) -> Literal["force_final_answer", "continue_reasoning"]:
|
||||
"""Check if max iterations reached before proceeding with reasoning."""
|
||||
if has_reached_max_iterations(self.state.iterations, self.max_iter):
|
||||
return "force_final_answer"
|
||||
if self.state.use_native_tools:
|
||||
return "continue_reasoning_native"
|
||||
return "continue_reasoning"
|
||||
|
||||
@router(execute_tool_action)
|
||||
@@ -672,7 +387,7 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
self.state.iterations += 1
|
||||
return "initialized"
|
||||
|
||||
@listen(or_("agent_finished", "tool_result_is_final", "native_finished"))
|
||||
@listen(or_("agent_finished", "tool_result_is_final"))
|
||||
def finalize(self) -> Literal["completed", "skipped"]:
|
||||
"""Finalize execution and emit completion logs."""
|
||||
if self.state.current_answer is None:
|
||||
@@ -760,8 +475,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
self.state.iterations = 0
|
||||
self.state.current_answer = None
|
||||
self.state.is_finished = False
|
||||
self.state.use_native_tools = False
|
||||
self.state.pending_tool_calls = []
|
||||
|
||||
if "system" in self.prompt:
|
||||
prompt = cast("SystemPromptResult", self.prompt)
|
||||
|
||||
@@ -931,6 +931,7 @@ class LLM(BaseLLM):
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
|
||||
if not tool_calls or not available_functions:
|
||||
|
||||
if response_model and self.is_litellm:
|
||||
instructor_instance = InternalInstructor(
|
||||
content=full_response,
|
||||
@@ -1143,12 +1144,8 @@ class LLM(BaseLLM):
|
||||
if response_model:
|
||||
params["response_model"] = response_model
|
||||
response = litellm.completion(**params)
|
||||
|
||||
if (
|
||||
hasattr(response, "usage")
|
||||
and not isinstance(response.usage, type)
|
||||
and response.usage
|
||||
):
|
||||
|
||||
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
|
||||
usage_info = response.usage
|
||||
self._track_token_usage_internal(usage_info)
|
||||
|
||||
@@ -1202,19 +1199,16 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
|
||||
# --- 6) If there are tool calls but no available functions, return the tool calls
|
||||
# This allows the caller (e.g., executor) to handle tool execution
|
||||
if tool_calls and not available_functions:
|
||||
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
# --- 7) Handle tool calls if present (execute when available_functions provided)
|
||||
if tool_calls and available_functions:
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
# --- 7) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
# --- 8) If tool call handling didn't return a result, emit completion event and return text response
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
@@ -1279,11 +1273,7 @@ class LLM(BaseLLM):
|
||||
params["response_model"] = response_model
|
||||
response = await litellm.acompletion(**params)
|
||||
|
||||
if (
|
||||
hasattr(response, "usage")
|
||||
and not isinstance(response.usage, type)
|
||||
and response.usage
|
||||
):
|
||||
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
|
||||
usage_info = response.usage
|
||||
self._track_token_usage_internal(usage_info)
|
||||
|
||||
@@ -1331,18 +1321,14 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
|
||||
# If there are tool calls but no available functions, return the tool calls
|
||||
# This allows the caller (e.g., executor) to handle tool execution
|
||||
if tool_calls and not available_functions:
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
# Handle tool calls if present (execute when available_functions provided)
|
||||
if tool_calls and available_functions:
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
@@ -1377,7 +1363,7 @@ class LLM(BaseLLM):
|
||||
"""
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
|
||||
usage_info = None
|
||||
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
|
||||
@@ -445,7 +445,7 @@ class BaseLLM(ABC):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
return result
|
||||
return str(result)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing function '{function_name}': {e!s}"
|
||||
|
||||
@@ -418,7 +418,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
- System messages are separate from conversation messages
|
||||
- Messages must alternate between user and assistant
|
||||
- First message must be from user
|
||||
- Tool results must be in user messages with tool_result content blocks
|
||||
- When thinking is enabled, assistant messages must start with thinking blocks
|
||||
|
||||
Args:
|
||||
@@ -432,7 +431,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
formatted_messages: list[LLMMessage] = []
|
||||
system_message: str | None = None
|
||||
pending_tool_results: list[dict[str, Any]] = []
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role")
|
||||
@@ -443,47 +441,16 @@ class AnthropicCompletion(BaseLLM):
|
||||
system_message += f"\n\n{content}"
|
||||
else:
|
||||
system_message = cast(str, content)
|
||||
elif role == "tool":
|
||||
# Convert OpenAI-style tool message to Anthropic tool_result format
|
||||
# These will be collected and added as a user message
|
||||
tool_call_id = message.get("tool_call_id", "")
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": content if content else "",
|
||||
}
|
||||
pending_tool_results.append(tool_result)
|
||||
elif role == "assistant":
|
||||
# First, flush any pending tool results as a user message
|
||||
if pending_tool_results:
|
||||
formatted_messages.append(
|
||||
{"role": "user", "content": pending_tool_results}
|
||||
)
|
||||
pending_tool_results = []
|
||||
else:
|
||||
role_str = role if role is not None else "user"
|
||||
|
||||
# Handle assistant message with tool_calls (convert to Anthropic format)
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
assistant_content: list[dict[str, Any]] = []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
func = tc.get("function", {})
|
||||
tool_use = {
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", ""),
|
||||
"name": func.get("name", ""),
|
||||
"input": json.loads(func.get("arguments", "{}"))
|
||||
if isinstance(func.get("arguments"), str)
|
||||
else func.get("arguments", {}),
|
||||
}
|
||||
assistant_content.append(tool_use)
|
||||
if assistant_content:
|
||||
formatted_messages.append(
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
formatted_messages.append({"role": "assistant", "content": content})
|
||||
elif self.thinking and self.previous_thinking_blocks:
|
||||
if isinstance(content, list):
|
||||
formatted_messages.append({"role": role_str, "content": content})
|
||||
elif (
|
||||
role_str == "assistant"
|
||||
and self.thinking
|
||||
and self.previous_thinking_blocks
|
||||
):
|
||||
structured_content = cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
@@ -492,34 +459,14 @@ class AnthropicCompletion(BaseLLM):
|
||||
],
|
||||
)
|
||||
formatted_messages.append(
|
||||
LLMMessage(role="assistant", content=structured_content)
|
||||
LLMMessage(role=role_str, content=structured_content)
|
||||
)
|
||||
else:
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append(
|
||||
LLMMessage(role="assistant", content=content_str)
|
||||
)
|
||||
else:
|
||||
# User message - first flush any pending tool results
|
||||
if pending_tool_results:
|
||||
formatted_messages.append(
|
||||
{"role": "user", "content": pending_tool_results}
|
||||
)
|
||||
pending_tool_results = []
|
||||
|
||||
role_str = role if role is not None else "user"
|
||||
if isinstance(content, list):
|
||||
formatted_messages.append({"role": role_str, "content": content})
|
||||
else:
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append(
|
||||
LLMMessage(role=role_str, content=content_str)
|
||||
)
|
||||
|
||||
# Flush any remaining pending tool results
|
||||
if pending_tool_results:
|
||||
formatted_messages.append({"role": "user", "content": pending_tool_results})
|
||||
|
||||
# Ensure first message is from user (Anthropic requirement)
|
||||
if not formatted_messages:
|
||||
# If no messages, add a default user message
|
||||
@@ -579,19 +526,13 @@ class AnthropicCompletion(BaseLLM):
|
||||
return structured_json
|
||||
|
||||
# Check if Claude wants to use tools
|
||||
if response.content:
|
||||
if response.content and available_functions:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
# This allows the executor to manage tool execution with proper
|
||||
# message history and post-tool reasoning prompts
|
||||
if not available_functions:
|
||||
return list(tool_uses)
|
||||
|
||||
# Handle tool use conversation flow internally
|
||||
# Handle tool use conversation flow
|
||||
return self._handle_tool_use_conversation(
|
||||
response,
|
||||
tool_uses,
|
||||
@@ -755,7 +696,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return structured_json
|
||||
|
||||
if final_message.content:
|
||||
if final_message.content and available_functions:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
@@ -763,11 +704,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
if not available_functions:
|
||||
return list(tool_uses)
|
||||
|
||||
# Handle tool use conversation flow internally
|
||||
# Handle tool use conversation flow
|
||||
return self._handle_tool_use_conversation(
|
||||
final_message,
|
||||
tool_uses,
|
||||
@@ -996,16 +933,12 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return structured_json
|
||||
|
||||
if response.content:
|
||||
if response.content and available_functions:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
if not available_functions:
|
||||
return list(tool_uses)
|
||||
|
||||
return await self._ahandle_tool_use_conversation(
|
||||
response,
|
||||
tool_uses,
|
||||
@@ -1146,7 +1079,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return structured_json
|
||||
|
||||
if final_message.content:
|
||||
if final_message.content and available_functions:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
@@ -1154,10 +1087,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
if not available_functions:
|
||||
return list(tool_uses)
|
||||
|
||||
return await self._ahandle_tool_use_conversation(
|
||||
final_message,
|
||||
tool_uses,
|
||||
|
||||
@@ -514,31 +514,10 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role", "user") # Default to user if no role
|
||||
# Handle None content - Azure requires string content
|
||||
content = message.get("content") or ""
|
||||
content = message.get("content", "")
|
||||
|
||||
# Handle tool role messages - keep as tool role for Azure OpenAI
|
||||
if role == "tool":
|
||||
tool_call_id = message.get("tool_call_id", "unknown")
|
||||
azure_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
# Handle assistant messages with tool_calls
|
||||
elif role == "assistant" and message.get("tool_calls"):
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
azure_msg: LLMMessage = {
|
||||
"role": "assistant",
|
||||
"content": content, # Already defaulted to "" above
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
azure_messages.append(azure_msg)
|
||||
else:
|
||||
# Azure AI Inference requires both 'role' and 'content'
|
||||
azure_messages.append({"role": role, "content": content})
|
||||
# Azure AI Inference requires both 'role' and 'content'
|
||||
azure_messages.append({"role": role, "content": content})
|
||||
|
||||
return azure_messages
|
||||
|
||||
@@ -625,11 +604,6 @@ class AzureCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# If there are tool_calls but no available_functions, return the tool_calls
|
||||
# This allows the caller (e.g., executor) to handle tool execution
|
||||
if message.tool_calls and not available_functions:
|
||||
return list(message.tool_calls)
|
||||
|
||||
# Handle tool calls
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0] # Handle first tool call
|
||||
@@ -801,21 +775,6 @@ class AzureCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# If there are tool_calls but no available_functions, return them
|
||||
# in OpenAI-compatible format for executor to handle
|
||||
if tool_calls and not available_functions:
|
||||
return [
|
||||
{
|
||||
"id": call_data.get("id", f"call_{idx}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call_data["name"],
|
||||
"arguments": call_data["arguments"],
|
||||
},
|
||||
}
|
||||
for idx, call_data in tool_calls.items()
|
||||
]
|
||||
|
||||
# Handle completed tool calls
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
|
||||
@@ -606,17 +606,6 @@ class GeminiCompletion(BaseLLM):
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
# Collect function call parts
|
||||
function_call_parts = [
|
||||
part for part in candidate.content.parts if part.function_call
|
||||
]
|
||||
|
||||
# If there are function calls but no available_functions,
|
||||
# return them for the executor to handle (like OpenAI/Anthropic)
|
||||
if function_call_parts and not available_functions:
|
||||
return function_call_parts
|
||||
|
||||
# Otherwise execute the tools internally
|
||||
for part in candidate.content.parts:
|
||||
if part.function_call:
|
||||
function_name = part.function_call.name
|
||||
@@ -731,7 +720,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str:
|
||||
"""Finalize streaming response with usage tracking, function execution, and events.
|
||||
|
||||
Args:
|
||||
@@ -749,21 +738,6 @@ class GeminiCompletion(BaseLLM):
|
||||
"""
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
# If there are function calls but no available_functions,
|
||||
# return them for the executor to handle
|
||||
if function_calls and not available_functions:
|
||||
return [
|
||||
{
|
||||
"id": call_data["id"],
|
||||
"function": {
|
||||
"name": call_data["name"],
|
||||
"arguments": json.dumps(call_data["args"]),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
for call_data in function_calls.values()
|
||||
]
|
||||
|
||||
# Handle completed function calls
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
|
||||
@@ -428,12 +428,6 @@ class OpenAICompletion(BaseLLM):
|
||||
choice: Choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# If there are tool_calls but no available_functions, return the tool_calls
|
||||
# This allows the caller (e.g., executor) to handle tool execution
|
||||
if message.tool_calls and not available_functions:
|
||||
return list(message.tool_calls)
|
||||
|
||||
# If there are tool_calls and available_functions, execute the tools
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
@@ -731,15 +725,6 @@ class OpenAICompletion(BaseLLM):
|
||||
choice: Choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# If there are tool_calls but no available_functions, return the tool_calls
|
||||
# This allows the caller (e.g., executor) to handle tool execution
|
||||
if message.tool_calls and not available_functions:
|
||||
print("--------------------------------")
|
||||
print("lorenze tool_calls", list(message.tool_calls))
|
||||
print("--------------------------------")
|
||||
return list(message.tool_calls)
|
||||
|
||||
# If there are tool_calls and available_functions, execute the tools
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
@@ -11,9 +11,6 @@
|
||||
"role_playing": "You are {role}. {backstory}\nYour personal goal is: {goal}",
|
||||
"tools": "\nYou ONLY have access to the following tools, and should NEVER make up tools that are not listed here:\n\n{tools}\n\nIMPORTANT: Use the following format in your response:\n\n```\nThought: you should always think about what to do\nAction: the action to take, only one name of [{tool_names}], just the name, exactly as it's written.\nAction Input: the input to the action, just a simple JSON object, enclosed in curly braces, using \" to wrap keys and values.\nObservation: the result of the action\n```\n\nOnce all necessary information is gathered, return the following format:\n\n```\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\n```",
|
||||
"no_tools": "\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!",
|
||||
"native_tools": "\nUse available tools to gather information and complete your task.",
|
||||
"native_task": "\nCurrent Task: {input}\n\nThis is VERY important to you, your job depends on it!",
|
||||
"post_tool_reasoning": "PAUSE and THINK before responding.\n\nInternally consider (DO NOT output these steps):\n- What key insights did the tool provide?\n- Have I fulfilled ALL requirements from my original instructions (e.g., minimum tool calls, specific sources)?\n- Do I have enough information to fully answer the task?\n\nIF you have NOT met all requirements or need more information: Call another tool now.\n\nIF you have met all requirements and have sufficient information: Provide ONLY your final answer in the format specified by the task's expected output. Do NOT include reasoning steps, analysis sections, or meta-commentary. Just deliver the answer.",
|
||||
"format": "I MUST either use a tool (use one at time) OR give my best final answer not both at the same time. When responding, I must use the following format:\n\n```\nThought: you should always think about what to do\nAction: the action to take, should be one of [{tool_names}]\nAction Input: the input to the action, dictionary enclosed in curly braces\nObservation: the result of the action\n```\nThis Thought/Action/Action Input/Result can repeat N times. Once I know the final answer, I must return the following format:\n\n```\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described\n\n```",
|
||||
"final_answer_format": "If you don't need to use any more tools, you must give your best complete final answer, make sure it satisfies the expected criteria, use the EXACT format below:\n\n```\nThought: I now can give a great answer\nFinal Answer: my best complete final answer to the task.\n\n```",
|
||||
"format_without_tools": "\nSorry, I didn't use the right format. I MUST either use a tool (among the available ones), OR give my best final answer.\nHere is the expected format I must follow:\n\n```\nQuestion: the input question you must answer\nThought: you should always think about what to do\nAction: the action to take, should be one of [{tool_names}]\nAction Input: the input to the action\nObservation: the result of the action\n```\n This Thought/Action/Action Input/Result process can repeat N times. Once I know the final answer, I must return the following format:\n\n```\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described\n\n```",
|
||||
|
||||
@@ -108,65 +108,6 @@ def render_text_description_and_args(
|
||||
return "\n".join(tool_strings)
|
||||
|
||||
|
||||
def convert_tools_to_openai_schema(
|
||||
tools: Sequence[BaseTool | CrewStructuredTool],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Callable[..., Any]]]:
|
||||
"""Convert CrewAI tools to OpenAI function calling format.
|
||||
|
||||
This function converts CrewAI BaseTool and CrewStructuredTool objects
|
||||
into the OpenAI-compatible tool schema format that can be passed to
|
||||
LLM providers for native function calling.
|
||||
|
||||
Args:
|
||||
tools: List of CrewAI tool objects to convert.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of OpenAI-format tool schema dictionaries
|
||||
- Dict mapping tool names to their callable run() methods
|
||||
|
||||
Example:
|
||||
>>> tools = [CalculatorTool(), SearchTool()]
|
||||
>>> schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
>>> # schemas can be passed to llm.call(tools=schemas)
|
||||
>>> # functions can be passed to llm.call(available_functions=functions)
|
||||
"""
|
||||
openai_tools: list[dict[str, Any]] = []
|
||||
available_functions: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
for tool in tools:
|
||||
# Get the JSON schema for tool parameters
|
||||
parameters: dict[str, Any] = {}
|
||||
if hasattr(tool, "args_schema") and tool.args_schema is not None:
|
||||
try:
|
||||
parameters = tool.args_schema.model_json_schema()
|
||||
# Remove title and description from schema root as they're redundant
|
||||
parameters.pop("title", None)
|
||||
parameters.pop("description", None)
|
||||
except Exception:
|
||||
parameters = {}
|
||||
|
||||
# Extract original description from formatted description
|
||||
# BaseTool formats description as "Tool Name: ...\nTool Arguments: ...\nTool Description: {original}"
|
||||
description = tool.description
|
||||
if "Tool Description:" in description:
|
||||
# Extract the original description after "Tool Description:"
|
||||
description = description.split("Tool Description:")[-1].strip()
|
||||
|
||||
schema: dict[str, Any] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": description,
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
openai_tools.append(schema)
|
||||
available_functions[tool.name] = tool.run
|
||||
|
||||
return openai_tools, available_functions
|
||||
|
||||
|
||||
def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool:
|
||||
"""Check if the maximum number of iterations has been reached.
|
||||
|
||||
@@ -293,13 +234,11 @@ def get_llm_response(
|
||||
messages: list[LLMMessage],
|
||||
callbacks: list[TokenCalcHandler],
|
||||
printer: Printer,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
available_functions: dict[str, Callable[..., Any]] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None = None,
|
||||
) -> str | Any:
|
||||
) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
Args:
|
||||
@@ -307,16 +246,13 @@ def get_llm_response(
|
||||
messages: The messages to send to the LLM.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
printer: Printer instance for output.
|
||||
tools: Optional list of tool schemas for native function calling.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
from_task: Optional task context for the LLM call.
|
||||
from_agent: Optional agent context for the LLM call.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string, or tool call results if
|
||||
native function calling is used.
|
||||
The response from the LLM as a string.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
@@ -331,9 +267,7 @@ def get_llm_response(
|
||||
try:
|
||||
answer = llm.call(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent, # type: ignore[arg-type]
|
||||
response_model=response_model,
|
||||
@@ -355,13 +289,11 @@ async def aget_llm_response(
|
||||
messages: list[LLMMessage],
|
||||
callbacks: list[TokenCalcHandler],
|
||||
printer: Printer,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
available_functions: dict[str, Callable[..., Any]] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
) -> str | Any:
|
||||
) -> str:
|
||||
"""Call the LLM asynchronously and return the response.
|
||||
|
||||
Args:
|
||||
@@ -369,16 +301,13 @@ async def aget_llm_response(
|
||||
messages: The messages to send to the LLM.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
printer: Printer instance for output.
|
||||
tools: Optional list of tool schemas for native function calling.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
from_task: Optional task context for the LLM call.
|
||||
from_agent: Optional agent context for the LLM call.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string, or tool call results if
|
||||
native function calling is used.
|
||||
The response from the LLM as a string.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
@@ -392,9 +321,7 @@ async def aget_llm_response(
|
||||
try:
|
||||
answer = await llm.acall(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent, # type: ignore[arg-type]
|
||||
response_model=response_model,
|
||||
|
||||
@@ -22,9 +22,7 @@ class SystemPromptResult(StandardPromptResult):
|
||||
user: Annotated[str, "The user prompt component"]
|
||||
|
||||
|
||||
COMPONENTS = Literal[
|
||||
"role_playing", "tools", "no_tools", "native_tools", "task", "native_task"
|
||||
]
|
||||
COMPONENTS = Literal["role_playing", "tools", "no_tools", "task"]
|
||||
|
||||
|
||||
class Prompts(BaseModel):
|
||||
@@ -38,10 +36,6 @@ class Prompts(BaseModel):
|
||||
has_tools: bool = Field(
|
||||
default=False, description="Indicates if the agent has access to tools"
|
||||
)
|
||||
use_native_tool_calling: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use native function calling instead of ReAct format",
|
||||
)
|
||||
system_template: str | None = Field(
|
||||
default=None, description="Custom system prompt template"
|
||||
)
|
||||
@@ -64,24 +58,12 @@ class Prompts(BaseModel):
|
||||
A dictionary containing the constructed prompt(s).
|
||||
"""
|
||||
slices: list[COMPONENTS] = ["role_playing"]
|
||||
# When using native tool calling with tools, use native_tools instructions
|
||||
# When using ReAct pattern with tools, use tools instructions
|
||||
# When no tools are available, use no_tools instructions
|
||||
if self.has_tools:
|
||||
if self.use_native_tool_calling:
|
||||
slices.append("native_tools")
|
||||
else:
|
||||
slices.append("tools")
|
||||
slices.append("tools")
|
||||
else:
|
||||
slices.append("no_tools")
|
||||
system: str = self._build_prompt(slices)
|
||||
|
||||
# Use native_task for native tool calling (no "Thought:" prompt)
|
||||
# Use task for ReAct pattern (includes "Thought:" prompt)
|
||||
task_slice: COMPONENTS = (
|
||||
"native_task" if self.use_native_tool_calling else "task"
|
||||
)
|
||||
slices.append(task_slice)
|
||||
slices.append("task")
|
||||
|
||||
if (
|
||||
not self.system_template
|
||||
@@ -90,7 +72,7 @@ class Prompts(BaseModel):
|
||||
):
|
||||
return SystemPromptResult(
|
||||
system=system,
|
||||
user=self._build_prompt([task_slice]),
|
||||
user=self._build_prompt(["task"]),
|
||||
prompt=self._build_prompt(slices),
|
||||
)
|
||||
return StandardPromptResult(
|
||||
|
||||
@@ -26,9 +26,13 @@ def mock_agent() -> MagicMock:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task() -> MagicMock:
|
||||
def mock_task(mock_context: MagicMock) -> MagicMock:
|
||||
"""Create a mock Task."""
|
||||
return MagicMock()
|
||||
task = MagicMock()
|
||||
task.id = mock_context.task_id
|
||||
task.name = "Mock Task"
|
||||
task.description = "Mock task description"
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -179,8 +183,8 @@ class TestExecute:
|
||||
event = first_call[0][1]
|
||||
|
||||
assert event.type == "a2a_server_task_started"
|
||||
assert event.a2a_task_id == mock_context.task_id
|
||||
assert event.a2a_context_id == mock_context.context_id
|
||||
assert event.task_id == mock_context.task_id
|
||||
assert event.context_id == mock_context.context_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emits_completed_event(
|
||||
@@ -201,7 +205,7 @@ class TestExecute:
|
||||
event = second_call[0][1]
|
||||
|
||||
assert event.type == "a2a_server_task_completed"
|
||||
assert event.a2a_task_id == mock_context.task_id
|
||||
assert event.task_id == mock_context.task_id
|
||||
assert event.result == "Task completed successfully"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -250,7 +254,7 @@ class TestExecute:
|
||||
event = canceled_call[0][1]
|
||||
|
||||
assert event.type == "a2a_server_task_canceled"
|
||||
assert event.a2a_task_id == mock_context.task_id
|
||||
assert event.task_id == mock_context.task_id
|
||||
|
||||
|
||||
class TestCancel:
|
||||
|
||||
@@ -14,6 +14,16 @@ except ImportError:
|
||||
A2A_SDK_INSTALLED = False
|
||||
|
||||
|
||||
def _create_mock_agent_card(name: str = "Test", url: str = "http://test-endpoint.com/"):
|
||||
"""Create a mock agent card with proper model_dump behavior."""
|
||||
mock_card = MagicMock()
|
||||
mock_card.name = name
|
||||
mock_card.url = url
|
||||
mock_card.model_dump.return_value = {"name": name, "url": url}
|
||||
mock_card.model_dump_json.return_value = f'{{"name": "{name}", "url": "{url}"}}'
|
||||
return mock_card
|
||||
|
||||
|
||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||
def test_trust_remote_completion_status_true_returns_directly():
|
||||
"""When trust_remote_completion_status=True and A2A returns completed, return result directly."""
|
||||
@@ -44,8 +54,7 @@ def test_trust_remote_completion_status_true_returns_directly():
|
||||
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
):
|
||||
mock_card = MagicMock()
|
||||
mock_card.name = "Test"
|
||||
mock_card = _create_mock_agent_card()
|
||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||
|
||||
# A2A returns completed
|
||||
@@ -110,8 +119,7 @@ def test_trust_remote_completion_status_false_continues_conversation():
|
||||
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
):
|
||||
mock_card = MagicMock()
|
||||
mock_card.name = "Test"
|
||||
mock_card = _create_mock_agent_card()
|
||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||
|
||||
# A2A returns completed
|
||||
|
||||
@@ -1,479 +0,0 @@
|
||||
"""Integration tests for native tool calling functionality.
|
||||
|
||||
These tests verify that agents can use native function calling
|
||||
when the LLM supports it, across multiple providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
# Check for optional provider availability
|
||||
try:
|
||||
import anthropic
|
||||
HAS_ANTHROPIC = True
|
||||
except ImportError:
|
||||
HAS_ANTHROPIC = False
|
||||
|
||||
try:
|
||||
import google.genai
|
||||
HAS_GOOGLE_GENAI = True
|
||||
except ImportError:
|
||||
HAS_GOOGLE_GENAI = False
|
||||
|
||||
try:
|
||||
import boto3
|
||||
HAS_BOTO3 = True
|
||||
except ImportError:
|
||||
HAS_BOTO3 = False
|
||||
|
||||
|
||||
class CalculatorInput(BaseModel):
|
||||
"""Input schema for calculator tool."""
|
||||
|
||||
expression: str = Field(description="Mathematical expression to evaluate")
|
||||
|
||||
|
||||
class CalculatorTool(BaseTool):
|
||||
"""A calculator tool that performs mathematical calculations."""
|
||||
|
||||
name: str = "calculator"
|
||||
description: str = "Perform mathematical calculations. Use this for any math operations."
|
||||
args_schema: type[BaseModel] = CalculatorInput
|
||||
|
||||
def _run(self, expression: str) -> str:
|
||||
"""Execute the calculation."""
|
||||
try:
|
||||
# Safe evaluation for basic math
|
||||
result = eval(expression) # noqa: S307
|
||||
return f"The result of {expression} is {result}"
|
||||
except Exception as e:
|
||||
return f"Error calculating {expression}: {e}"
|
||||
|
||||
|
||||
class WeatherInput(BaseModel):
|
||||
"""Input schema for weather tool."""
|
||||
|
||||
location: str = Field(description="City name to get weather for")
|
||||
|
||||
|
||||
class WeatherTool(BaseTool):
|
||||
"""A mock weather tool for testing."""
|
||||
|
||||
name: str = "get_weather"
|
||||
description: str = "Get the current weather for a location"
|
||||
args_schema: type[BaseModel] = WeatherInput
|
||||
|
||||
def _run(self, location: str) -> str:
|
||||
"""Get weather (mock implementation)."""
|
||||
return f"The weather in {location} is sunny with a temperature of 72°F"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calculator_tool() -> CalculatorTool:
|
||||
"""Create a calculator tool for testing."""
|
||||
return CalculatorTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tool() -> WeatherTool:
|
||||
"""Create a weather tool for testing."""
|
||||
return WeatherTool()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI Provider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestOpenAINativeToolCalling:
|
||||
"""Tests for native tool calling with OpenAI models."""
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_openai_agent_with_native_tool_calling(
|
||||
self, calculator_tool: CalculatorTool
|
||||
) -> None:
|
||||
"""Test OpenAI agent can use native tool calling."""
|
||||
agent = Agent(
|
||||
role="Math Assistant",
|
||||
goal="Help users with mathematical calculations",
|
||||
backstory="You are a helpful math assistant.",
|
||||
tools=[calculator_tool],
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
verbose=False,
|
||||
max_iter=3,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate what is 15 * 8",
|
||||
expected_output="The result of the calculation",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
assert "120" in str(result.raw)
|
||||
|
||||
def test_openai_agent_kickoff_with_tools_mocked(
|
||||
self, calculator_tool: CalculatorTool
|
||||
) -> None:
|
||||
"""Test OpenAI agent kickoff with mocked LLM call."""
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
|
||||
agent = Agent(
|
||||
role="Math Assistant",
|
||||
goal="Calculate math",
|
||||
backstory="You calculate.",
|
||||
tools=[calculator_tool],
|
||||
llm=llm,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate 15 * 8",
|
||||
expected_output="Result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert mock_call.called
|
||||
assert result is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Anthropic Provider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_ANTHROPIC, reason="anthropic package not installed")
|
||||
class TestAnthropicNativeToolCalling:
|
||||
"""Tests for native tool calling with Anthropic models."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_anthropic_api_key(self):
|
||||
"""Mock ANTHROPIC_API_KEY for tests."""
|
||||
if "ANTHROPIC_API_KEY" not in os.environ:
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_anthropic_agent_with_native_tool_calling(
|
||||
self, calculator_tool: CalculatorTool
|
||||
) -> None:
|
||||
"""Test Anthropic agent can use native tool calling."""
|
||||
agent = Agent(
|
||||
role="Math Assistant",
|
||||
goal="Help users with mathematical calculations",
|
||||
backstory="You are a helpful math assistant.",
|
||||
tools=[calculator_tool],
|
||||
llm=LLM(model="anthropic/claude-3-5-haiku-20241022"),
|
||||
verbose=False,
|
||||
max_iter=3,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate what is 15 * 8",
|
||||
expected_output="The result of the calculation",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
|
||||
def test_anthropic_agent_kickoff_with_tools_mocked(
|
||||
self, calculator_tool: CalculatorTool
|
||||
) -> None:
|
||||
"""Test Anthropic agent kickoff with mocked LLM call."""
|
||||
llm = LLM(model="anthropic/claude-3-5-haiku-20241022")
|
||||
|
||||
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
|
||||
agent = Agent(
|
||||
role="Math Assistant",
|
||||
goal="Calculate math",
|
||||
backstory="You calculate.",
|
||||
tools=[calculator_tool],
|
||||
llm=llm,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate 15 * 8",
|
||||
expected_output="Result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert mock_call.called
|
||||
assert result is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Google/Gemini Provider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_GOOGLE_GENAI, reason="google-genai package not installed")
|
||||
class TestGeminiNativeToolCalling:
|
||||
"""Tests for native tool calling with Gemini models."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_google_api_key(self):
|
||||
"""Mock GOOGLE_API_KEY for tests."""
|
||||
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
|
||||
yield
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_gemini_agent_with_native_tool_calling(
|
||||
self, calculator_tool: CalculatorTool
|
||||
) -> None:
|
||||
"""Test Gemini agent can use native tool calling."""
|
||||
agent = Agent(
|
||||
role="Math Assistant",
|
||||
goal="Help users with mathematical calculations",
|
||||
backstory="You are a helpful math assistant.",
|
||||
tools=[calculator_tool],
|
||||
llm=LLM(model="gemini/gemini-2.0-flash-001"),
|
||||
verbose=False,
|
||||
max_iter=3,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate what is 15 * 8",
|
||||
expected_output="The result of the calculation",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
|
||||
def test_gemini_agent_kickoff_with_tools_mocked(
|
||||
self, calculator_tool: CalculatorTool
|
||||
) -> None:
|
||||
"""Test Gemini agent kickoff with mocked LLM call."""
|
||||
llm = LLM(model="gemini/gemini-2.0-flash-001")
|
||||
|
||||
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
|
||||
agent = Agent(
|
||||
role="Math Assistant",
|
||||
goal="Calculate math",
|
||||
backstory="You calculate.",
|
||||
tools=[calculator_tool],
|
||||
llm=llm,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate 15 * 8",
|
||||
expected_output="Result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert mock_call.called
|
||||
assert result is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Azure Provider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAzureNativeToolCalling:
|
||||
"""Tests for native tool calling with Azure OpenAI models."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_azure_env(self):
|
||||
"""Mock Azure environment variables for tests."""
|
||||
env_vars = {
|
||||
"AZURE_API_KEY": "test-key",
|
||||
"AZURE_API_BASE": "https://test.openai.azure.com",
|
||||
"AZURE_API_VERSION": "2024-02-15-preview",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars):
|
||||
yield
|
||||
|
||||
def test_azure_agent_kickoff_with_tools_mocked(
|
||||
self, calculator_tool: CalculatorTool
|
||||
) -> None:
|
||||
"""Test Azure agent kickoff with mocked LLM call."""
|
||||
llm = LLM(
|
||||
model="azure/gpt-4o-mini",
|
||||
api_key="test-key",
|
||||
base_url="https://test.openai.azure.com",
|
||||
)
|
||||
|
||||
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
|
||||
agent = Agent(
|
||||
role="Math Assistant",
|
||||
goal="Calculate math",
|
||||
backstory="You calculate.",
|
||||
tools=[calculator_tool],
|
||||
llm=llm,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate 15 * 8",
|
||||
expected_output="Result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert mock_call.called
|
||||
assert result is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Bedrock Provider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_BOTO3, reason="boto3 package not installed")
|
||||
class TestBedrockNativeToolCalling:
|
||||
"""Tests for native tool calling with AWS Bedrock models."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_aws_env(self):
|
||||
"""Mock AWS environment variables for tests."""
|
||||
env_vars = {
|
||||
"AWS_ACCESS_KEY_ID": "test-key",
|
||||
"AWS_SECRET_ACCESS_KEY": "test-secret",
|
||||
"AWS_REGION": "us-east-1",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars):
|
||||
yield
|
||||
|
||||
def test_bedrock_agent_kickoff_with_tools_mocked(
|
||||
self, calculator_tool: CalculatorTool
|
||||
) -> None:
|
||||
"""Test Bedrock agent kickoff with mocked LLM call."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-haiku-20240307-v1:0")
|
||||
|
||||
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
|
||||
agent = Agent(
|
||||
role="Math Assistant",
|
||||
goal="Calculate math",
|
||||
backstory="You calculate.",
|
||||
tools=[calculator_tool],
|
||||
llm=llm,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate 15 * 8",
|
||||
expected_output="Result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert mock_call.called
|
||||
assert result is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cross-Provider Native Tool Calling Behavior Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNativeToolCallingBehavior:
|
||||
"""Tests for native tool calling behavior across providers."""
|
||||
|
||||
def test_supports_function_calling_check(self) -> None:
|
||||
"""Test that supports_function_calling() is properly checked."""
|
||||
# OpenAI should support function calling
|
||||
openai_llm = LLM(model="gpt-4o-mini")
|
||||
assert hasattr(openai_llm, "supports_function_calling")
|
||||
assert openai_llm.supports_function_calling() is True
|
||||
|
||||
@pytest.mark.skipif(not HAS_ANTHROPIC, reason="anthropic package not installed")
|
||||
def test_anthropic_supports_function_calling(self) -> None:
|
||||
"""Test that Anthropic models support function calling."""
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm = LLM(model="anthropic/claude-3-5-haiku-20241022")
|
||||
assert hasattr(llm, "supports_function_calling")
|
||||
assert llm.supports_function_calling() is True
|
||||
|
||||
@pytest.mark.skipif(not HAS_GOOGLE_GENAI, reason="google-genai package not installed")
|
||||
def test_gemini_supports_function_calling(self) -> None:
|
||||
"""Test that Gemini models support function calling."""
|
||||
# with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
|
||||
print("GOOGLE_API_KEY", os.getenv("GOOGLE_API_KEY"))
|
||||
llm = LLM(model="gemini/gemini-2.5-flash")
|
||||
assert hasattr(llm, "supports_function_calling")
|
||||
# Gemini uses supports_tools property
|
||||
assert llm.supports_function_calling() is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Token Usage Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNativeToolCallingTokenUsage:
|
||||
"""Tests for token usage with native tool calling."""
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_openai_native_tool_calling_token_usage(
|
||||
self, calculator_tool: CalculatorTool
|
||||
) -> None:
|
||||
"""Test token usage tracking with OpenAI native tool calling."""
|
||||
agent = Agent(
|
||||
role="Calculator",
|
||||
goal="Perform calculations efficiently",
|
||||
backstory="You calculate things.",
|
||||
tools=[calculator_tool],
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
verbose=False,
|
||||
max_iter=3,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="What is 100 / 4?",
|
||||
expected_output="The result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result is not None
|
||||
assert result.token_usage is not None
|
||||
assert result.token_usage.total_tokens > 0
|
||||
assert result.token_usage.successful_requests >= 1
|
||||
|
||||
print(f"\n[OPENAI NATIVE TOOL CALLING TOKEN USAGE]")
|
||||
print(f" Prompt tokens: {result.token_usage.prompt_tokens}")
|
||||
print(f" Completion tokens: {result.token_usage.completion_tokens}")
|
||||
print(f" Total tokens: {result.token_usage.total_tokens}")
|
||||
@@ -1,214 +0,0 @@
|
||||
"""Tests for agent utility functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
|
||||
|
||||
class CalculatorInput(BaseModel):
|
||||
"""Input schema for calculator tool."""
|
||||
|
||||
expression: str = Field(description="Mathematical expression to evaluate")
|
||||
|
||||
|
||||
class CalculatorTool(BaseTool):
|
||||
"""A simple calculator tool for testing."""
|
||||
|
||||
name: str = "calculator"
|
||||
description: str = "Perform mathematical calculations"
|
||||
args_schema: type[BaseModel] = CalculatorInput
|
||||
|
||||
def _run(self, expression: str) -> str:
|
||||
"""Execute the calculation."""
|
||||
try:
|
||||
result = eval(expression) # noqa: S307
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
class SearchInput(BaseModel):
|
||||
"""Input schema for search tool."""
|
||||
|
||||
query: str = Field(description="Search query")
|
||||
max_results: int = Field(default=10, description="Maximum number of results")
|
||||
|
||||
|
||||
class SearchTool(BaseTool):
|
||||
"""A search tool for testing."""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = "Search the web for information"
|
||||
args_schema: type[BaseModel] = SearchInput
|
||||
|
||||
def _run(self, query: str, max_results: int = 10) -> str:
|
||||
"""Execute the search."""
|
||||
return f"Search results for '{query}' (max {max_results})"
|
||||
|
||||
|
||||
class NoSchemaTool(BaseTool):
|
||||
"""A tool without an args schema for testing edge cases."""
|
||||
|
||||
name: str = "simple_tool"
|
||||
description: str = "A simple tool with no schema"
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Execute the tool."""
|
||||
return "Simple tool executed"
|
||||
|
||||
|
||||
class TestConvertToolsToOpenaiSchema:
|
||||
"""Tests for convert_tools_to_openai_schema function."""
|
||||
|
||||
def test_converts_single_tool(self) -> None:
|
||||
"""Test converting a single tool to OpenAI schema."""
|
||||
tools = [CalculatorTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
assert len(schemas) == 1
|
||||
assert len(functions) == 1
|
||||
|
||||
schema = schemas[0]
|
||||
assert schema["type"] == "function"
|
||||
assert schema["function"]["name"] == "calculator"
|
||||
assert schema["function"]["description"] == "Perform mathematical calculations"
|
||||
assert "properties" in schema["function"]["parameters"]
|
||||
assert "expression" in schema["function"]["parameters"]["properties"]
|
||||
|
||||
def test_converts_multiple_tools(self) -> None:
|
||||
"""Test converting multiple tools to OpenAI schema."""
|
||||
tools = [CalculatorTool(), SearchTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
assert len(schemas) == 2
|
||||
assert len(functions) == 2
|
||||
|
||||
# Check calculator
|
||||
calc_schema = next(s for s in schemas if s["function"]["name"] == "calculator")
|
||||
assert calc_schema["function"]["description"] == "Perform mathematical calculations"
|
||||
|
||||
# Check search
|
||||
search_schema = next(s for s in schemas if s["function"]["name"] == "web_search")
|
||||
assert search_schema["function"]["description"] == "Search the web for information"
|
||||
assert "query" in search_schema["function"]["parameters"]["properties"]
|
||||
assert "max_results" in search_schema["function"]["parameters"]["properties"]
|
||||
|
||||
def test_functions_dict_contains_callables(self) -> None:
|
||||
"""Test that the functions dict maps names to callable run methods."""
|
||||
tools = [CalculatorTool(), SearchTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
assert "calculator" in functions
|
||||
assert "web_search" in functions
|
||||
assert callable(functions["calculator"])
|
||||
assert callable(functions["web_search"])
|
||||
|
||||
def test_function_can_be_called(self) -> None:
|
||||
"""Test that the returned function can be called."""
|
||||
tools = [CalculatorTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
result = functions["calculator"](expression="2 + 2")
|
||||
assert result == "4"
|
||||
|
||||
def test_empty_tools_list(self) -> None:
|
||||
"""Test with an empty tools list."""
|
||||
schemas, functions = convert_tools_to_openai_schema([])
|
||||
|
||||
assert schemas == []
|
||||
assert functions == {}
|
||||
|
||||
def test_schema_has_required_fields(self) -> None:
|
||||
"""Test that the schema includes required fields information."""
|
||||
tools = [SearchTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
schema = schemas[0]
|
||||
params = schema["function"]["parameters"]
|
||||
|
||||
# Should have required array
|
||||
assert "required" in params
|
||||
assert "query" in params["required"]
|
||||
|
||||
def test_tool_without_args_schema(self) -> None:
|
||||
"""Test converting a tool that doesn't have an args_schema."""
|
||||
# Create a minimal tool without args_schema
|
||||
class MinimalTool(BaseTool):
|
||||
name: str = "minimal"
|
||||
description: str = "A minimal tool"
|
||||
|
||||
def _run(self) -> str:
|
||||
return "done"
|
||||
|
||||
tools = [MinimalTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
assert len(schemas) == 1
|
||||
schema = schemas[0]
|
||||
assert schema["function"]["name"] == "minimal"
|
||||
# Parameters should be empty dict or have minimal schema
|
||||
assert isinstance(schema["function"]["parameters"], dict)
|
||||
|
||||
def test_schema_structure_matches_openai_format(self) -> None:
|
||||
"""Test that the schema structure matches OpenAI's expected format."""
|
||||
tools = [CalculatorTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
schema = schemas[0]
|
||||
|
||||
# Top level must have "type": "function"
|
||||
assert schema["type"] == "function"
|
||||
|
||||
# Must have "function" key with nested structure
|
||||
assert "function" in schema
|
||||
func = schema["function"]
|
||||
|
||||
# Function must have name and description
|
||||
assert "name" in func
|
||||
assert "description" in func
|
||||
assert isinstance(func["name"], str)
|
||||
assert isinstance(func["description"], str)
|
||||
|
||||
# Parameters should be a valid JSON schema
|
||||
assert "parameters" in func
|
||||
params = func["parameters"]
|
||||
assert isinstance(params, dict)
|
||||
|
||||
def test_removes_redundant_schema_fields(self) -> None:
|
||||
"""Test that redundant title and description are removed from parameters."""
|
||||
tools = [CalculatorTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
# Title should be removed as it's redundant with function name
|
||||
assert "title" not in params
|
||||
|
||||
def test_preserves_field_descriptions(self) -> None:
|
||||
"""Test that field descriptions are preserved in the schema."""
|
||||
tools = [SearchTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
query_prop = params["properties"]["query"]
|
||||
|
||||
# Field description should be preserved
|
||||
assert "description" in query_prop
|
||||
assert query_prop["description"] == "Search query"
|
||||
|
||||
def test_preserves_default_values(self) -> None:
|
||||
"""Test that default values are preserved in the schema."""
|
||||
tools = [SearchTool()]
|
||||
schemas, functions = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
max_results_prop = params["properties"]["max_results"]
|
||||
|
||||
# Default value should be preserved
|
||||
assert "default" in max_results_prop
|
||||
assert max_results_prop["default"] == 10
|
||||
Reference in New Issue
Block a user