Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
ca1f8fd7e0 feat: add MCP Discovery tool for dynamic MCP server discovery
This adds a new MCPDiscoveryTool that enables CrewAI agents to dynamically
discover MCP servers based on natural language queries using the MCP
Discovery API.

Features:
- Semantic search for MCP servers using natural language
- Support for constraints (max latency, required features, excluded servers)
- Returns server recommendations with installation instructions, capabilities,
  and performance metrics
- Includes comprehensive test coverage

Closes #4249

Co-Authored-By: João <joao@crewai.com>
2026-01-17 19:38:49 +00:00
Greyson LaLonde
ceef062426 feat: add additional a2a events and enrich event metadata
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-01-16 16:57:31 -05:00
37 changed files with 6985 additions and 6266 deletions

View File

@@ -89,6 +89,9 @@ from crewai_tools.tools.jina_scrape_website_tool.jina_scrape_website_tool import
from crewai_tools.tools.json_search_tool.json_search_tool import JSONSearchTool
from crewai_tools.tools.linkup.linkup_search_tool import LinkupSearchTool
from crewai_tools.tools.llamaindex_tool.llamaindex_tool import LlamaIndexTool
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
MCPDiscoveryTool,
)
from crewai_tools.tools.mdx_search_tool.mdx_search_tool import MDXSearchTool
from crewai_tools.tools.merge_agent_handler_tool.merge_agent_handler_tool import (
MergeAgentHandlerTool,
@@ -236,6 +239,7 @@ __all__ = [
"JinaScrapeWebsiteTool",
"LinkupSearchTool",
"LlamaIndexTool",
"MCPDiscoveryTool",
"MCPServerAdapter",
"MDXSearchTool",
"MergeAgentHandlerTool",

View File

@@ -0,0 +1,16 @@
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
MCPDiscoveryResult,
MCPDiscoveryTool,
MCPDiscoveryToolSchema,
MCPServerMetrics,
MCPServerRecommendation,
)
__all__ = [
"MCPDiscoveryResult",
"MCPDiscoveryTool",
"MCPDiscoveryToolSchema",
"MCPServerMetrics",
"MCPServerRecommendation",
]

View File

@@ -0,0 +1,414 @@
"""MCP Discovery Tool for CrewAI agents.
This tool enables agents to dynamically discover MCP servers based on
natural language queries using the MCP Discovery API.
"""
import json
import logging
import os
from typing import Any, TypedDict
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field
import requests
logger = logging.getLogger(__name__)
class MCPServerMetrics(TypedDict, total=False):
"""Performance metrics for an MCP server."""
avg_latency_ms: float | None
uptime_pct: float | None
last_checked: str | None
class MCPServerRecommendation(TypedDict, total=False):
"""A recommended MCP server from the discovery API."""
server: str
npm_package: str
install_command: str
confidence: float
description: str
capabilities: list[str]
metrics: MCPServerMetrics
docs_url: str
github_url: str
class MCPDiscoveryResult(TypedDict):
"""Result from the MCP Discovery API."""
recommendations: list[MCPServerRecommendation]
total_found: int
query_time_ms: int
class MCPDiscoveryConstraints(BaseModel):
"""Constraints for MCP server discovery."""
max_latency_ms: int | None = Field(
default=None,
description="Maximum acceptable latency in milliseconds",
)
required_features: list[str] | None = Field(
default=None,
description="List of required features/capabilities",
)
exclude_servers: list[str] | None = Field(
default=None,
description="List of server names to exclude from results",
)
class MCPDiscoveryToolSchema(BaseModel):
"""Input schema for MCPDiscoveryTool."""
need: str = Field(
...,
description=(
"Natural language description of what you need. "
"For example: 'database with authentication', 'email automation', "
"'file storage', 'web scraping'"
),
)
constraints: MCPDiscoveryConstraints | None = Field(
default=None,
description="Optional constraints to filter results",
)
limit: int = Field(
default=5,
description="Maximum number of recommendations to return (1-10)",
ge=1,
le=10,
)
class MCPDiscoveryTool(BaseTool):
"""Tool for discovering MCP servers dynamically.
This tool uses the MCP Discovery API to find MCP servers that match
a natural language description of what the agent needs. It enables
agents to dynamically discover and select the best MCP servers for
their tasks without requiring pre-configuration.
Example:
```python
from crewai import Agent
from crewai_tools import MCPDiscoveryTool
discovery_tool = MCPDiscoveryTool()
agent = Agent(
role='Researcher',
tools=[discovery_tool],
goal='Research and analyze data'
)
# The agent can now discover MCP servers dynamically:
# discover_mcp_server(need="database with authentication")
# Returns: Supabase MCP server with installation instructions
```
Attributes:
name: The name of the tool.
description: A description of what the tool does.
args_schema: The Pydantic model for input validation.
base_url: The base URL for the MCP Discovery API.
timeout: Request timeout in seconds.
"""
name: str = "Discover MCP Server"
description: str = (
"Discover MCP (Model Context Protocol) servers that match your needs. "
"Use this tool to find the best MCP server for any task using natural "
"language. Returns server recommendations with installation instructions, "
"capabilities, and performance metrics."
)
args_schema: type[BaseModel] = MCPDiscoveryToolSchema
base_url: str = "https://mcp-discovery-production.up.railway.app"
timeout: int = 30
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
name="MCP_DISCOVERY_API_KEY",
description="API key for MCP Discovery (optional for free tier)",
required=False,
),
]
)
def _build_request_payload(
self,
need: str,
constraints: MCPDiscoveryConstraints | None,
limit: int,
) -> dict[str, Any]:
"""Build the request payload for the discovery API.
Args:
need: Natural language description of what is needed.
constraints: Optional constraints to filter results.
limit: Maximum number of recommendations.
Returns:
Dictionary containing the request payload.
"""
payload: dict[str, Any] = {
"need": need,
"limit": limit,
}
if constraints:
constraints_dict: dict[str, Any] = {}
if constraints.max_latency_ms is not None:
constraints_dict["max_latency_ms"] = constraints.max_latency_ms
if constraints.required_features:
constraints_dict["required_features"] = constraints.required_features
if constraints.exclude_servers:
constraints_dict["exclude_servers"] = constraints.exclude_servers
if constraints_dict:
payload["constraints"] = constraints_dict
return payload
def _make_api_request(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Make a request to the MCP Discovery API.
Args:
payload: The request payload.
Returns:
The API response as a dictionary.
Raises:
ValueError: If the API returns an empty response.
requests.exceptions.RequestException: If the request fails.
"""
url = f"{self.base_url}/api/v1/discover"
headers = {
"Content-Type": "application/json",
}
api_key = os.environ.get("MCP_DISCOVERY_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
response = None
try:
response = requests.post(
url,
headers=headers,
json=payload,
timeout=self.timeout,
)
response.raise_for_status()
results = response.json()
if not results:
logger.error("Empty response from MCP Discovery API")
raise ValueError("Empty response from MCP Discovery API")
return results
except requests.exceptions.RequestException as e:
error_msg = f"Error making request to MCP Discovery API: {e}"
if response is not None and hasattr(response, "content"):
error_msg += (
f"\nResponse content: "
f"{response.content.decode('utf-8', errors='replace')}"
)
logger.error(error_msg)
raise
except json.JSONDecodeError as e:
if response is not None and hasattr(response, "content"):
logger.error(f"Error decoding JSON response: {e}")
logger.error(
f"Response content: "
f"{response.content.decode('utf-8', errors='replace')}"
)
else:
logger.error(
f"Error decoding JSON response: {e} (No response content available)"
)
raise
def _process_single_recommendation(
self, rec: dict[str, Any]
) -> MCPServerRecommendation | None:
"""Process a single recommendation from the API.
Args:
rec: Raw recommendation dictionary from the API.
Returns:
Processed MCPServerRecommendation or None if malformed.
"""
try:
metrics_data = rec.get("metrics", {}) if isinstance(rec, dict) else {}
metrics: MCPServerMetrics = {
"avg_latency_ms": metrics_data.get("avg_latency_ms"),
"uptime_pct": metrics_data.get("uptime_pct"),
"last_checked": metrics_data.get("last_checked"),
}
recommendation: MCPServerRecommendation = {
"server": rec.get("server", ""),
"npm_package": rec.get("npm_package", ""),
"install_command": rec.get("install_command", ""),
"confidence": rec.get("confidence", 0.0),
"description": rec.get("description", ""),
"capabilities": rec.get("capabilities", []),
"metrics": metrics,
"docs_url": rec.get("docs_url", ""),
"github_url": rec.get("github_url", ""),
}
return recommendation
except (KeyError, TypeError, AttributeError) as e:
logger.warning(f"Skipping malformed recommendation: {rec}, error: {e}")
return None
def _process_recommendations(
self, recommendations: list[dict[str, Any]]
) -> list[MCPServerRecommendation]:
"""Process and validate server recommendations.
Args:
recommendations: Raw recommendations from the API.
Returns:
List of processed MCPServerRecommendation objects.
"""
processed: list[MCPServerRecommendation] = []
for rec in recommendations:
result = self._process_single_recommendation(rec)
if result is not None:
processed.append(result)
return processed
def _format_result(self, result: MCPDiscoveryResult) -> str:
"""Format the discovery result as a human-readable string.
Args:
result: The discovery result to format.
Returns:
A formatted string representation of the result.
"""
if not result["recommendations"]:
return "No MCP servers found matching your requirements."
lines = [
f"Found {result['total_found']} MCP server(s) "
f"(query took {result['query_time_ms']}ms):\n"
]
for i, rec in enumerate(result["recommendations"], 1):
confidence_pct = rec.get("confidence", 0) * 100
lines.append(f"{i}. {rec.get('server', 'Unknown')} ({confidence_pct:.0f}% confidence)")
lines.append(f" Description: {rec.get('description', 'N/A')}")
lines.append(f" Capabilities: {', '.join(rec.get('capabilities', []))}")
lines.append(f" Install: {rec.get('install_command', 'N/A')}")
lines.append(f" NPM Package: {rec.get('npm_package', 'N/A')}")
metrics = rec.get("metrics", {})
if metrics.get("avg_latency_ms") is not None:
lines.append(f" Avg Latency: {metrics['avg_latency_ms']}ms")
if metrics.get("uptime_pct") is not None:
lines.append(f" Uptime: {metrics['uptime_pct']}%")
if rec.get("docs_url"):
lines.append(f" Docs: {rec['docs_url']}")
if rec.get("github_url"):
lines.append(f" GitHub: {rec['github_url']}")
lines.append("")
return "\n".join(lines)
def _run(self, **kwargs: Any) -> str:
"""Execute the MCP discovery operation.
Args:
**kwargs: Keyword arguments matching MCPDiscoveryToolSchema.
Returns:
A formatted string with discovered MCP servers.
Raises:
ValueError: If required parameters are missing.
"""
need: str | None = kwargs.get("need")
if not need:
raise ValueError("'need' parameter is required")
constraints_data = kwargs.get("constraints")
constraints: MCPDiscoveryConstraints | None = None
if constraints_data:
if isinstance(constraints_data, dict):
constraints = MCPDiscoveryConstraints(**constraints_data)
elif isinstance(constraints_data, MCPDiscoveryConstraints):
constraints = constraints_data
limit: int = kwargs.get("limit", 5)
payload = self._build_request_payload(need, constraints, limit)
response = self._make_api_request(payload)
recommendations = self._process_recommendations(
response.get("recommendations", [])
)
result: MCPDiscoveryResult = {
"recommendations": recommendations,
"total_found": response.get("total_found", len(recommendations)),
"query_time_ms": response.get("query_time_ms", 0),
}
return self._format_result(result)
def discover(
self,
need: str,
constraints: MCPDiscoveryConstraints | None = None,
limit: int = 5,
) -> MCPDiscoveryResult:
"""Discover MCP servers matching the given requirements.
This is a convenience method that returns structured data instead
of a formatted string.
Args:
need: Natural language description of what is needed.
constraints: Optional constraints to filter results.
limit: Maximum number of recommendations (1-10).
Returns:
MCPDiscoveryResult containing server recommendations.
Example:
```python
tool = MCPDiscoveryTool()
result = tool.discover(
need="database with authentication",
constraints=MCPDiscoveryConstraints(
max_latency_ms=200,
required_features=["auth", "realtime"]
),
limit=3
)
for rec in result["recommendations"]:
print(f"{rec['server']}: {rec['description']}")
```
"""
payload = self._build_request_payload(need, constraints, limit)
response = self._make_api_request(payload)
recommendations = self._process_recommendations(
response.get("recommendations", [])
)
return {
"recommendations": recommendations,
"total_found": response.get("total_found", len(recommendations)),
"query_time_ms": response.get("query_time_ms", 0),
}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,452 @@
"""Tests for the MCP Discovery Tool."""
import json
from unittest.mock import MagicMock, patch
import pytest
import requests
from crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool import (
MCPDiscoveryConstraints,
MCPDiscoveryResult,
MCPDiscoveryTool,
MCPDiscoveryToolSchema,
MCPServerMetrics,
MCPServerRecommendation,
)
@pytest.fixture
def mock_api_response() -> dict:
"""Create a mock API response for testing."""
return {
"recommendations": [
{
"server": "sqlite-server",
"npm_package": "@modelcontextprotocol/server-sqlite",
"install_command": "npx -y @modelcontextprotocol/server-sqlite",
"confidence": 0.38,
"description": "SQLite database server for MCP.",
"capabilities": ["sqlite", "sql", "database", "embedded"],
"metrics": {
"avg_latency_ms": 50.0,
"uptime_pct": 99.9,
"last_checked": "2026-01-17T10:30:00Z",
},
"docs_url": "https://modelcontextprotocol.io/docs/servers/sqlite",
"github_url": "https://github.com/modelcontextprotocol/servers",
},
{
"server": "postgres-server",
"npm_package": "@modelcontextprotocol/server-postgres",
"install_command": "npx -y @modelcontextprotocol/server-postgres",
"confidence": 0.33,
"description": "PostgreSQL database server for MCP.",
"capabilities": ["postgres", "sql", "database", "queries"],
"metrics": {
"avg_latency_ms": None,
"uptime_pct": None,
"last_checked": None,
},
"docs_url": "https://modelcontextprotocol.io/docs/servers/postgres",
"github_url": "https://github.com/modelcontextprotocol/servers",
},
],
"total_found": 2,
"query_time_ms": 245,
}
@pytest.fixture
def discovery_tool() -> MCPDiscoveryTool:
"""Create an MCPDiscoveryTool instance for testing."""
return MCPDiscoveryTool()
class TestMCPDiscoveryToolSchema:
"""Tests for MCPDiscoveryToolSchema."""
def test_schema_with_required_fields(self) -> None:
"""Test schema with only required fields."""
schema = MCPDiscoveryToolSchema(need="database server")
assert schema.need == "database server"
assert schema.constraints is None
assert schema.limit == 5
def test_schema_with_all_fields(self) -> None:
"""Test schema with all fields."""
constraints = MCPDiscoveryConstraints(
max_latency_ms=200,
required_features=["auth", "realtime"],
exclude_servers=["deprecated-server"],
)
schema = MCPDiscoveryToolSchema(
need="database with authentication",
constraints=constraints,
limit=3,
)
assert schema.need == "database with authentication"
assert schema.constraints is not None
assert schema.constraints.max_latency_ms == 200
assert schema.constraints.required_features == ["auth", "realtime"]
assert schema.constraints.exclude_servers == ["deprecated-server"]
assert schema.limit == 3
def test_schema_limit_validation(self) -> None:
"""Test that limit is validated to be between 1 and 10."""
with pytest.raises(ValueError):
MCPDiscoveryToolSchema(need="test", limit=0)
with pytest.raises(ValueError):
MCPDiscoveryToolSchema(need="test", limit=11)
class TestMCPDiscoveryConstraints:
"""Tests for MCPDiscoveryConstraints."""
def test_empty_constraints(self) -> None:
"""Test creating empty constraints."""
constraints = MCPDiscoveryConstraints()
assert constraints.max_latency_ms is None
assert constraints.required_features is None
assert constraints.exclude_servers is None
def test_full_constraints(self) -> None:
"""Test creating constraints with all fields."""
constraints = MCPDiscoveryConstraints(
max_latency_ms=100,
required_features=["feature1", "feature2"],
exclude_servers=["server1", "server2"],
)
assert constraints.max_latency_ms == 100
assert constraints.required_features == ["feature1", "feature2"]
assert constraints.exclude_servers == ["server1", "server2"]
class TestMCPDiscoveryTool:
"""Tests for MCPDiscoveryTool."""
def test_tool_initialization(self, discovery_tool: MCPDiscoveryTool) -> None:
"""Test tool initialization with default values."""
assert discovery_tool.name == "Discover MCP Server"
assert "MCP" in discovery_tool.description
assert discovery_tool.base_url == "https://mcp-discovery-production.up.railway.app"
assert discovery_tool.timeout == 30
def test_build_request_payload_basic(
self, discovery_tool: MCPDiscoveryTool
) -> None:
"""Test building request payload with basic parameters."""
payload = discovery_tool._build_request_payload(
need="database server",
constraints=None,
limit=5,
)
assert payload == {"need": "database server", "limit": 5}
def test_build_request_payload_with_constraints(
self, discovery_tool: MCPDiscoveryTool
) -> None:
"""Test building request payload with constraints."""
constraints = MCPDiscoveryConstraints(
max_latency_ms=200,
required_features=["auth"],
exclude_servers=["old-server"],
)
payload = discovery_tool._build_request_payload(
need="database",
constraints=constraints,
limit=3,
)
assert payload["need"] == "database"
assert payload["limit"] == 3
assert "constraints" in payload
assert payload["constraints"]["max_latency_ms"] == 200
assert payload["constraints"]["required_features"] == ["auth"]
assert payload["constraints"]["exclude_servers"] == ["old-server"]
def test_build_request_payload_partial_constraints(
self, discovery_tool: MCPDiscoveryTool
) -> None:
"""Test building request payload with partial constraints."""
constraints = MCPDiscoveryConstraints(max_latency_ms=100)
payload = discovery_tool._build_request_payload(
need="test",
constraints=constraints,
limit=5,
)
assert payload["constraints"] == {"max_latency_ms": 100}
def test_process_recommendations(
self, discovery_tool: MCPDiscoveryTool, mock_api_response: dict
) -> None:
"""Test processing recommendations from API response."""
recommendations = discovery_tool._process_recommendations(
mock_api_response["recommendations"]
)
assert len(recommendations) == 2
first_rec = recommendations[0]
assert first_rec["server"] == "sqlite-server"
assert first_rec["npm_package"] == "@modelcontextprotocol/server-sqlite"
assert first_rec["confidence"] == 0.38
assert first_rec["capabilities"] == ["sqlite", "sql", "database", "embedded"]
assert first_rec["metrics"]["avg_latency_ms"] == 50.0
assert first_rec["metrics"]["uptime_pct"] == 99.9
second_rec = recommendations[1]
assert second_rec["server"] == "postgres-server"
assert second_rec["metrics"]["avg_latency_ms"] is None
def test_process_recommendations_with_malformed_data(
self, discovery_tool: MCPDiscoveryTool
) -> None:
"""Test processing recommendations with malformed data."""
malformed_recommendations = [
{"server": "valid-server", "confidence": 0.5},
None,
{"invalid": "data"},
]
recommendations = discovery_tool._process_recommendations(
malformed_recommendations
)
assert len(recommendations) >= 1
assert recommendations[0]["server"] == "valid-server"
def test_format_result_with_recommendations(
self, discovery_tool: MCPDiscoveryTool
) -> None:
"""Test formatting results with recommendations."""
result: MCPDiscoveryResult = {
"recommendations": [
{
"server": "test-server",
"npm_package": "@test/server",
"install_command": "npx -y @test/server",
"confidence": 0.85,
"description": "A test server",
"capabilities": ["test", "demo"],
"metrics": {
"avg_latency_ms": 100.0,
"uptime_pct": 99.5,
"last_checked": "2026-01-17T10:00:00Z",
},
"docs_url": "https://example.com/docs",
"github_url": "https://github.com/test/server",
}
],
"total_found": 1,
"query_time_ms": 150,
}
formatted = discovery_tool._format_result(result)
assert "Found 1 MCP server(s)" in formatted
assert "test-server" in formatted
assert "85% confidence" in formatted
assert "A test server" in formatted
assert "test, demo" in formatted
assert "npx -y @test/server" in formatted
assert "100.0ms" in formatted
assert "99.5%" in formatted
def test_format_result_empty(self, discovery_tool: MCPDiscoveryTool) -> None:
"""Test formatting results with no recommendations."""
result: MCPDiscoveryResult = {
"recommendations": [],
"total_found": 0,
"query_time_ms": 50,
}
formatted = discovery_tool._format_result(result)
assert "No MCP servers found" in formatted
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
def test_make_api_request_success(
self,
mock_post: MagicMock,
discovery_tool: MCPDiscoveryTool,
mock_api_response: dict,
) -> None:
"""Test successful API request."""
mock_response = MagicMock()
mock_response.json.return_value = mock_api_response
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discovery_tool._make_api_request({"need": "database", "limit": 5})
assert result == mock_api_response
mock_post.assert_called_once()
call_args = mock_post.call_args
assert call_args[1]["json"] == {"need": "database", "limit": 5}
assert call_args[1]["timeout"] == 30
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
def test_make_api_request_with_api_key(
self,
mock_post: MagicMock,
discovery_tool: MCPDiscoveryTool,
mock_api_response: dict,
) -> None:
"""Test API request with API key."""
mock_response = MagicMock()
mock_response.json.return_value = mock_api_response
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
with patch.dict("os.environ", {"MCP_DISCOVERY_API_KEY": "test-key"}):
discovery_tool._make_api_request({"need": "test", "limit": 5})
call_args = mock_post.call_args
assert "Authorization" in call_args[1]["headers"]
assert call_args[1]["headers"]["Authorization"] == "Bearer test-key"
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
def test_make_api_request_empty_response(
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
) -> None:
"""Test API request with empty response."""
mock_response = MagicMock()
mock_response.json.return_value = {}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
with pytest.raises(ValueError, match="Empty response"):
discovery_tool._make_api_request({"need": "test", "limit": 5})
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
def test_make_api_request_network_error(
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
) -> None:
"""Test API request with network error."""
mock_post.side_effect = requests.exceptions.ConnectionError("Network error")
with pytest.raises(requests.exceptions.ConnectionError):
discovery_tool._make_api_request({"need": "test", "limit": 5})
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
def test_make_api_request_json_decode_error(
self, mock_post: MagicMock, discovery_tool: MCPDiscoveryTool
) -> None:
"""Test API request with JSON decode error."""
mock_response = MagicMock()
mock_response.json.side_effect = json.JSONDecodeError("Error", "", 0)
mock_response.raise_for_status.return_value = None
mock_response.content = b"invalid json"
mock_post.return_value = mock_response
with pytest.raises(json.JSONDecodeError):
discovery_tool._make_api_request({"need": "test", "limit": 5})
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
def test_run_success(
self,
mock_post: MagicMock,
discovery_tool: MCPDiscoveryTool,
mock_api_response: dict,
) -> None:
"""Test successful _run execution."""
mock_response = MagicMock()
mock_response.json.return_value = mock_api_response
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discovery_tool._run(need="database server")
assert "sqlite-server" in result
assert "postgres-server" in result
assert "Found 2 MCP server(s)" in result
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
def test_run_with_constraints(
self,
mock_post: MagicMock,
discovery_tool: MCPDiscoveryTool,
mock_api_response: dict,
) -> None:
"""Test _run with constraints."""
mock_response = MagicMock()
mock_response.json.return_value = mock_api_response
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discovery_tool._run(
need="database",
constraints={"max_latency_ms": 100, "required_features": ["sql"]},
limit=3,
)
assert "sqlite-server" in result
call_args = mock_post.call_args
payload = call_args[1]["json"]
assert payload["constraints"]["max_latency_ms"] == 100
assert payload["constraints"]["required_features"] == ["sql"]
assert payload["limit"] == 3
def test_run_missing_need_parameter(
self, discovery_tool: MCPDiscoveryTool
) -> None:
"""Test _run with missing need parameter."""
with pytest.raises(ValueError, match="'need' parameter is required"):
discovery_tool._run()
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
def test_discover_method(
self,
mock_post: MagicMock,
discovery_tool: MCPDiscoveryTool,
mock_api_response: dict,
) -> None:
"""Test the discover convenience method."""
mock_response = MagicMock()
mock_response.json.return_value = mock_api_response
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discovery_tool.discover(
need="database",
constraints=MCPDiscoveryConstraints(max_latency_ms=200),
limit=3,
)
assert isinstance(result, dict)
assert "recommendations" in result
assert "total_found" in result
assert "query_time_ms" in result
assert len(result["recommendations"]) == 2
assert result["total_found"] == 2
@patch("crewai_tools.tools.mcp_discovery_tool.mcp_discovery_tool.requests.post")
def test_discover_returns_structured_data(
self,
mock_post: MagicMock,
discovery_tool: MCPDiscoveryTool,
mock_api_response: dict,
) -> None:
"""Test that discover returns properly structured data."""
mock_response = MagicMock()
mock_response.json.return_value = mock_api_response
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discovery_tool.discover(need="database")
first_rec = result["recommendations"][0]
assert "server" in first_rec
assert "npm_package" in first_rec
assert "install_command" in first_rec
assert "confidence" in first_rec
assert "description" in first_rec
assert "capabilities" in first_rec
assert "metrics" in first_rec
assert "docs_url" in first_rec
assert "github_url" in first_rec
class TestMCPDiscoveryToolIntegration:
"""Integration tests for MCPDiscoveryTool (requires network)."""
@pytest.mark.skip(reason="Integration test - requires network access")
def test_real_api_call(self) -> None:
"""Test actual API call to MCP Discovery service."""
tool = MCPDiscoveryTool()
result = tool._run(need="database", limit=3)
assert "MCP server" in result or "No MCP servers found" in result

View File

@@ -3,9 +3,10 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, TypedDict
from typing import TYPE_CHECKING, Any, TypedDict
import uuid
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
@@ -20,7 +21,10 @@ from a2a.types import (
from typing_extensions import NotRequired
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2AResponseReceivedEvent,
)
if TYPE_CHECKING:
@@ -55,7 +59,8 @@ class TaskStateResult(TypedDict):
history: list[Message]
result: NotRequired[str]
error: NotRequired[str]
agent_card: NotRequired[AgentCard]
agent_card: NotRequired[dict[str, Any]]
a2a_agent_name: NotRequired[str | None]
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
@@ -131,50 +136,69 @@ def process_task_state(
is_multiturn: bool,
agent_role: str | None,
result_parts: list[str] | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
is_final: bool = True,
) -> TaskStateResult | None:
"""Process A2A task state and return result dictionary.
Shared logic for both polling and streaming handlers.
Args:
a2a_task: The A2A task to process
new_messages: List to collect messages (modified in place)
agent_card: The agent card
turn_number: Current turn number
is_multiturn: Whether multi-turn conversation
agent_role: Agent role for logging
a2a_task: The A2A task to process.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
turn_number: Current turn number.
is_multiturn: Whether multi-turn conversation.
agent_role: Agent role for logging.
result_parts: Accumulated result parts (streaming passes accumulated,
polling passes None to extract from task)
polling passes None to extract from task).
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
from_task: Optional CrewAI Task for event metadata.
from_agent: Optional CrewAI Agent for event metadata.
is_final: Whether this is the final response in the stream.
Returns:
Result dictionary if terminal/actionable state, None otherwise
Result dictionary if terminal/actionable state, None otherwise.
"""
should_extract = result_parts is None
if result_parts is None:
result_parts = []
if a2a_task.status.state == TaskState.completed:
if should_extract:
if not result_parts:
extracted_parts = extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = " ".join(result_parts) if result_parts else ""
message_id = None
if a2a_task.status and a2a_task.status.message:
message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=message_id,
is_multiturn=is_multiturn,
status="completed",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.completed,
agent_card=agent_card,
agent_card=agent_card.model_dump(exclude_none=True),
result=response_text,
history=new_messages,
)
@@ -194,14 +218,24 @@ def process_task_state(
)
new_messages.append(agent_message)
input_message_id = None
if a2a_task.status and a2a_task.status.message:
input_message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=input_message_id,
is_multiturn=is_multiturn,
status="input_required",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
@@ -209,7 +243,7 @@ def process_task_state(
status=TaskState.input_required,
error=response_text,
history=new_messages,
agent_card=agent_card,
agent_card=agent_card.model_dump(exclude_none=True),
)
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
@@ -248,6 +282,11 @@ async def send_message_and_get_task_id(
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
context_id: str | None = None,
) -> str | TaskStateResult:
"""Send message and process initial response.
@@ -262,6 +301,11 @@ async def send_message_and_get_task_id(
turn_number: Current turn number
is_multiturn: Whether multi-turn conversation
agent_role: Agent role for logging
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
endpoint: Optional A2A endpoint URL.
a2a_agent_name: Optional A2A agent name.
context_id: Optional A2A context ID for correlation.
Returns:
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
@@ -280,9 +324,16 @@ async def send_message_and_get_task_id(
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=event.context_id,
message_id=event.message_id,
is_multiturn=is_multiturn,
status="completed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
@@ -290,7 +341,7 @@ async def send_message_and_get_task_id(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card,
agent_card=agent_card.model_dump(exclude_none=True),
)
if isinstance(event, tuple):
@@ -304,6 +355,10 @@ async def send_message_and_get_task_id(
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
@@ -316,6 +371,99 @@ async def send_message_and_get_task_id(
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during send_message: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:

View File

@@ -22,6 +22,13 @@ class BaseHandlerKwargs(TypedDict, total=False):
turn_number: int
is_multiturn: bool
agent_role: str | None
context_id: str | None
task_id: str | None
endpoint: str | None
agent_branch: Any
a2a_agent_name: str | None
from_task: Any
from_agent: Any
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
@@ -29,8 +36,6 @@ class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
polling_interval: float
polling_timeout: float
endpoint: str
agent_branch: Any
history_length: int
max_polls: int | None
@@ -38,9 +43,6 @@ class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for streaming handler."""
context_id: str | None
task_id: str | None
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for push notification handler."""
@@ -49,7 +51,6 @@ class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
result_store: PushNotificationResultStore
polling_timeout: float
polling_interval: float
agent_branch: Any
class PushNotificationResultStore(Protocol):

View File

@@ -31,6 +31,7 @@ from crewai.a2a.task_helpers import (
from crewai.a2a.updates.base import PollingHandlerKwargs
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2APollingStartedEvent,
A2APollingStatusEvent,
A2AResponseReceivedEvent,
@@ -49,23 +50,33 @@ async def _poll_task_until_complete(
agent_branch: Any | None = None,
history_length: int = 100,
max_polls: int | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
context_id: str | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> A2ATask:
"""Poll task status until terminal state reached.
Args:
client: A2A client instance
task_id: Task ID to poll
polling_interval: Seconds between poll attempts
polling_timeout: Max seconds before timeout
agent_branch: Agent tree branch for logging
history_length: Number of messages to retrieve per poll
max_polls: Max number of poll attempts (None = unlimited)
client: A2A client instance.
task_id: Task ID to poll.
polling_interval: Seconds between poll attempts.
polling_timeout: Max seconds before timeout.
agent_branch: Agent tree branch for logging.
history_length: Number of messages to retrieve per poll.
max_polls: Max number of poll attempts (None = unlimited).
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
context_id: A2A context ID for correlation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
Returns:
Final task object in terminal state
Final task object in terminal state.
Raises:
A2APollingTimeoutError: If polling exceeds timeout or max_polls
A2APollingTimeoutError: If polling exceeds timeout or max_polls.
"""
start_time = time.monotonic()
poll_count = 0
@@ -77,13 +88,19 @@ async def _poll_task_until_complete(
)
elapsed = time.monotonic() - start_time
effective_context_id = task.context_id or context_id
crewai_event_bus.emit(
agent_branch,
A2APollingStatusEvent(
task_id=task_id,
context_id=effective_context_id,
state=str(task.status.state.value) if task.status.state else "unknown",
elapsed_seconds=elapsed,
poll_count=poll_count,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
@@ -137,6 +154,9 @@ class PollingHandler:
max_polls = kwargs.get("max_polls")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
try:
result_or_task_id = await send_message_and_get_task_id(
@@ -146,6 +166,11 @@ class PollingHandler:
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
)
if not isinstance(result_or_task_id, str):
@@ -157,8 +182,12 @@ class PollingHandler:
agent_branch,
A2APollingStartedEvent(
task_id=task_id,
context_id=context_id,
polling_interval=polling_interval,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
@@ -170,6 +199,11 @@ class PollingHandler:
agent_branch=agent_branch,
history_length=history_length,
max_polls=max_polls,
from_task=from_task,
from_agent=from_agent,
context_id=context_id,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
)
result = process_task_state(
@@ -179,6 +213,10 @@ class PollingHandler:
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
@@ -206,9 +244,15 @@ class PollingHandler:
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
@@ -229,14 +273,83 @@ class PollingHandler:
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="polling",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during polling: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="polling",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(

View File

@@ -29,6 +29,7 @@ from crewai.a2a.updates.base import (
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2APushNotificationRegisteredEvent,
A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
@@ -48,6 +49,11 @@ async def _wait_for_push_result(
timeout: float,
poll_interval: float,
agent_branch: Any | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
context_id: str | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> A2ATask | None:
"""Wait for push notification result.
@@ -57,6 +63,11 @@ async def _wait_for_push_result(
timeout: Max seconds to wait.
poll_interval: Seconds between polling attempts.
agent_branch: Agent tree branch for logging.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
context_id: A2A context ID for correlation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent.
Returns:
Final task object, or None if timeout.
@@ -72,7 +83,12 @@ async def _wait_for_push_result(
agent_branch,
A2APushNotificationTimeoutEvent(
task_id=task_id,
context_id=context_id,
timeout_seconds=timeout,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
@@ -115,18 +131,56 @@ class PushNotificationHandler:
agent_role = kwargs.get("agent_role")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
endpoint = kwargs.get("endpoint")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
if config is None:
error_msg = (
"PushNotificationConfig is required for push notification handler"
)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=error_msg,
error_type="configuration_error",
a2a_agent_name=a2a_agent_name,
operation="push_notification",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error="PushNotificationConfig is required for push notification handler",
error=error_msg,
history=new_messages,
)
if result_store is None:
error_msg = (
"PushNotificationResultStore is required for push notification handler"
)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=error_msg,
error_type="configuration_error",
a2a_agent_name=a2a_agent_name,
operation="push_notification",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error="PushNotificationResultStore is required for push notification handler",
error=error_msg,
history=new_messages,
)
@@ -138,6 +192,11 @@ class PushNotificationHandler:
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
)
if not isinstance(result_or_task_id, str):
@@ -149,7 +208,12 @@ class PushNotificationHandler:
agent_branch,
A2APushNotificationRegisteredEvent(
task_id=task_id,
context_id=context_id,
callback_url=str(config.url),
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
@@ -165,6 +229,11 @@ class PushNotificationHandler:
timeout=polling_timeout,
poll_interval=polling_interval,
agent_branch=agent_branch,
from_task=from_task,
from_agent=from_agent,
context_id=context_id,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
)
if final_task is None:
@@ -181,6 +250,10 @@ class PushNotificationHandler:
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
@@ -203,14 +276,83 @@ class PushNotificationHandler:
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="push_notification",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during push notification: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="push_notification",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(

View File

@@ -26,7 +26,13 @@ from crewai.a2a.task_helpers import (
)
from crewai.a2a.updates.base import StreamingHandlerKwargs
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
from crewai.events.types.a2a_events import (
A2AArtifactReceivedEvent,
A2AConnectionErrorEvent,
A2AResponseReceivedEvent,
A2AStreamingChunkEvent,
A2AStreamingStartedEvent,
)
class StreamingHandler:
@@ -57,19 +63,57 @@ class StreamingHandler:
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
endpoint = kwargs.get("endpoint")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
agent_branch = kwargs.get("agent_branch")
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
chunk_index = 0
crewai_event_bus.emit(
agent_branch,
A2AStreamingStartedEvent(
task_id=task_id,
context_id=context_id,
endpoint=endpoint or "",
a2a_agent_name=a2a_agent_name,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
),
)
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
message_context_id = event.context_id or context_id
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
crewai_event_bus.emit(
agent_branch,
A2AStreamingChunkEvent(
task_id=event.task_id or task_id,
context_id=message_context_id,
chunk=text,
chunk_index=chunk_index,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
turn_number=turn_number,
is_multiturn=is_multiturn,
from_task=from_task,
from_agent=from_agent,
),
)
chunk_index += 1
elif isinstance(event, tuple):
a2a_task, update = event
@@ -81,10 +125,51 @@ class StreamingHandler:
for part in artifact.parts
if part.root.kind == "text"
)
artifact_size = None
if artifact.parts:
artifact_size = sum(
len(p.root.text.encode("utf-8"))
if p.root.kind == "text"
else len(getattr(p.root, "data", b""))
for p in artifact.parts
)
effective_context_id = a2a_task.context_id or context_id
crewai_event_bus.emit(
agent_branch,
A2AArtifactReceivedEvent(
task_id=a2a_task.id,
artifact_id=artifact.artifact_id,
artifact_name=artifact.name,
artifact_description=artifact.description,
mime_type=artifact.parts[0].root.kind
if artifact.parts
else None,
size_bytes=artifact_size,
append=update.append or False,
last_chunk=update.last_chunk or False,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=effective_context_id,
turn_number=turn_number,
is_multiturn=is_multiturn,
from_task=from_task,
from_agent=from_agent,
),
)
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
if (
update.status
and update.status.message
and update.status.message.parts
):
result_parts.extend(
part.root.text
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
)
if (
not is_final_update
@@ -101,6 +186,11 @@ class StreamingHandler:
is_multiturn=is_multiturn,
agent_role=agent_role,
result_parts=result_parts,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
is_final=is_final_update,
)
if final_result:
break
@@ -118,13 +208,82 @@ class StreamingHandler:
new_messages.append(error_message)
crewai_event_bus.emit(
None,
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="streaming",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during streaming: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="streaming",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
@@ -136,7 +295,23 @@ class StreamingHandler:
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:
await aclose()
try:
await aclose()
except Exception as close_error:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(close_error),
error_type="stream_close_error",
a2a_agent_name=a2a_agent_name,
operation="stream_close",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
if final_result:
return final_result
@@ -145,5 +320,5 @@ class StreamingHandler:
status=TaskState.completed,
result=" ".join(result_parts) if result_parts else "",
history=new_messages,
agent_card=agent_card,
agent_card=agent_card.model_dump(exclude_none=True),
)

View File

@@ -23,6 +23,12 @@ from crewai.a2a.auth.utils import (
)
from crewai.a2a.config import A2AServerConfig
from crewai.crew import Crew
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AAgentCardFetchedEvent,
A2AAuthenticationFailedEvent,
A2AConnectionErrorEvent,
)
if TYPE_CHECKING:
@@ -183,6 +189,8 @@ async def _afetch_agent_card_impl(
timeout: int,
) -> AgentCard:
"""Internal async implementation of AgentCard fetching."""
start_time = time.perf_counter()
if "/.well-known/agent-card.json" in endpoint:
base_url = endpoint.replace("/.well-known/agent-card.json", "")
agent_card_path = "/.well-known/agent-card.json"
@@ -217,9 +225,29 @@ async def _afetch_agent_card_impl(
)
response.raise_for_status()
return AgentCard.model_validate(response.json())
agent_card = AgentCard.model_validate(response.json())
fetch_time_ms = (time.perf_counter() - start_time) * 1000
agent_card_dict = agent_card.model_dump(exclude_none=True)
crewai_event_bus.emit(
None,
A2AAgentCardFetchedEvent(
endpoint=endpoint,
a2a_agent_name=agent_card.name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
cached=False,
fetch_time_ms=fetch_time_ms,
),
)
return agent_card
except httpx.HTTPStatusError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
response_body = e.response.text[:1000] if e.response.text else None
if e.response.status_code == 401:
error_details = ["Authentication failed"]
www_auth = e.response.headers.get("WWW-Authenticate")
@@ -228,7 +256,93 @@ async def _afetch_agent_card_impl(
if not auth:
error_details.append("No auth scheme provided")
msg = " | ".join(error_details)
auth_type = type(auth).__name__ if auth else None
crewai_event_bus.emit(
None,
A2AAuthenticationFailedEvent(
endpoint=endpoint,
auth_type=auth_type,
error=msg,
status_code=401,
metadata={
"elapsed_ms": elapsed_ms,
"response_body": response_body,
"www_authenticate": www_auth,
"request_url": str(e.request.url),
},
),
)
raise A2AClientHTTPError(401, msg) from e
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="http_error",
status_code=e.response.status_code,
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"response_body": response_body,
"request_url": str(e.request.url),
},
),
)
raise
except httpx.TimeoutException as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="timeout",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"timeout_config": timeout,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
except httpx.ConnectError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="connection_error",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
except httpx.RequestError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="request_error",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise

View File

@@ -88,6 +88,9 @@ def execute_a2a_delegation(
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
updates: UpdateConfig | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent synchronously.
@@ -129,6 +132,9 @@ def execute_a2a_delegation(
response_model: Optional Pydantic model for structured outputs.
turn_number: Optional turn number for multi-turn conversations.
updates: Update mechanism config from A2AConfig.updates.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
@@ -156,10 +162,16 @@ def execute_a2a_delegation(
transport_protocol=transport_protocol,
turn_number=turn_number,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
)
)
finally:
loop.close()
try:
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
async def aexecute_a2a_delegation(
@@ -181,6 +193,9 @@ async def aexecute_a2a_delegation(
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
updates: UpdateConfig | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent asynchronously.
@@ -222,6 +237,9 @@ async def aexecute_a2a_delegation(
response_model: Optional Pydantic model for structured outputs.
turn_number: Optional turn number for multi-turn conversations.
updates: Update mechanism config from A2AConfig.updates.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
@@ -233,17 +251,6 @@ async def aexecute_a2a_delegation(
if turn_number is None:
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
crewai_event_bus.emit(
agent_branch,
A2ADelegationStartedEvent(
endpoint=endpoint,
task_description=task_description,
agent_id=agent_id,
is_multiturn=is_multiturn,
turn_number=turn_number,
),
)
result = await _aexecute_a2a_delegation_impl(
endpoint=endpoint,
auth=auth,
@@ -264,15 +271,28 @@ async def aexecute_a2a_delegation(
response_model=response_model,
updates=updates,
transport_protocol=transport_protocol,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
)
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status=result["status"],
result=result.get("result"),
error=result.get("error"),
context_id=context_id,
is_multiturn=is_multiturn,
endpoint=endpoint,
a2a_agent_name=result.get("a2a_agent_name"),
agent_card=agent_card_data,
provider=agent_card_data.get("provider"),
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
@@ -299,6 +319,9 @@ async def _aexecute_a2a_delegation_impl(
agent_role: str | None,
response_model: type[BaseModel] | None,
updates: UpdateConfig | None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
) -> TaskStateResult:
"""Internal async implementation of A2A delegation."""
if auth:
@@ -331,6 +354,28 @@ async def _aexecute_a2a_delegation_impl(
if agent_card.name:
a2a_agent_name = agent_card.name
agent_card_dict = agent_card.model_dump(exclude_none=True)
crewai_event_bus.emit(
agent_branch,
A2ADelegationStartedEvent(
endpoint=endpoint,
task_description=task_description,
agent_id=agent_id or endpoint,
context_id=context_id,
is_multiturn=is_multiturn,
turn_number=turn_number,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
if turn_number == 1:
agent_id_for_event = agent_id or endpoint
crewai_event_bus.emit(
@@ -338,7 +383,17 @@ async def _aexecute_a2a_delegation_impl(
A2AConversationStartedEvent(
agent_id=agent_id_for_event,
endpoint=endpoint,
context_id=context_id,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
@@ -364,6 +419,10 @@ async def _aexecute_a2a_delegation_impl(
}
)
message_metadata = metadata.copy() if metadata else {}
if skill_id:
message_metadata["skill_id"] = skill_id
message = Message(
role=Role.user,
message_id=str(uuid.uuid4()),
@@ -371,7 +430,7 @@ async def _aexecute_a2a_delegation_impl(
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
metadata=message_metadata if message_metadata else None,
extensions=extensions,
)
@@ -381,8 +440,17 @@ async def _aexecute_a2a_delegation_impl(
A2AMessageSentEvent(
message=message_text,
turn_number=turn_number,
context_id=context_id,
message_id=message.message_id,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
skill_id=skill_id,
metadata=message_metadata if message_metadata else None,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
@@ -397,6 +465,9 @@ async def _aexecute_a2a_delegation_impl(
"task_id": task_id,
"endpoint": endpoint,
"agent_branch": agent_branch,
"a2a_agent_name": a2a_agent_name,
"from_task": from_task,
"from_agent": from_agent,
}
if isinstance(updates, PollingConfig):
@@ -434,13 +505,16 @@ async def _aexecute_a2a_delegation_impl(
use_polling=use_polling,
push_notification_config=push_config_for_client,
) as client:
return await handler.execute(
result = await handler.execute(
client=client,
message=message,
new_messages=new_messages,
agent_card=agent_card,
**handler_kwargs,
)
result["a2a_agent_name"] = a2a_agent_name
result["agent_card"] = agent_card.model_dump(exclude_none=True)
return result
@asynccontextmanager

View File

@@ -3,11 +3,14 @@
from __future__ import annotations
import asyncio
import base64
from collections.abc import Callable, Coroutine
from datetime import datetime
from functools import wraps
import logging
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from urllib.parse import urlparse
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
@@ -45,7 +48,14 @@ T = TypeVar("T")
def _parse_redis_url(url: str) -> dict[str, Any]:
from urllib.parse import urlparse
"""Parse a Redis URL into aiocache configuration.
Args:
url: Redis connection URL (e.g., redis://localhost:6379/0).
Returns:
Configuration dict for aiocache.RedisCache.
"""
parsed = urlparse(url)
config: dict[str, Any] = {
@@ -127,7 +137,7 @@ def cancellable(
async for message in pubsub.listen():
if message["type"] == "message":
return True
except Exception as e:
except (OSError, ConnectionError) as e:
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
return await poll_for_cancel()
return False
@@ -183,7 +193,12 @@ async def execute(
msg = "task_id and context_id are required"
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(a2a_task_id="", a2a_context_id="", error=msg),
A2AServerTaskFailedEvent(
task_id="",
context_id="",
error=msg,
from_agent=agent,
),
)
raise ServerError(InvalidParamsError(message=msg)) from None
@@ -195,7 +210,12 @@ async def execute(
crewai_event_bus.emit(
agent,
A2AServerTaskStartedEvent(a2a_task_id=task_id, a2a_context_id=context_id),
A2AServerTaskStartedEvent(
task_id=task_id,
context_id=context_id,
from_task=task,
from_agent=agent,
),
)
try:
@@ -215,20 +235,33 @@ async def execute(
crewai_event_bus.emit(
agent,
A2AServerTaskCompletedEvent(
a2a_task_id=task_id, a2a_context_id=context_id, result=str(result)
task_id=task_id,
context_id=context_id,
result=str(result),
from_task=task,
from_agent=agent,
),
)
except asyncio.CancelledError:
crewai_event_bus.emit(
agent,
A2AServerTaskCanceledEvent(a2a_task_id=task_id, a2a_context_id=context_id),
A2AServerTaskCanceledEvent(
task_id=task_id,
context_id=context_id,
from_task=task,
from_agent=agent,
),
)
raise
except Exception as e:
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(
a2a_task_id=task_id, a2a_context_id=context_id, error=str(e)
task_id=task_id,
context_id=context_id,
error=str(e),
from_task=task,
from_agent=agent,
),
)
raise ServerError(
@@ -282,3 +315,85 @@ async def cancel(
context.current_task.status = TaskStatus(state=TaskState.canceled)
return context.current_task
return None
def list_tasks(
tasks: list[A2ATask],
context_id: str | None = None,
status: TaskState | None = None,
status_timestamp_after: datetime | None = None,
page_size: int = 50,
page_token: str | None = None,
history_length: int | None = None,
include_artifacts: bool = False,
) -> tuple[list[A2ATask], str | None, int]:
"""Filter and paginate A2A tasks.
Provides filtering by context, status, and timestamp, along with
cursor-based pagination. This is a pure utility function that operates
on an in-memory list of tasks - storage retrieval is handled separately.
Args:
tasks: All tasks to filter.
context_id: Filter by context ID to get tasks in a conversation.
status: Filter by task state (e.g., completed, working).
status_timestamp_after: Filter to tasks updated after this time.
page_size: Maximum tasks per page (default 50).
page_token: Base64-encoded cursor from previous response.
history_length: Limit history messages per task (None = full history).
include_artifacts: Whether to include task artifacts (default False).
Returns:
Tuple of (filtered_tasks, next_page_token, total_count).
- filtered_tasks: Tasks matching filters, paginated and trimmed.
- next_page_token: Token for next page, or None if no more pages.
- total_count: Total number of tasks matching filters (before pagination).
"""
filtered: list[A2ATask] = []
for task in tasks:
if context_id and task.context_id != context_id:
continue
if status and task.status.state != status:
continue
if status_timestamp_after and task.status.timestamp:
ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00"))
if ts <= status_timestamp_after:
continue
filtered.append(task)
def get_timestamp(t: A2ATask) -> datetime:
"""Extract timestamp from task status for sorting."""
if t.status.timestamp is None:
return datetime.min
return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00"))
filtered.sort(key=get_timestamp, reverse=True)
total = len(filtered)
start = 0
if page_token:
try:
cursor_id = base64.b64decode(page_token).decode()
for idx, task in enumerate(filtered):
if task.id == cursor_id:
start = idx + 1
break
except (ValueError, UnicodeDecodeError):
pass
page = filtered[start : start + page_size]
result: list[A2ATask] = []
for task in page:
task = task.model_copy(deep=True)
if history_length is not None and task.history:
task.history = task.history[-history_length:]
if not include_artifacts:
task.artifacts = None
result.append(task)
next_token: str | None = None
if result and len(result) == page_size:
next_token = base64.b64encode(result[-1].id.encode()).decode()
return result, next_token, total

View File

@@ -6,9 +6,10 @@ Wraps agent classes with A2A delegation capabilities.
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from collections.abc import Callable, Coroutine, Mapping
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
import json
from types import MethodType
from typing import TYPE_CHECKING, Any
@@ -189,7 +190,7 @@ def _execute_task_with_a2a(
a2a_agents: list[A2AConfig | A2AClientConfig],
original_fn: Callable[..., str],
task: Task,
agent_response_model: type[BaseModel],
agent_response_model: type[BaseModel] | None,
context: str | None,
tools: list[BaseTool] | None,
extension_registry: ExtensionRegistry,
@@ -277,7 +278,7 @@ def _execute_task_with_a2a(
def _augment_prompt_with_a2a(
a2a_agents: list[A2AConfig | A2AClientConfig],
task_description: str,
agent_cards: dict[str, AgentCard],
agent_cards: Mapping[str, AgentCard | dict[str, Any]],
conversation_history: list[Message] | None = None,
turn_num: int = 0,
max_turns: int | None = None,
@@ -309,7 +310,15 @@ def _augment_prompt_with_a2a(
for config in a2a_agents:
if config.endpoint in agent_cards:
card = agent_cards[config.endpoint]
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
if isinstance(card, dict):
filtered = {
k: v
for k, v in card.items()
if k in {"description", "url", "skills"} and v is not None
}
agents_text += f"\n{json.dumps(filtered, indent=2)}\n"
else:
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
failed_agents = failed_agents or {}
if failed_agents:
@@ -377,7 +386,7 @@ IMPORTANT: You have the ability to delegate this task to remote A2A agents.
def _parse_agent_response(
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel] | None
) -> BaseModel | str | dict[str, Any]:
"""Parse LLM output as AgentResponse or return raw agent response."""
if agent_response_model:
@@ -394,6 +403,11 @@ def _parse_agent_response(
def _handle_max_turns_exceeded(
conversation_history: list[Message],
max_turns: int,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
) -> str:
"""Handle the case when max turns is exceeded.
@@ -421,6 +435,11 @@ def _handle_max_turns_exceeded(
final_result=final_message,
error=None,
total_turns=max_turns,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card,
),
)
return final_message
@@ -432,6 +451,11 @@ def _handle_max_turns_exceeded(
final_result=None,
error=f"Conversation exceeded maximum turns ({max_turns})",
total_turns=max_turns,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card,
),
)
raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})")
@@ -442,7 +466,12 @@ def _process_response_result(
disable_structured_output: bool,
turn_num: int,
agent_role: str,
agent_response_model: type[BaseModel],
agent_response_model: type[BaseModel] | None,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
) -> tuple[str | None, str | None]:
"""Process LLM response and determine next action.
@@ -461,6 +490,10 @@ def _process_response_result(
turn_number=final_turn_number,
is_multiturn=True,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
),
)
crewai_event_bus.emit(
@@ -470,6 +503,11 @@ def _process_response_result(
final_result=result_text,
error=None,
total_turns=final_turn_number,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card,
),
)
return result_text, None
@@ -490,6 +528,10 @@ def _process_response_result(
turn_number=final_turn_number,
is_multiturn=True,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
),
)
crewai_event_bus.emit(
@@ -499,6 +541,11 @@ def _process_response_result(
final_result=str(llm_response.message),
error=None,
total_turns=final_turn_number,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card,
),
)
return str(llm_response.message), None
@@ -510,13 +557,15 @@ def _process_response_result(
def _prepare_agent_cards_dict(
a2a_result: TaskStateResult,
agent_id: str,
agent_cards: dict[str, AgentCard] | None,
) -> dict[str, AgentCard]:
agent_cards: Mapping[str, AgentCard | dict[str, Any]] | None,
) -> dict[str, AgentCard | dict[str, Any]]:
"""Prepare agent cards dictionary from result and existing cards.
Shared logic for both sync and async response handlers.
"""
agent_cards_dict = agent_cards or {}
agent_cards_dict: dict[str, AgentCard | dict[str, Any]] = (
dict(agent_cards) if agent_cards else {}
)
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
agent_cards_dict[agent_id] = a2a_result["agent_card"]
return agent_cards_dict
@@ -529,7 +578,7 @@ def _prepare_delegation_context(
original_task_description: str | None,
) -> tuple[
list[A2AConfig | A2AClientConfig],
type[BaseModel],
type[BaseModel] | None,
str,
str,
A2AConfig | A2AClientConfig,
@@ -598,6 +647,11 @@ def _handle_task_completion(
reference_task_ids: list[str],
agent_config: A2AConfig | A2AClientConfig,
turn_num: int,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
) -> tuple[str | None, str | None, list[str]]:
"""Handle task completion state including reference task updates.
@@ -624,6 +678,11 @@ def _handle_task_completion(
final_result=result_text,
error=None,
total_turns=final_turn_number,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card,
),
)
return str(result_text), task_id_config, reference_task_ids
@@ -645,8 +704,11 @@ def _handle_agent_response_and_continue(
original_fn: Callable[..., str],
context: str | None,
tools: list[BaseTool] | None,
agent_response_model: type[BaseModel],
agent_response_model: type[BaseModel] | None,
remote_task_completed: bool = False,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
) -> tuple[str | None, str | None]:
"""Handle A2A result and get CrewAI agent's response.
@@ -698,6 +760,11 @@ def _handle_agent_response_and_continue(
turn_num=turn_num,
agent_role=self.role,
agent_response_model=agent_response_model,
from_task=task,
from_agent=self,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card,
)
@@ -750,6 +817,12 @@ def _delegate_to_a2a(
conversation_history: list[Message] = []
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
current_agent_card_dict = (
current_agent_card.model_dump() if current_agent_card else None
)
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
try:
for turn_num in range(max_turns):
console_formatter = getattr(crewai_event_bus, "_console", None)
@@ -777,6 +850,8 @@ def _delegate_to_a2a(
turn_number=turn_num + 1,
updates=agent_config.updates,
transport_protocol=agent_config.transport_protocol,
from_task=task,
from_agent=self,
)
conversation_history = a2a_result.get("history", [])
@@ -797,6 +872,11 @@ def _delegate_to_a2a(
reference_task_ids,
agent_config,
turn_num,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
)
)
if trusted_result is not None:
@@ -818,6 +898,9 @@ def _delegate_to_a2a(
tools=tools,
agent_response_model=agent_response_model,
remote_task_completed=(a2a_result["status"] == TaskState.completed),
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
)
if final_result is not None:
@@ -846,6 +929,9 @@ def _delegate_to_a2a(
tools=tools,
agent_response_model=agent_response_model,
remote_task_completed=False,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
)
if final_result is not None:
@@ -862,11 +948,24 @@ def _delegate_to_a2a(
final_result=None,
error=error_msg,
total_turns=turn_num + 1,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
),
)
return f"A2A delegation failed: {error_msg}"
return _handle_max_turns_exceeded(conversation_history, max_turns)
return _handle_max_turns_exceeded(
conversation_history,
max_turns,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
)
finally:
task.description = original_task_description
@@ -916,7 +1015,7 @@ async def _aexecute_task_with_a2a(
a2a_agents: list[A2AConfig | A2AClientConfig],
original_fn: Callable[..., Coroutine[Any, Any, str]],
task: Task,
agent_response_model: type[BaseModel],
agent_response_model: type[BaseModel] | None,
context: str | None,
tools: list[BaseTool] | None,
extension_registry: ExtensionRegistry,
@@ -1001,8 +1100,11 @@ async def _ahandle_agent_response_and_continue(
original_fn: Callable[..., Coroutine[Any, Any, str]],
context: str | None,
tools: list[BaseTool] | None,
agent_response_model: type[BaseModel],
agent_response_model: type[BaseModel] | None,
remote_task_completed: bool = False,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
) -> tuple[str | None, str | None]:
"""Async version of _handle_agent_response_and_continue."""
agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards)
@@ -1032,6 +1134,11 @@ async def _ahandle_agent_response_and_continue(
turn_num=turn_num,
agent_role=self.role,
agent_response_model=agent_response_model,
from_task=task,
from_agent=self,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card,
)
@@ -1066,6 +1173,12 @@ async def _adelegate_to_a2a(
conversation_history: list[Message] = []
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
current_agent_card_dict = (
current_agent_card.model_dump() if current_agent_card else None
)
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
try:
for turn_num in range(max_turns):
console_formatter = getattr(crewai_event_bus, "_console", None)
@@ -1093,6 +1206,8 @@ async def _adelegate_to_a2a(
turn_number=turn_num + 1,
transport_protocol=agent_config.transport_protocol,
updates=agent_config.updates,
from_task=task,
from_agent=self,
)
conversation_history = a2a_result.get("history", [])
@@ -1113,6 +1228,11 @@ async def _adelegate_to_a2a(
reference_task_ids,
agent_config,
turn_num,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
)
)
if trusted_result is not None:
@@ -1134,6 +1254,9 @@ async def _adelegate_to_a2a(
tools=tools,
agent_response_model=agent_response_model,
remote_task_completed=(a2a_result["status"] == TaskState.completed),
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
)
if final_result is not None:
@@ -1161,6 +1284,9 @@ async def _adelegate_to_a2a(
context=context,
tools=tools,
agent_response_model=agent_response_model,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
)
if final_result is not None:
@@ -1177,11 +1303,24 @@ async def _adelegate_to_a2a(
final_result=None,
error=error_msg,
total_turns=turn_num + 1,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
),
)
return f"A2A delegation failed: {error_msg}"
return _handle_max_turns_exceeded(conversation_history, max_turns)
return _handle_max_turns_exceeded(
conversation_history,
max_turns,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
)
finally:
task.description = original_task_description

View File

@@ -725,17 +725,9 @@ class Agent(BaseAgent):
raw_tools: list[BaseTool] = tools or self.tools or []
parsed_tools = parse_tools(raw_tools)
use_native_tool_calling = (
hasattr(self.llm, "supports_function_calling")
and callable(getattr(self.llm, "supports_function_calling", None))
and self.llm.supports_function_calling()
and len(raw_tools) > 0
)
prompt = Prompts(
agent=self,
has_tools=len(raw_tools) > 0,
use_native_tool_calling=use_native_tool_calling,
i18n=self.i18n,
use_system_prompt=self.use_system_prompt,
system_template=self.system_template,
@@ -743,8 +735,6 @@ class Agent(BaseAgent):
response_template=self.response_template,
).task_execution()
print("prompt", prompt)
stop_words = [self.i18n.slice("observation")]
if self.response_template:

View File

@@ -236,30 +236,14 @@ def process_tool_results(agent: Agent, result: Any) -> Any:
def save_last_messages(agent: Agent) -> None:
"""Save the last messages from agent executor.
Sanitizes messages to be compatible with TaskOutput's LLMMessage type,
which only accepts 'user', 'assistant', 'system' roles and requires
content to be a string or list (not None).
Args:
agent: The agent instance.
"""
if not agent.agent_executor or not hasattr(agent.agent_executor, "messages"):
agent._last_messages = []
return
sanitized_messages = []
for msg in agent.agent_executor.messages:
role = msg.get("role", "")
# Only include messages with valid LLMMessage roles
if role not in ("user", "assistant", "system"):
continue
# Ensure content is not None (can happen with tool call assistant messages)
content = msg.get("content")
if content is None:
content = ""
sanitized_messages.append({"role": role, "content": content})
agent._last_messages = sanitized_messages
agent._last_messages = (
agent.agent_executor.messages.copy()
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
else []
)
def prepare_tools(

View File

@@ -30,7 +30,6 @@ from crewai.hooks.llm_hooks import (
)
from crewai.utilities.agent_utils import (
aget_llm_response,
convert_tools_to_openai_schema,
enforce_rpm_limit,
format_message_for_llm,
get_llm_response,
@@ -216,33 +215,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def _invoke_loop(self) -> AgentFinish:
"""Execute agent loop until completion.
Checks if the LLM supports native function calling and uses that
approach if available, otherwise falls back to the ReAct text pattern.
Returns:
Final answer from the agent.
"""
# Check if model supports native function calling
use_native_tools = (
hasattr(self.llm, "supports_function_calling")
and callable(getattr(self.llm, "supports_function_calling", None))
and self.llm.supports_function_calling()
and self.original_tools
)
if use_native_tools:
return self._invoke_loop_native_tools()
# Fall back to ReAct text-based pattern
return self._invoke_loop_react()
def _invoke_loop_react(self) -> AgentFinish:
"""Execute agent loop using ReAct text-based pattern.
This is the traditional approach where tool definitions are embedded
in the prompt and the LLM outputs Action/Action Input text that is
parsed to execute tools.
Returns:
Final answer from the agent.
"""
@@ -272,10 +244,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
response_model=self.response_model,
executor_context=self,
)
print("--------------------------------")
print("get_llm_response answer", answer)
print("--------------------------------")
# breakpoint()
if self.response_model is not None:
try:
self.response_model.model_validate_json(answer)
@@ -365,338 +333,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
def _invoke_loop_native_tools(self) -> AgentFinish:
"""Execute agent loop using native function calling.
This method uses the LLM's native tool/function calling capability
instead of the text-based ReAct pattern. The LLM directly returns
structured tool calls which are executed and results fed back.
Returns:
Final answer from the agent.
"""
print("--------------------------------")
print("invoke_loop_native_tools")
print("--------------------------------")
# Convert tools to OpenAI schema format
if not self.original_tools:
# No tools available, fall back to simple LLM call
return self._invoke_loop_native_no_tools()
openai_tools, available_functions = convert_tools_to_openai_schema(
self.original_tools
)
while True:
try:
if has_reached_max_iterations(self.iterations, self.max_iter):
formatted_answer = handle_max_iterations_exceeded(
None,
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
)
self._show_logs(formatted_answer)
return formatted_answer
enforce_rpm_limit(self.request_within_rpm_limit)
# Debug: Show messages being sent to LLM
print("--------------------------------")
print(f"Messages count: {len(self.messages)}")
for i, msg in enumerate(self.messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
if content:
preview = (
content[:200] + "..." if len(content) > 200 else content
)
else:
preview = "(no content)"
print(f" [{i}] {role}: {preview}")
print("--------------------------------")
# Call LLM with native tools
# Pass available_functions=None so the LLM returns tool_calls
# without executing them. The executor handles tool execution
# via _handle_native_tool_calls to properly manage message history.
answer = get_llm_response(
llm=self.llm,
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
tools=openai_tools,
available_functions=None,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
print("--------------------------------")
print("invoke_loop_native_tools answer", answer)
print("--------------------------------")
# print("get_llm_response answer", answer[:500] + "...")
# Check if the response is a list of tool calls
if (
isinstance(answer, list)
and answer
and self._is_tool_call_list(answer)
):
# Handle tool calls - execute tools and add results to messages
self._handle_native_tool_calls(answer, available_functions)
# Continue loop to let LLM analyze results and decide next steps
continue
# Text or other response - handle as potential final answer
if isinstance(answer, str):
# Text response - this is the final answer
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
self._invoke_step_callback(formatted_answer)
self._append_message(answer) # Save final answer to messages
self._show_logs(formatted_answer)
return formatted_answer
# Unexpected response type, treat as final answer
formatted_answer = AgentFinish(
thought="",
output=str(answer),
text=str(answer),
)
self._invoke_step_callback(formatted_answer)
self._append_message(str(answer)) # Save final answer to messages
self._show_logs(formatted_answer)
return formatted_answer
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
raise e
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
)
continue
handle_unknown_error(self._printer, e)
raise e
finally:
self.iterations += 1
def _invoke_loop_native_no_tools(self) -> AgentFinish:
"""Execute a simple LLM call when no tools are available.
Returns:
Final answer from the agent.
"""
enforce_rpm_limit(self.request_within_rpm_limit)
answer = get_llm_response(
llm=self.llm,
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
formatted_answer = AgentFinish(
thought="",
output=str(answer),
text=str(answer),
)
self._show_logs(formatted_answer)
return formatted_answer
def _is_tool_call_list(self, response: list[Any]) -> bool:
"""Check if a response is a list of tool calls.
Args:
response: The response to check.
Returns:
True if the response appears to be a list of tool calls.
"""
if not response:
return False
first_item = response[0]
# OpenAI-style
if hasattr(first_item, "function") or (
isinstance(first_item, dict) and "function" in first_item
):
return True
# Anthropic-style
if (
hasattr(first_item, "type")
and getattr(first_item, "type", None) == "tool_use"
):
return True
if hasattr(first_item, "name") and hasattr(first_item, "input"):
return True
# Gemini-style
if hasattr(first_item, "function_call") and first_item.function_call:
return True
return False
def _handle_native_tool_calls(
self,
tool_calls: list[Any],
available_functions: dict[str, Callable[..., Any]],
) -> None:
"""Handle a single native tool call from the LLM.
Executes only the FIRST tool call and appends the result to message history.
This enables sequential tool execution with reflection after each tool,
allowing the LLM to reason about results before deciding on next steps.
Args:
tool_calls: List of tool calls from the LLM (only first is processed).
available_functions: Dict mapping function names to callables.
"""
from datetime import datetime
import json
from crewai.events import crewai_event_bus
from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
if not tool_calls:
return
# Only process the FIRST tool call for sequential execution with reflection
tool_call = tool_calls[0]
# Extract tool call info - handle OpenAI-style, Anthropic-style, and Gemini-style
if hasattr(tool_call, "function"):
# OpenAI-style: has .function.name and .function.arguments
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
func_name = tool_call.function.name
func_args = tool_call.function.arguments
elif hasattr(tool_call, "function_call") and tool_call.function_call:
# Gemini-style: has .function_call.name and .function_call.args
call_id = f"call_{id(tool_call)}"
func_name = tool_call.function_call.name
func_args = (
dict(tool_call.function_call.args)
if tool_call.function_call.args
else {}
)
elif hasattr(tool_call, "name") and hasattr(tool_call, "input"):
# Anthropic format: has .name and .input (ToolUseBlock)
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
func_name = tool_call.name
func_args = tool_call.input # Already a dict in Anthropic
elif isinstance(tool_call, dict):
call_id = tool_call.get("id", f"call_{id(tool_call)}")
func_info = tool_call.get("function", {})
func_name = func_info.get("name", "") or tool_call.get("name", "")
func_args = func_info.get("arguments", "{}") or tool_call.get("input", {})
else:
return
# Append assistant message with single tool call
assistant_message: LLMMessage = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": call_id,
"type": "function",
"function": {
"name": func_name,
"arguments": func_args
if isinstance(func_args, str)
else json.dumps(func_args),
},
}
],
}
self.messages.append(assistant_message)
# Parse arguments for the single tool call
if isinstance(func_args, str):
try:
args_dict = json.loads(func_args)
except json.JSONDecodeError:
args_dict = {}
else:
args_dict = func_args
# Emit tool usage started event
started_at = datetime.now()
crewai_event_bus.emit(
self,
event=ToolUsageStartedEvent(
tool_name=func_name,
tool_args=args_dict,
from_agent=self.agent,
from_task=self.task,
),
)
# Execute the tool
print(f"Using Tool: {func_name}")
result = "Tool not found"
if func_name in available_functions:
try:
tool_func = available_functions[func_name]
result = tool_func(**args_dict)
if not isinstance(result, str):
result = str(result)
except Exception as e:
result = f"Error executing tool: {e}"
# Emit tool usage finished event
crewai_event_bus.emit(
self,
event=ToolUsageFinishedEvent(
output=result,
tool_name=func_name,
tool_args=args_dict,
from_agent=self.agent,
from_task=self.task,
started_at=started_at,
finished_at=datetime.now(),
),
)
# Append tool result message
tool_message: LLMMessage = {
"role": "tool",
"tool_call_id": call_id,
"content": result,
}
self.messages.append(tool_message)
# Log the tool execution
if self.agent and self.agent.verbose:
self._printer.print(
content=f"Tool {func_name} executed with result: {result[:200]}...",
color="green",
)
# Inject post-tool reasoning prompt to enforce analysis
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
"role": "user",
"content": reasoning_prompt,
}
self.messages.append(reasoning_message)
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Execute the agent asynchronously with given inputs.
@@ -746,29 +382,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
async def _ainvoke_loop(self) -> AgentFinish:
"""Execute agent loop asynchronously until completion.
Checks if the LLM supports native function calling and uses that
approach if available, otherwise falls back to the ReAct text pattern.
Returns:
Final answer from the agent.
"""
# Check if model supports native function calling
use_native_tools = (
hasattr(self.llm, "supports_function_calling")
and callable(getattr(self.llm, "supports_function_calling", None))
and self.llm.supports_function_calling()
and self.original_tools
)
if use_native_tools:
return await self._ainvoke_loop_native_tools()
# Fall back to ReAct text-based pattern
return await self._ainvoke_loop_react()
async def _ainvoke_loop_react(self) -> AgentFinish:
"""Execute agent loop asynchronously using ReAct text-based pattern.
Returns:
Final answer from the agent.
"""
@@ -882,139 +495,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
async def _ainvoke_loop_native_tools(self) -> AgentFinish:
"""Execute agent loop asynchronously using native function calling.
This method uses the LLM's native tool/function calling capability
instead of the text-based ReAct pattern.
Returns:
Final answer from the agent.
"""
# Convert tools to OpenAI schema format
if not self.original_tools:
return await self._ainvoke_loop_native_no_tools()
openai_tools, available_functions = convert_tools_to_openai_schema(
self.original_tools
)
while True:
try:
if has_reached_max_iterations(self.iterations, self.max_iter):
formatted_answer = handle_max_iterations_exceeded(
None,
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
)
self._show_logs(formatted_answer)
return formatted_answer
enforce_rpm_limit(self.request_within_rpm_limit)
# Call LLM with native tools
# Pass available_functions=None so the LLM returns tool_calls
# without executing them. The executor handles tool execution
# via _handle_native_tool_calls to properly manage message history.
answer = await aget_llm_response(
llm=self.llm,
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
tools=openai_tools,
available_functions=None,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
print("--------------------------------")
print("native llm completion answer", answer)
print("--------------------------------")
# Check if the response is a list of tool calls
if (
isinstance(answer, list)
and answer
and self._is_tool_call_list(answer)
):
# Handle tool calls - execute tools and add results to messages
self._handle_native_tool_calls(answer, available_functions)
# Continue loop to let LLM analyze results and decide next steps
continue
# Text or other response - handle as potential final answer
if isinstance(answer, str):
# Text response - this is the final answer
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
self._invoke_step_callback(formatted_answer)
self._append_message(answer) # Save final answer to messages
self._show_logs(formatted_answer)
return formatted_answer
# Unexpected response type, treat as final answer
formatted_answer = AgentFinish(
thought="",
output=str(answer),
text=str(answer),
)
self._invoke_step_callback(formatted_answer)
self._append_message(str(answer)) # Save final answer to messages
self._show_logs(formatted_answer)
return formatted_answer
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
raise e
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
)
continue
handle_unknown_error(self._printer, e)
raise e
finally:
self.iterations += 1
async def _ainvoke_loop_native_no_tools(self) -> AgentFinish:
"""Execute a simple async LLM call when no tools are available.
Returns:
Final answer from the agent.
"""
enforce_rpm_limit(self.request_within_rpm_limit)
answer = await aget_llm_response(
llm=self.llm,
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
formatted_answer = AgentFinish(
thought="",
output=str(answer),
text=str(answer),
)
self._show_logs(formatted_answer)
return formatted_answer
def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult
) -> AgentAction | AgentFinish:

View File

@@ -378,12 +378,6 @@ class EventListener(BaseEventListener):
self.formatter.handle_llm_tool_usage_finished(
event.tool_name,
)
else:
self.formatter.handle_tool_usage_finished(
event.tool_name,
event.output,
getattr(event, "run_attempts", None),
)
@crewai_event_bus.on(ToolUsageErrorEvent)
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:

View File

@@ -1,19 +1,28 @@
from crewai.events.types.a2a_events import (
A2AAgentCardFetchedEvent,
A2AArtifactReceivedEvent,
A2AAuthenticationFailedEvent,
A2AConnectionErrorEvent,
A2AConversationCompletedEvent,
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
A2AParallelDelegationCompletedEvent,
A2AParallelDelegationStartedEvent,
A2APollingStartedEvent,
A2APollingStatusEvent,
A2APushNotificationReceivedEvent,
A2APushNotificationRegisteredEvent,
A2APushNotificationSentEvent,
A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
A2AServerTaskCanceledEvent,
A2AServerTaskCompletedEvent,
A2AServerTaskFailedEvent,
A2AServerTaskStartedEvent,
A2AStreamingChunkEvent,
A2AStreamingStartedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
@@ -93,7 +102,11 @@ from crewai.events.types.tool_usage_events import (
EventTypes = (
A2AConversationCompletedEvent
A2AAgentCardFetchedEvent
| A2AArtifactReceivedEvent
| A2AAuthenticationFailedEvent
| A2AConnectionErrorEvent
| A2AConversationCompletedEvent
| A2AConversationStartedEvent
| A2ADelegationCompletedEvent
| A2ADelegationStartedEvent
@@ -102,12 +115,17 @@ EventTypes = (
| A2APollingStatusEvent
| A2APushNotificationReceivedEvent
| A2APushNotificationRegisteredEvent
| A2APushNotificationSentEvent
| A2APushNotificationTimeoutEvent
| A2AResponseReceivedEvent
| A2AServerTaskCanceledEvent
| A2AServerTaskCompletedEvent
| A2AServerTaskFailedEvent
| A2AServerTaskStartedEvent
| A2AStreamingChunkEvent
| A2AStreamingStartedEvent
| A2AParallelDelegationStartedEvent
| A2AParallelDelegationCompletedEvent
| CrewKickoffStartedEvent
| CrewKickoffCompletedEvent
| CrewKickoffFailedEvent

View File

@@ -1,7 +1,7 @@
"""Trace collection listener for orchestrating trace collection."""
import os
from typing import Any, ClassVar, cast
from typing import Any, ClassVar
import uuid
from typing_extensions import Self
@@ -18,6 +18,32 @@ from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.listeners.tracing.utils import (
safe_serialize_to_dict,
)
from crewai.events.types.a2a_events import (
A2AAgentCardFetchedEvent,
A2AArtifactReceivedEvent,
A2AAuthenticationFailedEvent,
A2AConnectionErrorEvent,
A2AConversationCompletedEvent,
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
A2AParallelDelegationCompletedEvent,
A2AParallelDelegationStartedEvent,
A2APollingStartedEvent,
A2APollingStatusEvent,
A2APushNotificationReceivedEvent,
A2APushNotificationRegisteredEvent,
A2APushNotificationSentEvent,
A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
A2AServerTaskCanceledEvent,
A2AServerTaskCompletedEvent,
A2AServerTaskFailedEvent,
A2AServerTaskStartedEvent,
A2AStreamingChunkEvent,
A2AStreamingStartedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
@@ -105,7 +131,7 @@ class TraceCollectionListener(BaseEventListener):
"""Create or return singleton instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cast(Self, cls._instance)
return cls._instance
def __init__(
self,
@@ -160,6 +186,7 @@ class TraceCollectionListener(BaseEventListener):
self._register_flow_event_handlers(crewai_event_bus)
self._register_context_event_handlers(crewai_event_bus)
self._register_action_event_handlers(crewai_event_bus)
self._register_a2a_event_handlers(crewai_event_bus)
self._register_system_event_handlers(crewai_event_bus)
self._listeners_setup = True
@@ -439,6 +466,147 @@ class TraceCollectionListener(BaseEventListener):
) -> None:
self._handle_action_event("knowledge_query_failed", source, event)
def _register_a2a_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for A2A (Agent-to-Agent) events."""
@event_bus.on(A2ADelegationStartedEvent)
def on_a2a_delegation_started(
source: Any, event: A2ADelegationStartedEvent
) -> None:
self._handle_action_event("a2a_delegation_started", source, event)
@event_bus.on(A2ADelegationCompletedEvent)
def on_a2a_delegation_completed(
source: Any, event: A2ADelegationCompletedEvent
) -> None:
self._handle_action_event("a2a_delegation_completed", source, event)
@event_bus.on(A2AConversationStartedEvent)
def on_a2a_conversation_started(
source: Any, event: A2AConversationStartedEvent
) -> None:
self._handle_action_event("a2a_conversation_started", source, event)
@event_bus.on(A2AMessageSentEvent)
def on_a2a_message_sent(source: Any, event: A2AMessageSentEvent) -> None:
self._handle_action_event("a2a_message_sent", source, event)
@event_bus.on(A2AResponseReceivedEvent)
def on_a2a_response_received(
source: Any, event: A2AResponseReceivedEvent
) -> None:
self._handle_action_event("a2a_response_received", source, event)
@event_bus.on(A2AConversationCompletedEvent)
def on_a2a_conversation_completed(
source: Any, event: A2AConversationCompletedEvent
) -> None:
self._handle_action_event("a2a_conversation_completed", source, event)
@event_bus.on(A2APollingStartedEvent)
def on_a2a_polling_started(source: Any, event: A2APollingStartedEvent) -> None:
self._handle_action_event("a2a_polling_started", source, event)
@event_bus.on(A2APollingStatusEvent)
def on_a2a_polling_status(source: Any, event: A2APollingStatusEvent) -> None:
self._handle_action_event("a2a_polling_status", source, event)
@event_bus.on(A2APushNotificationRegisteredEvent)
def on_a2a_push_notification_registered(
source: Any, event: A2APushNotificationRegisteredEvent
) -> None:
self._handle_action_event("a2a_push_notification_registered", source, event)
@event_bus.on(A2APushNotificationReceivedEvent)
def on_a2a_push_notification_received(
source: Any, event: A2APushNotificationReceivedEvent
) -> None:
self._handle_action_event("a2a_push_notification_received", source, event)
@event_bus.on(A2APushNotificationSentEvent)
def on_a2a_push_notification_sent(
source: Any, event: A2APushNotificationSentEvent
) -> None:
self._handle_action_event("a2a_push_notification_sent", source, event)
@event_bus.on(A2APushNotificationTimeoutEvent)
def on_a2a_push_notification_timeout(
source: Any, event: A2APushNotificationTimeoutEvent
) -> None:
self._handle_action_event("a2a_push_notification_timeout", source, event)
@event_bus.on(A2AStreamingStartedEvent)
def on_a2a_streaming_started(
source: Any, event: A2AStreamingStartedEvent
) -> None:
self._handle_action_event("a2a_streaming_started", source, event)
@event_bus.on(A2AStreamingChunkEvent)
def on_a2a_streaming_chunk(source: Any, event: A2AStreamingChunkEvent) -> None:
self._handle_action_event("a2a_streaming_chunk", source, event)
@event_bus.on(A2AAgentCardFetchedEvent)
def on_a2a_agent_card_fetched(
source: Any, event: A2AAgentCardFetchedEvent
) -> None:
self._handle_action_event("a2a_agent_card_fetched", source, event)
@event_bus.on(A2AAuthenticationFailedEvent)
def on_a2a_authentication_failed(
source: Any, event: A2AAuthenticationFailedEvent
) -> None:
self._handle_action_event("a2a_authentication_failed", source, event)
@event_bus.on(A2AArtifactReceivedEvent)
def on_a2a_artifact_received(
source: Any, event: A2AArtifactReceivedEvent
) -> None:
self._handle_action_event("a2a_artifact_received", source, event)
@event_bus.on(A2AConnectionErrorEvent)
def on_a2a_connection_error(
source: Any, event: A2AConnectionErrorEvent
) -> None:
self._handle_action_event("a2a_connection_error", source, event)
@event_bus.on(A2AServerTaskStartedEvent)
def on_a2a_server_task_started(
source: Any, event: A2AServerTaskStartedEvent
) -> None:
self._handle_action_event("a2a_server_task_started", source, event)
@event_bus.on(A2AServerTaskCompletedEvent)
def on_a2a_server_task_completed(
source: Any, event: A2AServerTaskCompletedEvent
) -> None:
self._handle_action_event("a2a_server_task_completed", source, event)
@event_bus.on(A2AServerTaskCanceledEvent)
def on_a2a_server_task_canceled(
source: Any, event: A2AServerTaskCanceledEvent
) -> None:
self._handle_action_event("a2a_server_task_canceled", source, event)
@event_bus.on(A2AServerTaskFailedEvent)
def on_a2a_server_task_failed(
source: Any, event: A2AServerTaskFailedEvent
) -> None:
self._handle_action_event("a2a_server_task_failed", source, event)
@event_bus.on(A2AParallelDelegationStartedEvent)
def on_a2a_parallel_delegation_started(
source: Any, event: A2AParallelDelegationStartedEvent
) -> None:
self._handle_action_event("a2a_parallel_delegation_started", source, event)
@event_bus.on(A2AParallelDelegationCompletedEvent)
def on_a2a_parallel_delegation_completed(
source: Any, event: A2AParallelDelegationCompletedEvent
) -> None:
self._handle_action_event(
"a2a_parallel_delegation_completed", source, event
)
def _register_system_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for system signal events (SIGTERM, SIGINT, etc.)."""
@@ -570,10 +738,15 @@ class TraceCollectionListener(BaseEventListener):
if event_type not in self.complex_events:
return safe_serialize_to_dict(event)
if event_type == "task_started":
task_name = event.task.name or event.task.description
task_display_name = (
task_name[:80] + "..." if len(task_name) > 80 else task_name
)
return {
"task_description": event.task.description,
"expected_output": event.task.expected_output,
"task_name": event.task.name or event.task.description,
"task_name": task_name,
"task_display_name": task_display_name,
"context": event.context,
"agent_role": source.agent.role,
"task_id": str(event.task.id),

View File

@@ -4,68 +4,120 @@ This module defines events emitted during A2A protocol delegation,
including both single-turn and multiturn conversation flows.
"""
from __future__ import annotations
from typing import Any, Literal
from pydantic import model_validator
from crewai.events.base_events import BaseEvent
class A2AEventBase(BaseEvent):
"""Base class for A2A events with task/agent context."""
from_task: Any | None = None
from_agent: Any | None = None
from_task: Any = None
from_agent: Any = None
def __init__(self, **data: Any) -> None:
"""Initialize A2A event, extracting task and agent metadata."""
if data.get("from_task"):
task = data["from_task"]
@model_validator(mode="before")
@classmethod
def extract_task_and_agent_metadata(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Extract task and agent metadata before validation."""
if task := data.get("from_task"):
data["task_id"] = str(task.id)
data["task_name"] = task.name or task.description
data.setdefault("source_fingerprint", str(task.id))
data.setdefault("source_type", "task")
data.setdefault(
"fingerprint_metadata",
{
"task_id": str(task.id),
"task_name": task.name or task.description,
},
)
data["from_task"] = None
if data.get("from_agent"):
agent = data["from_agent"]
if agent := data.get("from_agent"):
data["agent_id"] = str(agent.id)
data["agent_role"] = agent.role
data.setdefault("source_fingerprint", str(agent.id))
data.setdefault("source_type", "agent")
data.setdefault(
"fingerprint_metadata",
{
"agent_id": str(agent.id),
"agent_role": agent.role,
},
)
data["from_agent"] = None
super().__init__(**data)
return data
class A2ADelegationStartedEvent(A2AEventBase):
"""Event emitted when A2A delegation starts.
Attributes:
endpoint: A2A agent endpoint URL (AgentCard URL)
task_description: Task being delegated to the A2A agent
agent_id: A2A agent identifier
is_multiturn: Whether this is part of a multiturn conversation
turn_number: Current turn number (1-indexed, 1 for single-turn)
endpoint: A2A agent endpoint URL (AgentCard URL).
task_description: Task being delegated to the A2A agent.
agent_id: A2A agent identifier.
context_id: A2A context ID grouping related tasks.
is_multiturn: Whether this is part of a multiturn conversation.
turn_number: Current turn number (1-indexed, 1 for single-turn).
a2a_agent_name: Name of the A2A agent from agent card.
agent_card: Full A2A agent card metadata.
protocol_version: A2A protocol version being used.
provider: Agent provider/organization info from agent card.
skill_id: ID of the specific skill being invoked.
metadata: Custom A2A metadata key-value pairs.
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_delegation_started"
endpoint: str
task_description: str
agent_id: str
context_id: str | None = None
is_multiturn: bool = False
turn_number: int = 1
a2a_agent_name: str | None = None
agent_card: dict[str, Any] | None = None
protocol_version: str | None = None
provider: dict[str, Any] | None = None
skill_id: str | None = None
metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
class A2ADelegationCompletedEvent(A2AEventBase):
"""Event emitted when A2A delegation completes.
Attributes:
status: Completion status (completed, input_required, failed, etc.)
result: Result message if status is completed
error: Error/response message (error for failed, response for input_required)
is_multiturn: Whether this is part of a multiturn conversation
status: Completion status (completed, input_required, failed, etc.).
result: Result message if status is completed.
error: Error/response message (error for failed, response for input_required).
context_id: A2A context ID grouping related tasks.
is_multiturn: Whether this is part of a multiturn conversation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
agent_card: Full A2A agent card metadata.
provider: Agent provider/organization info from agent card.
metadata: Custom A2A metadata key-value pairs.
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_delegation_completed"
status: str
result: str | None = None
error: str | None = None
context_id: str | None = None
is_multiturn: bool = False
endpoint: str | None = None
a2a_agent_name: str | None = None
agent_card: dict[str, Any] | None = None
provider: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
class A2AConversationStartedEvent(A2AEventBase):
@@ -75,51 +127,95 @@ class A2AConversationStartedEvent(A2AEventBase):
before the first message exchange.
Attributes:
agent_id: A2A agent identifier
endpoint: A2A agent endpoint URL
a2a_agent_name: Name of the A2A agent from agent card
agent_id: A2A agent identifier.
endpoint: A2A agent endpoint URL.
context_id: A2A context ID grouping related tasks.
a2a_agent_name: Name of the A2A agent from agent card.
agent_card: Full A2A agent card metadata.
protocol_version: A2A protocol version being used.
provider: Agent provider/organization info from agent card.
skill_id: ID of the specific skill being invoked.
reference_task_ids: Related task IDs for context.
metadata: Custom A2A metadata key-value pairs.
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_conversation_started"
agent_id: str
endpoint: str
context_id: str | None = None
a2a_agent_name: str | None = None
agent_card: dict[str, Any] | None = None
protocol_version: str | None = None
provider: dict[str, Any] | None = None
skill_id: str | None = None
reference_task_ids: list[str] | None = None
metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
class A2AMessageSentEvent(A2AEventBase):
"""Event emitted when a message is sent to the A2A agent.
Attributes:
message: Message content sent to the A2A agent
turn_number: Current turn number (1-indexed)
is_multiturn: Whether this is part of a multiturn conversation
agent_role: Role of the CrewAI agent sending the message
message: Message content sent to the A2A agent.
turn_number: Current turn number (1-indexed).
context_id: A2A context ID grouping related tasks.
message_id: Unique A2A message identifier.
is_multiturn: Whether this is part of a multiturn conversation.
agent_role: Role of the CrewAI agent sending the message.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
skill_id: ID of the specific skill being invoked.
metadata: Custom A2A metadata key-value pairs.
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_message_sent"
message: str
turn_number: int
context_id: str | None = None
message_id: str | None = None
is_multiturn: bool = False
agent_role: str | None = None
endpoint: str | None = None
a2a_agent_name: str | None = None
skill_id: str | None = None
metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
class A2AResponseReceivedEvent(A2AEventBase):
"""Event emitted when a response is received from the A2A agent.
Attributes:
response: Response content from the A2A agent
turn_number: Current turn number (1-indexed)
is_multiturn: Whether this is part of a multiturn conversation
status: Response status (input_required, completed, etc.)
agent_role: Role of the CrewAI agent (for display)
response: Response content from the A2A agent.
turn_number: Current turn number (1-indexed).
context_id: A2A context ID grouping related tasks.
message_id: Unique A2A message identifier.
is_multiturn: Whether this is part of a multiturn conversation.
status: Response status (input_required, completed, etc.).
final: Whether this is the final response in the stream.
agent_role: Role of the CrewAI agent (for display).
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
metadata: Custom A2A metadata key-value pairs.
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_response_received"
response: str
turn_number: int
context_id: str | None = None
message_id: str | None = None
is_multiturn: bool = False
status: str
final: bool = False
agent_role: str | None = None
endpoint: str | None = None
a2a_agent_name: str | None = None
metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
class A2AConversationCompletedEvent(A2AEventBase):
@@ -128,119 +224,433 @@ class A2AConversationCompletedEvent(A2AEventBase):
This is emitted once at the end of a multiturn conversation.
Attributes:
status: Final status (completed, failed, etc.)
final_result: Final result if completed successfully
error: Error message if failed
total_turns: Total number of turns in the conversation
status: Final status (completed, failed, etc.).
final_result: Final result if completed successfully.
error: Error message if failed.
context_id: A2A context ID grouping related tasks.
total_turns: Total number of turns in the conversation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
agent_card: Full A2A agent card metadata.
reference_task_ids: Related task IDs for context.
metadata: Custom A2A metadata key-value pairs.
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_conversation_completed"
status: Literal["completed", "failed"]
final_result: str | None = None
error: str | None = None
context_id: str | None = None
total_turns: int
endpoint: str | None = None
a2a_agent_name: str | None = None
agent_card: dict[str, Any] | None = None
reference_task_ids: list[str] | None = None
metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
class A2APollingStartedEvent(A2AEventBase):
"""Event emitted when polling mode begins for A2A delegation.
Attributes:
task_id: A2A task ID being polled
polling_interval: Seconds between poll attempts
endpoint: A2A agent endpoint URL
task_id: A2A task ID being polled.
context_id: A2A context ID grouping related tasks.
polling_interval: Seconds between poll attempts.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_polling_started"
task_id: str
context_id: str | None = None
polling_interval: float
endpoint: str
a2a_agent_name: str | None = None
metadata: dict[str, Any] | None = None
class A2APollingStatusEvent(A2AEventBase):
"""Event emitted on each polling iteration.
Attributes:
task_id: A2A task ID being polled
state: Current task state from remote agent
elapsed_seconds: Time since polling started
poll_count: Number of polls completed
task_id: A2A task ID being polled.
context_id: A2A context ID grouping related tasks.
state: Current task state from remote agent.
elapsed_seconds: Time since polling started.
poll_count: Number of polls completed.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_polling_status"
task_id: str
context_id: str | None = None
state: str
elapsed_seconds: float
poll_count: int
endpoint: str | None = None
a2a_agent_name: str | None = None
metadata: dict[str, Any] | None = None
class A2APushNotificationRegisteredEvent(A2AEventBase):
"""Event emitted when push notification callback is registered.
Attributes:
task_id: A2A task ID for which callback is registered
callback_url: URL where agent will send push notifications
task_id: A2A task ID for which callback is registered.
context_id: A2A context ID grouping related tasks.
callback_url: URL where agent will send push notifications.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_push_notification_registered"
task_id: str
context_id: str | None = None
callback_url: str
endpoint: str | None = None
a2a_agent_name: str | None = None
metadata: dict[str, Any] | None = None
class A2APushNotificationReceivedEvent(A2AEventBase):
"""Event emitted when a push notification is received.
This event should be emitted by the user's webhook handler when it receives
a push notification from the remote A2A agent, before calling
`result_store.store_result()`.
Attributes:
task_id: A2A task ID from the notification
state: Current task state from the notification
task_id: A2A task ID from the notification.
context_id: A2A context ID grouping related tasks.
state: Current task state from the notification.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_push_notification_received"
task_id: str
context_id: str | None = None
state: str
endpoint: str | None = None
a2a_agent_name: str | None = None
metadata: dict[str, Any] | None = None
class A2APushNotificationSentEvent(A2AEventBase):
"""Event emitted when a push notification is sent to a callback URL.
Emitted by the A2A server when it sends a task status update to the
client's registered push notification callback URL.
Attributes:
task_id: A2A task ID being notified.
context_id: A2A context ID grouping related tasks.
callback_url: URL the notification was sent to.
state: Task state being reported.
success: Whether the notification was successfully delivered.
error: Error message if delivery failed.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_push_notification_sent"
task_id: str
context_id: str | None = None
callback_url: str
state: str
success: bool = True
error: str | None = None
metadata: dict[str, Any] | None = None
class A2APushNotificationTimeoutEvent(A2AEventBase):
"""Event emitted when push notification wait times out.
Attributes:
task_id: A2A task ID that timed out
timeout_seconds: Timeout duration in seconds
task_id: A2A task ID that timed out.
context_id: A2A context ID grouping related tasks.
timeout_seconds: Timeout duration in seconds.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_push_notification_timeout"
task_id: str
context_id: str | None = None
timeout_seconds: float
endpoint: str | None = None
a2a_agent_name: str | None = None
metadata: dict[str, Any] | None = None
class A2AStreamingStartedEvent(A2AEventBase):
"""Event emitted when streaming mode begins for A2A delegation.
Attributes:
task_id: A2A task ID for the streaming session.
context_id: A2A context ID grouping related tasks.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
turn_number: Current turn number (1-indexed).
is_multiturn: Whether this is part of a multiturn conversation.
agent_role: Role of the CrewAI agent.
metadata: Custom A2A metadata key-value pairs.
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_streaming_started"
task_id: str | None = None
context_id: str | None = None
endpoint: str
a2a_agent_name: str | None = None
turn_number: int = 1
is_multiturn: bool = False
agent_role: str | None = None
metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
class A2AStreamingChunkEvent(A2AEventBase):
"""Event emitted when a streaming chunk is received.
Attributes:
task_id: A2A task ID for the streaming session.
context_id: A2A context ID grouping related tasks.
chunk: The text content of the chunk.
chunk_index: Index of this chunk in the stream (0-indexed).
final: Whether this is the final chunk in the stream.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
turn_number: Current turn number (1-indexed).
is_multiturn: Whether this is part of a multiturn conversation.
metadata: Custom A2A metadata key-value pairs.
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_streaming_chunk"
task_id: str | None = None
context_id: str | None = None
chunk: str
chunk_index: int
final: bool = False
endpoint: str | None = None
a2a_agent_name: str | None = None
turn_number: int = 1
is_multiturn: bool = False
metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
class A2AAgentCardFetchedEvent(A2AEventBase):
"""Event emitted when an agent card is successfully fetched.
Attributes:
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
agent_card: Full A2A agent card metadata.
protocol_version: A2A protocol version from agent card.
provider: Agent provider/organization info from agent card.
cached: Whether the agent card was retrieved from cache.
fetch_time_ms: Time taken to fetch the agent card in milliseconds.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_agent_card_fetched"
endpoint: str
a2a_agent_name: str | None = None
agent_card: dict[str, Any] | None = None
protocol_version: str | None = None
provider: dict[str, Any] | None = None
cached: bool = False
fetch_time_ms: float | None = None
metadata: dict[str, Any] | None = None
class A2AAuthenticationFailedEvent(A2AEventBase):
"""Event emitted when authentication to an A2A agent fails.
Attributes:
endpoint: A2A agent endpoint URL.
auth_type: Type of authentication attempted (e.g., bearer, oauth2, api_key).
error: Error message describing the failure.
status_code: HTTP status code if applicable.
a2a_agent_name: Name of the A2A agent if known.
protocol_version: A2A protocol version being used.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_authentication_failed"
endpoint: str
auth_type: str | None = None
error: str
status_code: int | None = None
a2a_agent_name: str | None = None
protocol_version: str | None = None
metadata: dict[str, Any] | None = None
class A2AArtifactReceivedEvent(A2AEventBase):
"""Event emitted when an artifact is received from a remote A2A agent.
Attributes:
task_id: A2A task ID the artifact belongs to.
artifact_id: Unique identifier for the artifact.
artifact_name: Name of the artifact.
artifact_description: Purpose description of the artifact.
mime_type: MIME type of the artifact content.
size_bytes: Size of the artifact in bytes.
append: Whether content should be appended to existing artifact.
last_chunk: Whether this is the final chunk of the artifact.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
context_id: Context ID for correlation.
turn_number: Current turn number (1-indexed).
is_multiturn: Whether this is part of a multiturn conversation.
metadata: Custom A2A metadata key-value pairs.
extensions: List of A2A extension URIs in use.
"""
type: str = "a2a_artifact_received"
task_id: str
artifact_id: str
artifact_name: str | None = None
artifact_description: str | None = None
mime_type: str | None = None
size_bytes: int | None = None
append: bool = False
last_chunk: bool = False
endpoint: str | None = None
a2a_agent_name: str | None = None
context_id: str | None = None
turn_number: int = 1
is_multiturn: bool = False
metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
class A2AConnectionErrorEvent(A2AEventBase):
"""Event emitted when a connection error occurs during A2A communication.
Attributes:
endpoint: A2A agent endpoint URL.
error: Error message describing the connection failure.
error_type: Type of error (e.g., timeout, connection_refused, dns_error).
status_code: HTTP status code if applicable.
a2a_agent_name: Name of the A2A agent from agent card.
operation: The operation being attempted when error occurred.
context_id: A2A context ID grouping related tasks.
task_id: A2A task ID if applicable.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_connection_error"
endpoint: str
error: str
error_type: str | None = None
status_code: int | None = None
a2a_agent_name: str | None = None
operation: str | None = None
context_id: str | None = None
task_id: str | None = None
metadata: dict[str, Any] | None = None
class A2AServerTaskStartedEvent(A2AEventBase):
"""Event emitted when an A2A server task execution starts."""
"""Event emitted when an A2A server task execution starts.
Attributes:
task_id: A2A task ID for this execution.
context_id: A2A context ID grouping related tasks.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_server_task_started"
a2a_task_id: str
a2a_context_id: str
task_id: str
context_id: str
metadata: dict[str, Any] | None = None
class A2AServerTaskCompletedEvent(A2AEventBase):
"""Event emitted when an A2A server task execution completes."""
"""Event emitted when an A2A server task execution completes.
Attributes:
task_id: A2A task ID for this execution.
context_id: A2A context ID grouping related tasks.
result: The task result.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_server_task_completed"
a2a_task_id: str
a2a_context_id: str
task_id: str
context_id: str
result: str
metadata: dict[str, Any] | None = None
class A2AServerTaskCanceledEvent(A2AEventBase):
"""Event emitted when an A2A server task execution is canceled."""
"""Event emitted when an A2A server task execution is canceled.
Attributes:
task_id: A2A task ID for this execution.
context_id: A2A context ID grouping related tasks.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_server_task_canceled"
a2a_task_id: str
a2a_context_id: str
task_id: str
context_id: str
metadata: dict[str, Any] | None = None
class A2AServerTaskFailedEvent(A2AEventBase):
"""Event emitted when an A2A server task execution fails."""
"""Event emitted when an A2A server task execution fails.
Attributes:
task_id: A2A task ID for this execution.
context_id: A2A context ID grouping related tasks.
error: Error message describing the failure.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_server_task_failed"
a2a_task_id: str
a2a_context_id: str
task_id: str
context_id: str
error: str
metadata: dict[str, Any] | None = None
class A2AParallelDelegationStartedEvent(A2AEventBase):
"""Event emitted when parallel delegation to multiple A2A agents begins.
Attributes:
endpoints: List of A2A agent endpoints being delegated to.
task_description: Description of the task being delegated.
"""
type: str = "a2a_parallel_delegation_started"
endpoints: list[str]
task_description: str
class A2AParallelDelegationCompletedEvent(A2AEventBase):
"""Event emitted when parallel delegation to multiple A2A agents completes.
Attributes:
endpoints: List of A2A agent endpoints that were delegated to.
success_count: Number of successful delegations.
failure_count: Number of failed delegations.
results: Summary of results from each agent.
"""
type: str = "a2a_parallel_delegation_completed"
endpoints: list[str]
success_count: int
failure_count: int
results: dict[str, str] | None = None

View File

@@ -366,32 +366,6 @@ To enable tracing, do any one of these:
self.print_panel(content, f"🔧 Tool Execution Started (#{iteration})", "yellow")
def handle_tool_usage_finished(
self,
tool_name: str,
output: str,
run_attempts: int | None = None,
) -> None:
"""Handle tool usage finished event with panel display."""
if not self.verbose:
return
iteration = self.tool_usage_counts.get(tool_name, 1)
content = Text()
content.append("Tool Completed\n", style="green bold")
content.append("Tool: ", style="white")
content.append(f"{tool_name}\n", style="green bold")
if output:
content.append("Output: ", style="white")
content.append(f"{output}\n", style="green")
self.print_panel(
content, f"✅ Tool Execution Completed (#{iteration})", "green"
)
def handle_tool_usage_error(
self,
tool_name: str,

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
from collections.abc import Callable
from datetime import datetime
import json
import threading
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4
@@ -19,24 +17,16 @@ from crewai.agents.parser import (
OutputParserError,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled_in_context,
)
from crewai.events.types.logging_events import (
AgentLogsExecutionEvent,
AgentLogsStartedEvent,
)
from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.flow.flow import Flow, listen, or_, router, start
from crewai.hooks.llm_hooks import (
get_after_llm_call_hooks,
get_before_llm_call_hooks,
)
from crewai.utilities.agent_utils import (
convert_tools_to_openai_schema,
enforce_rpm_limit,
format_message_for_llm,
get_llm_response,
@@ -81,8 +71,6 @@ class AgentReActState(BaseModel):
current_answer: AgentAction | AgentFinish | None = Field(default=None)
is_finished: bool = Field(default=False)
ask_for_human_input: bool = Field(default=False)
use_native_tools: bool = Field(default=False)
pending_tool_calls: list[Any] = Field(default_factory=list)
class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
@@ -191,10 +179,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
)
)
# Native tool calling support
self._openai_tools: list[dict[str, Any]] = []
self._available_functions: dict[str, Callable[..., Any]] = {}
self._state = AgentReActState()
def _ensure_flow_initialized(self) -> None:
@@ -205,66 +189,14 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
Only the instance that actually executes via invoke() will emit events.
"""
if not self._flow_initialized:
current_tracing = is_tracing_enabled_in_context()
# Now call Flow's __init__ which will replace self._state
# with Flow's managed state. Suppress flow events since this is
# an agent executor, not a user-facing flow.
super().__init__(
suppress_flow_events=True,
tracing=current_tracing if current_tracing else None,
)
self._flow_initialized = True
def _check_native_tool_support(self) -> bool:
"""Check if LLM supports native function calling.
Returns:
True if the LLM supports native function calling and tools are available.
"""
return (
hasattr(self.llm, "supports_function_calling")
and callable(getattr(self.llm, "supports_function_calling", None))
and self.llm.supports_function_calling()
and bool(self.original_tools)
)
def _setup_native_tools(self) -> None:
"""Convert tools to OpenAI schema format for native function calling."""
if self.original_tools:
self._openai_tools, self._available_functions = (
convert_tools_to_openai_schema(self.original_tools)
)
def _is_tool_call_list(self, response: list[Any]) -> bool:
"""Check if a response is a list of tool calls.
Args:
response: The response to check.
Returns:
True if the response appears to be a list of tool calls.
"""
if not response:
return False
first_item = response[0]
# Check for OpenAI-style tool call structure
if hasattr(first_item, "function") or (
isinstance(first_item, dict) and "function" in first_item
):
return True
# Check for Anthropic-style tool call structure (ToolUseBlock)
if (
hasattr(first_item, "type")
and getattr(first_item, "type", None) == "tool_use"
):
return True
if hasattr(first_item, "name") and hasattr(first_item, "input"):
return True
# Check for Gemini-style function call (Part with function_call)
if hasattr(first_item, "function_call") and first_item.function_call:
return True
return False
@property
def use_stop_words(self) -> bool:
"""Check to determine if stop words are being used.
@@ -297,11 +229,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
def initialize_reasoning(self) -> Literal["initialized"]:
"""Initialize the reasoning flow and emit agent start logs."""
self._show_start_logs()
# Check for native tool support on first iteration
if self.state.iterations == 0:
self.state.use_native_tools = self._check_native_tool_support()
if self.state.use_native_tools:
self._setup_native_tools()
return "initialized"
@listen("force_final_answer")
@@ -376,69 +303,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
handle_unknown_error(self._printer, e)
raise
@listen("continue_reasoning_native")
def call_llm_native_tools(
self,
) -> Literal["native_tool_calls", "native_finished", "context_error"]:
"""Execute LLM call with native function calling.
Returns routing decision based on whether tool calls or final answer.
"""
try:
enforce_rpm_limit(self.request_within_rpm_limit)
# Call LLM with native tools
# Pass available_functions=None so the LLM returns tool_calls
# without executing them. The executor handles tool execution.
answer = get_llm_response(
llm=self.llm,
messages=list(self.state.messages),
callbacks=self.callbacks,
printer=self._printer,
tools=self._openai_tools,
available_functions=None,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
# Check if the response is a list of tool calls
if isinstance(answer, list) and answer and self._is_tool_call_list(answer):
# Store tool calls for sequential processing
self.state.pending_tool_calls = list(answer)
return "native_tool_calls"
# Text response - this is the final answer
if isinstance(answer, str):
self.state.current_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
self._invoke_step_callback(self.state.current_answer)
self._append_message_to_state(answer)
return "native_finished"
# Unexpected response type, treat as final answer
self.state.current_answer = AgentFinish(
thought="",
output=str(answer),
text=str(answer),
)
self._invoke_step_callback(self.state.current_answer)
self._append_message_to_state(str(answer))
return "native_finished"
except Exception as e:
if is_context_length_exceeded(e):
self._last_context_error = e
return "context_error"
if e.__class__.__module__.startswith("litellm"):
raise e
handle_unknown_error(self._printer, e)
raise
@router(call_llm_and_parse)
def route_by_answer_type(self) -> Literal["execute_tool", "agent_finished"]:
"""Route based on whether answer is AgentAction or AgentFinish."""
@@ -494,14 +358,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
self.state.is_finished = True
return "tool_result_is_final"
# Inject post-tool reasoning prompt to enforce analysis
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
"role": "user",
"content": reasoning_prompt,
}
self.state.messages.append(reasoning_message)
return "tool_completed"
except Exception as e:
@@ -511,143 +367,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
self._console.print(error_text)
raise
@listen("native_tool_calls")
def execute_native_tool(self) -> Literal["native_tool_completed"]:
"""Execute a single native tool call and inject reasoning prompt.
Processes only the FIRST tool call from pending_tool_calls for
sequential execution with reflection after each tool.
"""
if not self.state.pending_tool_calls:
return "native_tool_completed"
tool_call = self.state.pending_tool_calls[0]
self.state.pending_tool_calls = [] # Clear pending calls
# Extract tool call info - handle OpenAI, Anthropic, and Gemini formats
if hasattr(tool_call, "function"):
# OpenAI format: has .function.name and .function.arguments
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
func_name = tool_call.function.name
func_args = tool_call.function.arguments
elif hasattr(tool_call, "function_call") and tool_call.function_call:
# Gemini format: has .function_call.name and .function_call.args
call_id = f"call_{id(tool_call)}"
func_name = tool_call.function_call.name
func_args = (
dict(tool_call.function_call.args)
if tool_call.function_call.args
else {}
)
elif hasattr(tool_call, "name") and hasattr(tool_call, "input"):
# Anthropic format: has .name and .input (ToolUseBlock)
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
func_name = tool_call.name
func_args = tool_call.input # Already a dict in Anthropic
elif isinstance(tool_call, dict):
call_id = tool_call.get("id", f"call_{id(tool_call)}")
func_info = tool_call.get("function", {})
func_name = func_info.get("name", "") or tool_call.get("name", "")
func_args = func_info.get("arguments", "{}") or tool_call.get("input", {})
else:
return "native_tool_completed"
# Append assistant message with single tool call
assistant_message: LLMMessage = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": call_id,
"type": "function",
"function": {
"name": func_name,
"arguments": func_args
if isinstance(func_args, str)
else json.dumps(func_args),
},
}
],
}
self.state.messages.append(assistant_message)
# Parse arguments for the single tool call
if isinstance(func_args, str):
try:
args_dict = json.loads(func_args)
except json.JSONDecodeError:
args_dict = {}
else:
args_dict = func_args
# Emit tool usage started event
started_at = datetime.now()
crewai_event_bus.emit(
self,
event=ToolUsageStartedEvent(
tool_name=func_name,
tool_args=args_dict,
from_agent=self.agent,
from_task=self.task,
),
)
# Execute the tool
result = "Tool not found"
if func_name in self._available_functions:
try:
tool_func = self._available_functions[func_name]
result = tool_func(**args_dict)
if not isinstance(result, str):
result = str(result)
except Exception as e:
result = f"Error executing tool: {e}"
# Emit tool usage finished event
crewai_event_bus.emit(
self,
event=ToolUsageFinishedEvent(
output=result,
tool_name=func_name,
tool_args=args_dict,
from_agent=self.agent,
from_task=self.task,
started_at=started_at,
finished_at=datetime.now(),
),
)
# Append tool result message
tool_message: LLMMessage = {
"role": "tool",
"tool_call_id": call_id,
"content": result,
}
self.state.messages.append(tool_message)
# Log the tool execution
if self.agent and self.agent.verbose:
self._printer.print(
content=f"Tool {func_name} executed with result: {result[:200]}...",
color="green",
)
# Inject post-tool reasoning prompt to enforce analysis
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
"role": "user",
"content": reasoning_prompt,
}
self.state.messages.append(reasoning_message)
return "native_tool_completed"
@router(execute_native_tool)
def increment_native_and_continue(self) -> Literal["initialized"]:
"""Increment iteration counter after native tool execution."""
self.state.iterations += 1
return "initialized"
@listen("initialized")
def continue_iteration(self) -> Literal["check_iteration"]:
"""Bridge listener that connects iteration loop back to iteration check."""
@@ -656,14 +375,10 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
@router(or_(initialize_reasoning, continue_iteration))
def check_max_iterations(
self,
) -> Literal[
"force_final_answer", "continue_reasoning", "continue_reasoning_native"
]:
) -> Literal["force_final_answer", "continue_reasoning"]:
"""Check if max iterations reached before proceeding with reasoning."""
if has_reached_max_iterations(self.state.iterations, self.max_iter):
return "force_final_answer"
if self.state.use_native_tools:
return "continue_reasoning_native"
return "continue_reasoning"
@router(execute_tool_action)
@@ -672,7 +387,7 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
self.state.iterations += 1
return "initialized"
@listen(or_("agent_finished", "tool_result_is_final", "native_finished"))
@listen(or_("agent_finished", "tool_result_is_final"))
def finalize(self) -> Literal["completed", "skipped"]:
"""Finalize execution and emit completion logs."""
if self.state.current_answer is None:
@@ -760,8 +475,6 @@ class CrewAgentExecutorFlow(Flow[AgentReActState], CrewAgentExecutorMixin):
self.state.iterations = 0
self.state.current_answer = None
self.state.is_finished = False
self.state.use_native_tools = False
self.state.pending_tool_calls = []
if "system" in self.prompt:
prompt = cast("SystemPromptResult", self.prompt)

View File

@@ -931,6 +931,7 @@ class LLM(BaseLLM):
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
if not tool_calls or not available_functions:
if response_model and self.is_litellm:
instructor_instance = InternalInstructor(
content=full_response,
@@ -1143,12 +1144,8 @@ class LLM(BaseLLM):
if response_model:
params["response_model"] = response_model
response = litellm.completion(**params)
if (
hasattr(response, "usage")
and not isinstance(response.usage, type)
and response.usage
):
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
usage_info = response.usage
self._track_token_usage_internal(usage_info)
@@ -1202,19 +1199,16 @@ class LLM(BaseLLM):
)
return text_response
# --- 6) If there are tool calls but no available functions, return the tool calls
# This allows the caller (e.g., executor) to handle tool execution
if tool_calls and not available_functions:
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
if tool_calls and not available_functions and not text_response:
return tool_calls
# --- 7) Handle tool calls if present (execute when available_functions provided)
if tool_calls and available_functions:
tool_result = self._handle_tool_call(
tool_calls, available_functions, from_task, from_agent
)
if tool_result is not None:
return tool_result
# --- 7) Handle tool calls if present
tool_result = self._handle_tool_call(
tool_calls, available_functions, from_task, from_agent
)
if tool_result is not None:
return tool_result
# --- 8) If tool call handling didn't return a result, emit completion event and return text response
self._handle_emit_call_events(
response=text_response,
@@ -1279,11 +1273,7 @@ class LLM(BaseLLM):
params["response_model"] = response_model
response = await litellm.acompletion(**params)
if (
hasattr(response, "usage")
and not isinstance(response.usage, type)
and response.usage
):
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
usage_info = response.usage
self._track_token_usage_internal(usage_info)
@@ -1331,18 +1321,14 @@ class LLM(BaseLLM):
)
return text_response
# If there are tool calls but no available functions, return the tool calls
# This allows the caller (e.g., executor) to handle tool execution
if tool_calls and not available_functions:
if tool_calls and not available_functions and not text_response:
return tool_calls
# Handle tool calls if present (execute when available_functions provided)
if tool_calls and available_functions:
tool_result = self._handle_tool_call(
tool_calls, available_functions, from_task, from_agent
)
if tool_result is not None:
return tool_result
tool_result = self._handle_tool_call(
tool_calls, available_functions, from_task, from_agent
)
if tool_result is not None:
return tool_result
self._handle_emit_call_events(
response=text_response,
@@ -1377,7 +1363,7 @@ class LLM(BaseLLM):
"""
full_response = ""
chunk_count = 0
usage_info = None
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(

View File

@@ -445,7 +445,7 @@ class BaseLLM(ABC):
from_agent=from_agent,
)
return result
return str(result)
except Exception as e:
error_msg = f"Error executing function '{function_name}': {e!s}"

View File

@@ -418,7 +418,6 @@ class AnthropicCompletion(BaseLLM):
- System messages are separate from conversation messages
- Messages must alternate between user and assistant
- First message must be from user
- Tool results must be in user messages with tool_result content blocks
- When thinking is enabled, assistant messages must start with thinking blocks
Args:
@@ -432,7 +431,6 @@ class AnthropicCompletion(BaseLLM):
formatted_messages: list[LLMMessage] = []
system_message: str | None = None
pending_tool_results: list[dict[str, Any]] = []
for message in base_formatted:
role = message.get("role")
@@ -443,47 +441,16 @@ class AnthropicCompletion(BaseLLM):
system_message += f"\n\n{content}"
else:
system_message = cast(str, content)
elif role == "tool":
# Convert OpenAI-style tool message to Anthropic tool_result format
# These will be collected and added as a user message
tool_call_id = message.get("tool_call_id", "")
tool_result = {
"type": "tool_result",
"tool_use_id": tool_call_id,
"content": content if content else "",
}
pending_tool_results.append(tool_result)
elif role == "assistant":
# First, flush any pending tool results as a user message
if pending_tool_results:
formatted_messages.append(
{"role": "user", "content": pending_tool_results}
)
pending_tool_results = []
else:
role_str = role if role is not None else "user"
# Handle assistant message with tool_calls (convert to Anthropic format)
tool_calls = message.get("tool_calls", [])
if tool_calls:
assistant_content: list[dict[str, Any]] = []
for tc in tool_calls:
if isinstance(tc, dict):
func = tc.get("function", {})
tool_use = {
"type": "tool_use",
"id": tc.get("id", ""),
"name": func.get("name", ""),
"input": json.loads(func.get("arguments", "{}"))
if isinstance(func.get("arguments"), str)
else func.get("arguments", {}),
}
assistant_content.append(tool_use)
if assistant_content:
formatted_messages.append(
{"role": "assistant", "content": assistant_content}
)
elif isinstance(content, list):
formatted_messages.append({"role": "assistant", "content": content})
elif self.thinking and self.previous_thinking_blocks:
if isinstance(content, list):
formatted_messages.append({"role": role_str, "content": content})
elif (
role_str == "assistant"
and self.thinking
and self.previous_thinking_blocks
):
structured_content = cast(
list[dict[str, Any]],
[
@@ -492,34 +459,14 @@ class AnthropicCompletion(BaseLLM):
],
)
formatted_messages.append(
LLMMessage(role="assistant", content=structured_content)
LLMMessage(role=role_str, content=structured_content)
)
else:
content_str = content if content is not None else ""
formatted_messages.append(
LLMMessage(role="assistant", content=content_str)
)
else:
# User message - first flush any pending tool results
if pending_tool_results:
formatted_messages.append(
{"role": "user", "content": pending_tool_results}
)
pending_tool_results = []
role_str = role if role is not None else "user"
if isinstance(content, list):
formatted_messages.append({"role": role_str, "content": content})
else:
content_str = content if content is not None else ""
formatted_messages.append(
LLMMessage(role=role_str, content=content_str)
)
# Flush any remaining pending tool results
if pending_tool_results:
formatted_messages.append({"role": "user", "content": pending_tool_results})
# Ensure first message is from user (Anthropic requirement)
if not formatted_messages:
# If no messages, add a default user message
@@ -579,19 +526,13 @@ class AnthropicCompletion(BaseLLM):
return structured_json
# Check if Claude wants to use tools
if response.content:
if response.content and available_functions:
tool_uses = [
block for block in response.content if isinstance(block, ToolUseBlock)
]
if tool_uses:
# If no available_functions, return tool calls for executor to handle
# This allows the executor to manage tool execution with proper
# message history and post-tool reasoning prompts
if not available_functions:
return list(tool_uses)
# Handle tool use conversation flow internally
# Handle tool use conversation flow
return self._handle_tool_use_conversation(
response,
tool_uses,
@@ -755,7 +696,7 @@ class AnthropicCompletion(BaseLLM):
return structured_json
if final_message.content:
if final_message.content and available_functions:
tool_uses = [
block
for block in final_message.content
@@ -763,11 +704,7 @@ class AnthropicCompletion(BaseLLM):
]
if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions:
return list(tool_uses)
# Handle tool use conversation flow internally
# Handle tool use conversation flow
return self._handle_tool_use_conversation(
final_message,
tool_uses,
@@ -996,16 +933,12 @@ class AnthropicCompletion(BaseLLM):
return structured_json
if response.content:
if response.content and available_functions:
tool_uses = [
block for block in response.content if isinstance(block, ToolUseBlock)
]
if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions:
return list(tool_uses)
return await self._ahandle_tool_use_conversation(
response,
tool_uses,
@@ -1146,7 +1079,7 @@ class AnthropicCompletion(BaseLLM):
return structured_json
if final_message.content:
if final_message.content and available_functions:
tool_uses = [
block
for block in final_message.content
@@ -1154,10 +1087,6 @@ class AnthropicCompletion(BaseLLM):
]
if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions:
return list(tool_uses)
return await self._ahandle_tool_use_conversation(
final_message,
tool_uses,

View File

@@ -514,31 +514,10 @@ class AzureCompletion(BaseLLM):
for message in base_formatted:
role = message.get("role", "user") # Default to user if no role
# Handle None content - Azure requires string content
content = message.get("content") or ""
content = message.get("content", "")
# Handle tool role messages - keep as tool role for Azure OpenAI
if role == "tool":
tool_call_id = message.get("tool_call_id", "unknown")
azure_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}
)
# Handle assistant messages with tool_calls
elif role == "assistant" and message.get("tool_calls"):
tool_calls = message.get("tool_calls", [])
azure_msg: LLMMessage = {
"role": "assistant",
"content": content, # Already defaulted to "" above
"tool_calls": tool_calls,
}
azure_messages.append(azure_msg)
else:
# Azure AI Inference requires both 'role' and 'content'
azure_messages.append({"role": role, "content": content})
# Azure AI Inference requires both 'role' and 'content'
azure_messages.append({"role": role, "content": content})
return azure_messages
@@ -625,11 +604,6 @@ class AzureCompletion(BaseLLM):
from_agent=from_agent,
)
# If there are tool_calls but no available_functions, return the tool_calls
# This allows the caller (e.g., executor) to handle tool execution
if message.tool_calls and not available_functions:
return list(message.tool_calls)
# Handle tool calls
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0] # Handle first tool call
@@ -801,21 +775,6 @@ class AzureCompletion(BaseLLM):
from_agent=from_agent,
)
# If there are tool_calls but no available_functions, return them
# in OpenAI-compatible format for executor to handle
if tool_calls and not available_functions:
return [
{
"id": call_data.get("id", f"call_{idx}"),
"type": "function",
"function": {
"name": call_data["name"],
"arguments": call_data["arguments"],
},
}
for idx, call_data in tool_calls.items()
]
# Handle completed tool calls
if tool_calls and available_functions:
for call_data in tool_calls.values():

View File

@@ -606,17 +606,6 @@ class GeminiCompletion(BaseLLM):
if response.candidates and (self.tools or available_functions):
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
# Collect function call parts
function_call_parts = [
part for part in candidate.content.parts if part.function_call
]
# If there are function calls but no available_functions,
# return them for the executor to handle (like OpenAI/Anthropic)
if function_call_parts and not available_functions:
return function_call_parts
# Otherwise execute the tools internally
for part in candidate.content.parts:
if part.function_call:
function_name = part.function_call.name
@@ -731,7 +720,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | list[dict[str, Any]]:
) -> str:
"""Finalize streaming response with usage tracking, function execution, and events.
Args:
@@ -749,21 +738,6 @@ class GeminiCompletion(BaseLLM):
"""
self._track_token_usage_internal(usage_data)
# If there are function calls but no available_functions,
# return them for the executor to handle
if function_calls and not available_functions:
return [
{
"id": call_data["id"],
"function": {
"name": call_data["name"],
"arguments": json.dumps(call_data["args"]),
},
"type": "function",
}
for call_data in function_calls.values()
]
# Handle completed function calls
if function_calls and available_functions:
for call_data in function_calls.values():

View File

@@ -428,12 +428,6 @@ class OpenAICompletion(BaseLLM):
choice: Choice = response.choices[0]
message = choice.message
# If there are tool_calls but no available_functions, return the tool_calls
# This allows the caller (e.g., executor) to handle tool execution
if message.tool_calls and not available_functions:
return list(message.tool_calls)
# If there are tool_calls and available_functions, execute the tools
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0]
function_name = tool_call.function.name
@@ -731,15 +725,6 @@ class OpenAICompletion(BaseLLM):
choice: Choice = response.choices[0]
message = choice.message
# If there are tool_calls but no available_functions, return the tool_calls
# This allows the caller (e.g., executor) to handle tool execution
if message.tool_calls and not available_functions:
print("--------------------------------")
print("lorenze tool_calls", list(message.tool_calls))
print("--------------------------------")
return list(message.tool_calls)
# If there are tool_calls and available_functions, execute the tools
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0]
function_name = tool_call.function.name

View File

@@ -11,9 +11,6 @@
"role_playing": "You are {role}. {backstory}\nYour personal goal is: {goal}",
"tools": "\nYou ONLY have access to the following tools, and should NEVER make up tools that are not listed here:\n\n{tools}\n\nIMPORTANT: Use the following format in your response:\n\n```\nThought: you should always think about what to do\nAction: the action to take, only one name of [{tool_names}], just the name, exactly as it's written.\nAction Input: the input to the action, just a simple JSON object, enclosed in curly braces, using \" to wrap keys and values.\nObservation: the result of the action\n```\n\nOnce all necessary information is gathered, return the following format:\n\n```\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\n```",
"no_tools": "\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!",
"native_tools": "\nUse available tools to gather information and complete your task.",
"native_task": "\nCurrent Task: {input}\n\nThis is VERY important to you, your job depends on it!",
"post_tool_reasoning": "PAUSE and THINK before responding.\n\nInternally consider (DO NOT output these steps):\n- What key insights did the tool provide?\n- Have I fulfilled ALL requirements from my original instructions (e.g., minimum tool calls, specific sources)?\n- Do I have enough information to fully answer the task?\n\nIF you have NOT met all requirements or need more information: Call another tool now.\n\nIF you have met all requirements and have sufficient information: Provide ONLY your final answer in the format specified by the task's expected output. Do NOT include reasoning steps, analysis sections, or meta-commentary. Just deliver the answer.",
"format": "I MUST either use a tool (use one at time) OR give my best final answer not both at the same time. When responding, I must use the following format:\n\n```\nThought: you should always think about what to do\nAction: the action to take, should be one of [{tool_names}]\nAction Input: the input to the action, dictionary enclosed in curly braces\nObservation: the result of the action\n```\nThis Thought/Action/Action Input/Result can repeat N times. Once I know the final answer, I must return the following format:\n\n```\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described\n\n```",
"final_answer_format": "If you don't need to use any more tools, you must give your best complete final answer, make sure it satisfies the expected criteria, use the EXACT format below:\n\n```\nThought: I now can give a great answer\nFinal Answer: my best complete final answer to the task.\n\n```",
"format_without_tools": "\nSorry, I didn't use the right format. I MUST either use a tool (among the available ones), OR give my best final answer.\nHere is the expected format I must follow:\n\n```\nQuestion: the input question you must answer\nThought: you should always think about what to do\nAction: the action to take, should be one of [{tool_names}]\nAction Input: the input to the action\nObservation: the result of the action\n```\n This Thought/Action/Action Input/Result process can repeat N times. Once I know the final answer, I must return the following format:\n\n```\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described\n\n```",

View File

@@ -108,65 +108,6 @@ def render_text_description_and_args(
return "\n".join(tool_strings)
def convert_tools_to_openai_schema(
tools: Sequence[BaseTool | CrewStructuredTool],
) -> tuple[list[dict[str, Any]], dict[str, Callable[..., Any]]]:
"""Convert CrewAI tools to OpenAI function calling format.
This function converts CrewAI BaseTool and CrewStructuredTool objects
into the OpenAI-compatible tool schema format that can be passed to
LLM providers for native function calling.
Args:
tools: List of CrewAI tool objects to convert.
Returns:
Tuple containing:
- List of OpenAI-format tool schema dictionaries
- Dict mapping tool names to their callable run() methods
Example:
>>> tools = [CalculatorTool(), SearchTool()]
>>> schemas, functions = convert_tools_to_openai_schema(tools)
>>> # schemas can be passed to llm.call(tools=schemas)
>>> # functions can be passed to llm.call(available_functions=functions)
"""
openai_tools: list[dict[str, Any]] = []
available_functions: dict[str, Callable[..., Any]] = {}
for tool in tools:
# Get the JSON schema for tool parameters
parameters: dict[str, Any] = {}
if hasattr(tool, "args_schema") and tool.args_schema is not None:
try:
parameters = tool.args_schema.model_json_schema()
# Remove title and description from schema root as they're redundant
parameters.pop("title", None)
parameters.pop("description", None)
except Exception:
parameters = {}
# Extract original description from formatted description
# BaseTool formats description as "Tool Name: ...\nTool Arguments: ...\nTool Description: {original}"
description = tool.description
if "Tool Description:" in description:
# Extract the original description after "Tool Description:"
description = description.split("Tool Description:")[-1].strip()
schema: dict[str, Any] = {
"type": "function",
"function": {
"name": tool.name,
"description": description,
"parameters": parameters,
},
}
openai_tools.append(schema)
available_functions[tool.name] = tool.run
return openai_tools, available_functions
def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool:
"""Check if the maximum number of iterations has been reached.
@@ -293,13 +234,11 @@ def get_llm_response(
messages: list[LLMMessage],
callbacks: list[TokenCalcHandler],
printer: Printer,
tools: list[dict[str, Any]] | None = None,
available_functions: dict[str, Callable[..., Any]] | None = None,
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | LiteAgent | None = None,
) -> str | Any:
) -> str:
"""Call the LLM and return the response, handling any invalid responses.
Args:
@@ -307,16 +246,13 @@ def get_llm_response(
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
tools: Optional list of tool schemas for native function calling.
available_functions: Optional dict mapping function names to callables.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string, or tool call results if
native function calling is used.
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
@@ -331,9 +267,7 @@ def get_llm_response(
try:
answer = llm.call(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
@@ -355,13 +289,11 @@ async def aget_llm_response(
messages: list[LLMMessage],
callbacks: list[TokenCalcHandler],
printer: Printer,
tools: list[dict[str, Any]] | None = None,
available_functions: dict[str, Callable[..., Any]] | None = None,
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | None = None,
) -> str | Any:
) -> str:
"""Call the LLM asynchronously and return the response.
Args:
@@ -369,16 +301,13 @@ async def aget_llm_response(
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
tools: Optional list of tool schemas for native function calling.
available_functions: Optional dict mapping function names to callables.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string, or tool call results if
native function calling is used.
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
@@ -392,9 +321,7 @@ async def aget_llm_response(
try:
answer = await llm.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,

View File

@@ -22,9 +22,7 @@ class SystemPromptResult(StandardPromptResult):
user: Annotated[str, "The user prompt component"]
COMPONENTS = Literal[
"role_playing", "tools", "no_tools", "native_tools", "task", "native_task"
]
COMPONENTS = Literal["role_playing", "tools", "no_tools", "task"]
class Prompts(BaseModel):
@@ -38,10 +36,6 @@ class Prompts(BaseModel):
has_tools: bool = Field(
default=False, description="Indicates if the agent has access to tools"
)
use_native_tool_calling: bool = Field(
default=False,
description="Whether to use native function calling instead of ReAct format",
)
system_template: str | None = Field(
default=None, description="Custom system prompt template"
)
@@ -64,24 +58,12 @@ class Prompts(BaseModel):
A dictionary containing the constructed prompt(s).
"""
slices: list[COMPONENTS] = ["role_playing"]
# When using native tool calling with tools, use native_tools instructions
# When using ReAct pattern with tools, use tools instructions
# When no tools are available, use no_tools instructions
if self.has_tools:
if self.use_native_tool_calling:
slices.append("native_tools")
else:
slices.append("tools")
slices.append("tools")
else:
slices.append("no_tools")
system: str = self._build_prompt(slices)
# Use native_task for native tool calling (no "Thought:" prompt)
# Use task for ReAct pattern (includes "Thought:" prompt)
task_slice: COMPONENTS = (
"native_task" if self.use_native_tool_calling else "task"
)
slices.append(task_slice)
slices.append("task")
if (
not self.system_template
@@ -90,7 +72,7 @@ class Prompts(BaseModel):
):
return SystemPromptResult(
system=system,
user=self._build_prompt([task_slice]),
user=self._build_prompt(["task"]),
prompt=self._build_prompt(slices),
)
return StandardPromptResult(

View File

@@ -26,9 +26,13 @@ def mock_agent() -> MagicMock:
@pytest.fixture
def mock_task() -> MagicMock:
def mock_task(mock_context: MagicMock) -> MagicMock:
"""Create a mock Task."""
return MagicMock()
task = MagicMock()
task.id = mock_context.task_id
task.name = "Mock Task"
task.description = "Mock task description"
return task
@pytest.fixture
@@ -179,8 +183,8 @@ class TestExecute:
event = first_call[0][1]
assert event.type == "a2a_server_task_started"
assert event.a2a_task_id == mock_context.task_id
assert event.a2a_context_id == mock_context.context_id
assert event.task_id == mock_context.task_id
assert event.context_id == mock_context.context_id
@pytest.mark.asyncio
async def test_emits_completed_event(
@@ -201,7 +205,7 @@ class TestExecute:
event = second_call[0][1]
assert event.type == "a2a_server_task_completed"
assert event.a2a_task_id == mock_context.task_id
assert event.task_id == mock_context.task_id
assert event.result == "Task completed successfully"
@pytest.mark.asyncio
@@ -250,7 +254,7 @@ class TestExecute:
event = canceled_call[0][1]
assert event.type == "a2a_server_task_canceled"
assert event.a2a_task_id == mock_context.task_id
assert event.task_id == mock_context.task_id
class TestCancel:

View File

@@ -14,6 +14,16 @@ except ImportError:
A2A_SDK_INSTALLED = False
def _create_mock_agent_card(name: str = "Test", url: str = "http://test-endpoint.com/"):
"""Create a mock agent card with proper model_dump behavior."""
mock_card = MagicMock()
mock_card.name = name
mock_card.url = url
mock_card.model_dump.return_value = {"name": name, "url": url}
mock_card.model_dump_json.return_value = f'{{"name": "{name}", "url": "{url}"}}'
return mock_card
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
def test_trust_remote_completion_status_true_returns_directly():
"""When trust_remote_completion_status=True and A2A returns completed, return result directly."""
@@ -44,8 +54,7 @@ def test_trust_remote_completion_status_true_returns_directly():
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
):
mock_card = MagicMock()
mock_card.name = "Test"
mock_card = _create_mock_agent_card()
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
# A2A returns completed
@@ -110,8 +119,7 @@ def test_trust_remote_completion_status_false_continues_conversation():
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
):
mock_card = MagicMock()
mock_card.name = "Test"
mock_card = _create_mock_agent_card()
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
# A2A returns completed

View File

@@ -1,479 +0,0 @@
"""Integration tests for native tool calling functionality.
These tests verify that agents can use native function calling
when the LLM supports it, across multiple providers.
"""
from __future__ import annotations
import os
from typing import Any
from unittest.mock import patch, MagicMock
import pytest
from pydantic import BaseModel, Field
from crewai import Agent, Crew, Task
from crewai.llm import LLM
from crewai.tools.base_tool import BaseTool
# Check for optional provider availability
try:
import anthropic
HAS_ANTHROPIC = True
except ImportError:
HAS_ANTHROPIC = False
try:
import google.genai
HAS_GOOGLE_GENAI = True
except ImportError:
HAS_GOOGLE_GENAI = False
try:
import boto3
HAS_BOTO3 = True
except ImportError:
HAS_BOTO3 = False
class CalculatorInput(BaseModel):
"""Input schema for calculator tool."""
expression: str = Field(description="Mathematical expression to evaluate")
class CalculatorTool(BaseTool):
"""A calculator tool that performs mathematical calculations."""
name: str = "calculator"
description: str = "Perform mathematical calculations. Use this for any math operations."
args_schema: type[BaseModel] = CalculatorInput
def _run(self, expression: str) -> str:
"""Execute the calculation."""
try:
# Safe evaluation for basic math
result = eval(expression) # noqa: S307
return f"The result of {expression} is {result}"
except Exception as e:
return f"Error calculating {expression}: {e}"
class WeatherInput(BaseModel):
"""Input schema for weather tool."""
location: str = Field(description="City name to get weather for")
class WeatherTool(BaseTool):
"""A mock weather tool for testing."""
name: str = "get_weather"
description: str = "Get the current weather for a location"
args_schema: type[BaseModel] = WeatherInput
def _run(self, location: str) -> str:
"""Get weather (mock implementation)."""
return f"The weather in {location} is sunny with a temperature of 72°F"
@pytest.fixture
def calculator_tool() -> CalculatorTool:
"""Create a calculator tool for testing."""
return CalculatorTool()
@pytest.fixture
def weather_tool() -> WeatherTool:
"""Create a weather tool for testing."""
return WeatherTool()
# =============================================================================
# OpenAI Provider Tests
# =============================================================================
class TestOpenAINativeToolCalling:
"""Tests for native tool calling with OpenAI models."""
@pytest.mark.vcr()
def test_openai_agent_with_native_tool_calling(
self, calculator_tool: CalculatorTool
) -> None:
"""Test OpenAI agent can use native tool calling."""
agent = Agent(
role="Math Assistant",
goal="Help users with mathematical calculations",
backstory="You are a helpful math assistant.",
tools=[calculator_tool],
llm=LLM(model="gpt-4o-mini"),
verbose=False,
max_iter=3,
)
task = Task(
description="Calculate what is 15 * 8",
expected_output="The result of the calculation",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
assert result.raw is not None
assert "120" in str(result.raw)
def test_openai_agent_kickoff_with_tools_mocked(
self, calculator_tool: CalculatorTool
) -> None:
"""Test OpenAI agent kickoff with mocked LLM call."""
llm = LLM(model="gpt-4o-mini")
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
agent = Agent(
role="Math Assistant",
goal="Calculate math",
backstory="You calculate.",
tools=[calculator_tool],
llm=llm,
verbose=False,
)
task = Task(
description="Calculate 15 * 8",
expected_output="Result",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert mock_call.called
assert result is not None
# =============================================================================
# Anthropic Provider Tests
# =============================================================================
@pytest.mark.skipif(not HAS_ANTHROPIC, reason="anthropic package not installed")
class TestAnthropicNativeToolCalling:
"""Tests for native tool calling with Anthropic models."""
@pytest.fixture(autouse=True)
def mock_anthropic_api_key(self):
"""Mock ANTHROPIC_API_KEY for tests."""
if "ANTHROPIC_API_KEY" not in os.environ:
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
yield
else:
yield
@pytest.mark.vcr()
def test_anthropic_agent_with_native_tool_calling(
self, calculator_tool: CalculatorTool
) -> None:
"""Test Anthropic agent can use native tool calling."""
agent = Agent(
role="Math Assistant",
goal="Help users with mathematical calculations",
backstory="You are a helpful math assistant.",
tools=[calculator_tool],
llm=LLM(model="anthropic/claude-3-5-haiku-20241022"),
verbose=False,
max_iter=3,
)
task = Task(
description="Calculate what is 15 * 8",
expected_output="The result of the calculation",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
assert result.raw is not None
def test_anthropic_agent_kickoff_with_tools_mocked(
self, calculator_tool: CalculatorTool
) -> None:
"""Test Anthropic agent kickoff with mocked LLM call."""
llm = LLM(model="anthropic/claude-3-5-haiku-20241022")
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
agent = Agent(
role="Math Assistant",
goal="Calculate math",
backstory="You calculate.",
tools=[calculator_tool],
llm=llm,
verbose=False,
)
task = Task(
description="Calculate 15 * 8",
expected_output="Result",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert mock_call.called
assert result is not None
# =============================================================================
# Google/Gemini Provider Tests
# =============================================================================
@pytest.mark.skipif(not HAS_GOOGLE_GENAI, reason="google-genai package not installed")
class TestGeminiNativeToolCalling:
"""Tests for native tool calling with Gemini models."""
@pytest.fixture(autouse=True)
def mock_google_api_key(self):
"""Mock GOOGLE_API_KEY for tests."""
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
yield
@pytest.mark.vcr()
def test_gemini_agent_with_native_tool_calling(
self, calculator_tool: CalculatorTool
) -> None:
"""Test Gemini agent can use native tool calling."""
agent = Agent(
role="Math Assistant",
goal="Help users with mathematical calculations",
backstory="You are a helpful math assistant.",
tools=[calculator_tool],
llm=LLM(model="gemini/gemini-2.0-flash-001"),
verbose=False,
max_iter=3,
)
task = Task(
description="Calculate what is 15 * 8",
expected_output="The result of the calculation",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
assert result.raw is not None
def test_gemini_agent_kickoff_with_tools_mocked(
self, calculator_tool: CalculatorTool
) -> None:
"""Test Gemini agent kickoff with mocked LLM call."""
llm = LLM(model="gemini/gemini-2.0-flash-001")
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
agent = Agent(
role="Math Assistant",
goal="Calculate math",
backstory="You calculate.",
tools=[calculator_tool],
llm=llm,
verbose=False,
)
task = Task(
description="Calculate 15 * 8",
expected_output="Result",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert mock_call.called
assert result is not None
# =============================================================================
# Azure Provider Tests
# =============================================================================
class TestAzureNativeToolCalling:
"""Tests for native tool calling with Azure OpenAI models."""
@pytest.fixture(autouse=True)
def mock_azure_env(self):
"""Mock Azure environment variables for tests."""
env_vars = {
"AZURE_API_KEY": "test-key",
"AZURE_API_BASE": "https://test.openai.azure.com",
"AZURE_API_VERSION": "2024-02-15-preview",
}
with patch.dict(os.environ, env_vars):
yield
def test_azure_agent_kickoff_with_tools_mocked(
self, calculator_tool: CalculatorTool
) -> None:
"""Test Azure agent kickoff with mocked LLM call."""
llm = LLM(
model="azure/gpt-4o-mini",
api_key="test-key",
base_url="https://test.openai.azure.com",
)
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
agent = Agent(
role="Math Assistant",
goal="Calculate math",
backstory="You calculate.",
tools=[calculator_tool],
llm=llm,
verbose=False,
)
task = Task(
description="Calculate 15 * 8",
expected_output="Result",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert mock_call.called
assert result is not None
# =============================================================================
# Bedrock Provider Tests
# =============================================================================
@pytest.mark.skipif(not HAS_BOTO3, reason="boto3 package not installed")
class TestBedrockNativeToolCalling:
"""Tests for native tool calling with AWS Bedrock models."""
@pytest.fixture(autouse=True)
def mock_aws_env(self):
"""Mock AWS environment variables for tests."""
env_vars = {
"AWS_ACCESS_KEY_ID": "test-key",
"AWS_SECRET_ACCESS_KEY": "test-secret",
"AWS_REGION": "us-east-1",
}
with patch.dict(os.environ, env_vars):
yield
def test_bedrock_agent_kickoff_with_tools_mocked(
self, calculator_tool: CalculatorTool
) -> None:
"""Test Bedrock agent kickoff with mocked LLM call."""
llm = LLM(model="bedrock/anthropic.claude-3-haiku-20240307-v1:0")
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
agent = Agent(
role="Math Assistant",
goal="Calculate math",
backstory="You calculate.",
tools=[calculator_tool],
llm=llm,
verbose=False,
)
task = Task(
description="Calculate 15 * 8",
expected_output="Result",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert mock_call.called
assert result is not None
# =============================================================================
# Cross-Provider Native Tool Calling Behavior Tests
# =============================================================================
class TestNativeToolCallingBehavior:
"""Tests for native tool calling behavior across providers."""
def test_supports_function_calling_check(self) -> None:
"""Test that supports_function_calling() is properly checked."""
# OpenAI should support function calling
openai_llm = LLM(model="gpt-4o-mini")
assert hasattr(openai_llm, "supports_function_calling")
assert openai_llm.supports_function_calling() is True
@pytest.mark.skipif(not HAS_ANTHROPIC, reason="anthropic package not installed")
def test_anthropic_supports_function_calling(self) -> None:
"""Test that Anthropic models support function calling."""
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
llm = LLM(model="anthropic/claude-3-5-haiku-20241022")
assert hasattr(llm, "supports_function_calling")
assert llm.supports_function_calling() is True
@pytest.mark.skipif(not HAS_GOOGLE_GENAI, reason="google-genai package not installed")
def test_gemini_supports_function_calling(self) -> None:
"""Test that Gemini models support function calling."""
# with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
print("GOOGLE_API_KEY", os.getenv("GOOGLE_API_KEY"))
llm = LLM(model="gemini/gemini-2.5-flash")
assert hasattr(llm, "supports_function_calling")
# Gemini uses supports_tools property
assert llm.supports_function_calling() is True
# =============================================================================
# Token Usage Tests
# =============================================================================
class TestNativeToolCallingTokenUsage:
"""Tests for token usage with native tool calling."""
@pytest.mark.vcr()
def test_openai_native_tool_calling_token_usage(
self, calculator_tool: CalculatorTool
) -> None:
"""Test token usage tracking with OpenAI native tool calling."""
agent = Agent(
role="Calculator",
goal="Perform calculations efficiently",
backstory="You calculate things.",
tools=[calculator_tool],
llm=LLM(model="gpt-4o-mini"),
verbose=False,
max_iter=3,
)
task = Task(
description="What is 100 / 4?",
expected_output="The result",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
assert result.token_usage is not None
assert result.token_usage.total_tokens > 0
assert result.token_usage.successful_requests >= 1
print(f"\n[OPENAI NATIVE TOOL CALLING TOKEN USAGE]")
print(f" Prompt tokens: {result.token_usage.prompt_tokens}")
print(f" Completion tokens: {result.token_usage.completion_tokens}")
print(f" Total tokens: {result.token_usage.total_tokens}")

View File

@@ -1,214 +0,0 @@
"""Tests for agent utility functions."""
from __future__ import annotations
from typing import Any
import pytest
from pydantic import BaseModel, Field
from crewai.tools.base_tool import BaseTool
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
class CalculatorInput(BaseModel):
"""Input schema for calculator tool."""
expression: str = Field(description="Mathematical expression to evaluate")
class CalculatorTool(BaseTool):
"""A simple calculator tool for testing."""
name: str = "calculator"
description: str = "Perform mathematical calculations"
args_schema: type[BaseModel] = CalculatorInput
def _run(self, expression: str) -> str:
"""Execute the calculation."""
try:
result = eval(expression) # noqa: S307
return str(result)
except Exception as e:
return f"Error: {e}"
class SearchInput(BaseModel):
"""Input schema for search tool."""
query: str = Field(description="Search query")
max_results: int = Field(default=10, description="Maximum number of results")
class SearchTool(BaseTool):
"""A search tool for testing."""
name: str = "web_search"
description: str = "Search the web for information"
args_schema: type[BaseModel] = SearchInput
def _run(self, query: str, max_results: int = 10) -> str:
"""Execute the search."""
return f"Search results for '{query}' (max {max_results})"
class NoSchemaTool(BaseTool):
"""A tool without an args schema for testing edge cases."""
name: str = "simple_tool"
description: str = "A simple tool with no schema"
def _run(self, **kwargs: Any) -> str:
"""Execute the tool."""
return "Simple tool executed"
class TestConvertToolsToOpenaiSchema:
"""Tests for convert_tools_to_openai_schema function."""
def test_converts_single_tool(self) -> None:
"""Test converting a single tool to OpenAI schema."""
tools = [CalculatorTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
assert len(schemas) == 1
assert len(functions) == 1
schema = schemas[0]
assert schema["type"] == "function"
assert schema["function"]["name"] == "calculator"
assert schema["function"]["description"] == "Perform mathematical calculations"
assert "properties" in schema["function"]["parameters"]
assert "expression" in schema["function"]["parameters"]["properties"]
def test_converts_multiple_tools(self) -> None:
"""Test converting multiple tools to OpenAI schema."""
tools = [CalculatorTool(), SearchTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
assert len(schemas) == 2
assert len(functions) == 2
# Check calculator
calc_schema = next(s for s in schemas if s["function"]["name"] == "calculator")
assert calc_schema["function"]["description"] == "Perform mathematical calculations"
# Check search
search_schema = next(s for s in schemas if s["function"]["name"] == "web_search")
assert search_schema["function"]["description"] == "Search the web for information"
assert "query" in search_schema["function"]["parameters"]["properties"]
assert "max_results" in search_schema["function"]["parameters"]["properties"]
def test_functions_dict_contains_callables(self) -> None:
"""Test that the functions dict maps names to callable run methods."""
tools = [CalculatorTool(), SearchTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
assert "calculator" in functions
assert "web_search" in functions
assert callable(functions["calculator"])
assert callable(functions["web_search"])
def test_function_can_be_called(self) -> None:
"""Test that the returned function can be called."""
tools = [CalculatorTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
result = functions["calculator"](expression="2 + 2")
assert result == "4"
def test_empty_tools_list(self) -> None:
"""Test with an empty tools list."""
schemas, functions = convert_tools_to_openai_schema([])
assert schemas == []
assert functions == {}
def test_schema_has_required_fields(self) -> None:
"""Test that the schema includes required fields information."""
tools = [SearchTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
schema = schemas[0]
params = schema["function"]["parameters"]
# Should have required array
assert "required" in params
assert "query" in params["required"]
def test_tool_without_args_schema(self) -> None:
"""Test converting a tool that doesn't have an args_schema."""
# Create a minimal tool without args_schema
class MinimalTool(BaseTool):
name: str = "minimal"
description: str = "A minimal tool"
def _run(self) -> str:
return "done"
tools = [MinimalTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
assert len(schemas) == 1
schema = schemas[0]
assert schema["function"]["name"] == "minimal"
# Parameters should be empty dict or have minimal schema
assert isinstance(schema["function"]["parameters"], dict)
def test_schema_structure_matches_openai_format(self) -> None:
"""Test that the schema structure matches OpenAI's expected format."""
tools = [CalculatorTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
schema = schemas[0]
# Top level must have "type": "function"
assert schema["type"] == "function"
# Must have "function" key with nested structure
assert "function" in schema
func = schema["function"]
# Function must have name and description
assert "name" in func
assert "description" in func
assert isinstance(func["name"], str)
assert isinstance(func["description"], str)
# Parameters should be a valid JSON schema
assert "parameters" in func
params = func["parameters"]
assert isinstance(params, dict)
def test_removes_redundant_schema_fields(self) -> None:
"""Test that redundant title and description are removed from parameters."""
tools = [CalculatorTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
params = schemas[0]["function"]["parameters"]
# Title should be removed as it's redundant with function name
assert "title" not in params
def test_preserves_field_descriptions(self) -> None:
"""Test that field descriptions are preserved in the schema."""
tools = [SearchTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
params = schemas[0]["function"]["parameters"]
query_prop = params["properties"]["query"]
# Field description should be preserved
assert "description" in query_prop
assert query_prop["description"] == "Search query"
def test_preserves_default_values(self) -> None:
"""Test that default values are preserved in the schema."""
tools = [SearchTool()]
schemas, functions = convert_tools_to_openai_schema(tools)
params = schemas[0]["function"]["parameters"]
max_results_prop = params["properties"]["max_results"]
# Default value should be preserved
assert "default" in max_results_prop
assert max_results_prop["default"] == 10

8507
uv.lock generated

File diff suppressed because it is too large Load Diff