Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
50d2e4c1b0 fix: inject MCP tools during delegation (fixes #4571)
When an agent with MCP servers configured is used as a sub-agent via
delegation, its MCP tools were not loaded because the Crew's
_prepare_tools() is not called for the delegated-to agent.

This fix adds _inject_mcp_tools() to agent/utils.py and calls it from
prepare_tools(), which is invoked by both execute_task() and
aexecute_task(). MCP tools are now loaded on-demand when the agent has
mcps configured, with deduplication and graceful error handling.

Also adds 11 tests covering:
- MCP tool injection with/without mcps
- Deduplication of existing tools
- Graceful failure handling
- prepare_tools integration
- Full delegation flow

Co-Authored-By: João <joao@crewai.com>
2026-02-23 15:49:58 +00:00
Greyson LaLonde
51754899a2 feat: migrate CLI http client from requests to httpx
Some checks failed
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-02-20 18:21:05 -05:00
Greyson LaLonde
71b4f8402a fix: ensure callbacks are ran/awaited if promise
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
2026-02-20 13:15:50 -05:00
22 changed files with 495 additions and 148 deletions

View File

@@ -38,6 +38,7 @@ dependencies = [
"json5~=0.10.0",
"portalocker~=2.7.0",
"pydantic-settings~=2.10.1",
"httpx~=0.28.1",
"mcp~=1.26.0",
"uv~=0.9.13",
"aiosqlite~=0.21.0",

View File

@@ -273,6 +273,46 @@ def save_last_messages(agent: Agent) -> None:
agent._last_messages = sanitized_messages
def _inject_mcp_tools(agent: Agent, tools: list[BaseTool]) -> list[BaseTool]:
"""Inject MCP tools into the tools list if the agent has MCP servers configured.
This ensures MCP tools are available even when the agent is invoked
outside the normal Crew task-execution flow (e.g. via delegation).
Args:
agent: The agent instance that may have MCP servers configured.
tools: Current list of tools.
Returns:
Updated list of tools with MCP tools added (if any).
"""
mcps = getattr(agent, "mcps", None)
if not mcps:
return tools
if not hasattr(agent, "get_mcp_tools"):
return tools
try:
mcp_tools = agent.get_mcp_tools(mcps=mcps)
if mcp_tools:
# Merge without duplicates based on tool name
existing_names = {tool.name for tool in tools}
for tool in mcp_tools:
if tool.name not in existing_names:
tools.append(tool)
existing_names.add(tool.name)
except Exception:
# Log but don't fail task execution if MCP tool loading fails
agent._logger.log(
"warning",
"Failed to load MCP tools during task execution",
color="yellow",
)
return tools
def prepare_tools(
agent: Agent, tools: list[BaseTool] | None, task: Task
) -> list[BaseTool]:
@@ -286,7 +326,13 @@ def prepare_tools(
Returns:
The list of tools to use.
"""
final_tools = tools or agent.tools or []
final_tools = list(tools or agent.tools or [])
# Inject MCP tools when the agent has MCP servers configured.
# This is needed for delegation scenarios where the Crew's
# _prepare_tools() is not called for the delegated-to agent.
final_tools = _inject_mcp_tools(agent, final_tools)
agent.create_agent_executor(tools=final_tools, task=task)
return final_tools

View File

@@ -6,8 +6,10 @@ and memory management.
from __future__ import annotations
import asyncio
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -736,7 +738,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
] = []
for call_id, func_name, func_args in parsed_calls:
original_tool = original_tools_by_name.get(func_name)
execution_plan.append((call_id, func_name, func_args, original_tool))
execution_plan.append(
(call_id, func_name, func_args, original_tool)
)
self._append_assistant_tool_calls_message(
[
@@ -746,7 +750,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
max_workers = min(8, len(execution_plan))
ordered_results: list[dict[str, Any] | None] = [None] * len(execution_plan)
ordered_results: list[dict[str, Any] | None] = [None] * len(
execution_plan
)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(
@@ -803,7 +809,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return tool_finish
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
reasoning_message = {
"role": "user",
"content": reasoning_prompt,
}
@@ -908,9 +914,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
elif (
should_execute
and original_tool
and getattr(original_tool, "max_usage_count", None) is not None
and getattr(original_tool, "current_usage_count", 0)
>= original_tool.max_usage_count
and (max_count := getattr(original_tool, "max_usage_count", None))
is not None
and getattr(original_tool, "current_usage_count", 0) >= max_count
):
max_usage_reached = True
@@ -989,13 +995,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
and hasattr(original_tool, "cache_function")
and callable(original_tool.cache_function)
):
should_cache = original_tool.cache_function(args_dict, raw_result)
should_cache = original_tool.cache_function(
args_dict, raw_result
)
if should_cache:
self.tools_handler.cache.add(
tool=func_name, input=input_str, output=raw_result
)
result = str(raw_result) if not isinstance(raw_result, str) else raw_result
result = (
str(raw_result) if not isinstance(raw_result, str) else raw_result
)
except Exception as e:
result = f"Error executing tool: {e}"
if self.task:
@@ -1490,7 +1500,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer: Current agent response.
"""
if self.step_callback:
self.step_callback(formatted_answer)
cb_result = self.step_callback(formatted_answer)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
def _append_message(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"

View File

@@ -2,8 +2,8 @@ import time
from typing import TYPE_CHECKING, Any, TypeVar, cast
import webbrowser
import httpx
from pydantic import BaseModel, Field
import requests
from rich.console import Console
from crewai.cli.authentication.utils import validate_jwt_token
@@ -98,7 +98,7 @@ class AuthenticationCommand:
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
"audience": self.oauth2_provider.get_audience(),
}
response = requests.post(
response = httpx.post(
url=self.oauth2_provider.get_authorize_url(),
data=device_code_payload,
timeout=20,
@@ -130,7 +130,7 @@ class AuthenticationCommand:
attempts = 0
while True and attempts < 10:
response = requests.post(
response = httpx.post(
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
)
token_data = response.json()
@@ -149,7 +149,7 @@ class AuthenticationCommand:
return
if token_data["error"] not in ("authorization_pending", "slow_down"):
raise requests.HTTPError(
raise httpx.HTTPError(
token_data.get("error_description") or token_data.get("error")
)

View File

@@ -1,5 +1,6 @@
import requests
from requests.exceptions import JSONDecodeError
import json
import httpx
from rich.console import Console
from crewai.cli.authentication.token import get_auth_token
@@ -30,16 +31,16 @@ class PlusAPIMixin:
console.print("Run 'crewai login' to sign up/login.", style="bold green")
raise SystemExit from None
def _validate_response(self, response: requests.Response) -> None:
def _validate_response(self, response: httpx.Response) -> None:
"""
Handle and display error messages from API responses.
Args:
response (requests.Response): The response from the Plus API
response (httpx.Response): The response from the Plus API
"""
try:
json_response = response.json()
except (JSONDecodeError, ValueError):
except (json.JSONDecodeError, ValueError):
console.print(
"Failed to parse response from Enterprise API failed. Details:",
style="bold red",
@@ -62,7 +63,7 @@ class PlusAPIMixin:
)
raise SystemExit
if not response.ok:
if not response.is_success:
console.print(
"Request to Enterprise API failed. Details:", style="bold red"
)

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, cast
import requests
from requests.exceptions import JSONDecodeError, RequestException
import httpx
from rich.console import Console
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
@@ -47,12 +47,12 @@ class EnterpriseConfigureCommand(BaseCommand):
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
"X-Crewai-Version": get_crewai_version(),
}
response = requests.get(oauth_endpoint, timeout=30, headers=headers)
response = httpx.get(oauth_endpoint, timeout=30, headers=headers)
response.raise_for_status()
try:
oauth_config = response.json()
except JSONDecodeError as e:
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
self._validate_oauth_config(oauth_config)
@@ -62,7 +62,7 @@ class EnterpriseConfigureCommand(BaseCommand):
)
return cast(dict[str, Any], oauth_config)
except RequestException as e:
except httpx.HTTPError as e:
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
except Exception as e:
raise ValueError(f"Error fetching OAuth2 configuration: {e!s}") from e

View File

@@ -1,4 +1,4 @@
from requests import HTTPError
from httpx import HTTPStatusError
from rich.console import Console
from rich.table import Table
@@ -10,11 +10,11 @@ console = Console()
class OrganizationCommand(BaseCommand, PlusAPIMixin):
def __init__(self):
def __init__(self) -> None:
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
def list(self):
def list(self) -> None:
try:
response = self.plus_api_client.get_organizations()
response.raise_for_status()
@@ -33,7 +33,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
table.add_row(org["name"], org["uuid"])
console.print(table)
except HTTPError as e:
except HTTPStatusError as e:
if e.response.status_code == 401:
console.print(
"You are not logged in to any organization. Use 'crewai login' to login.",
@@ -50,7 +50,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
)
raise SystemExit(1) from e
def switch(self, org_id):
def switch(self, org_id: str) -> None:
try:
response = self.plus_api_client.get_organizations()
response.raise_for_status()
@@ -72,7 +72,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
f"Successfully switched to {org['name']} ({org['uuid']})",
style="bold green",
)
except HTTPError as e:
except HTTPStatusError as e:
if e.response.status_code == 401:
console.print(
"You are not logged in to any organization. Use 'crewai login' to login.",
@@ -87,7 +87,7 @@ class OrganizationCommand(BaseCommand, PlusAPIMixin):
console.print(f"Failed to switch organization: {e!s}", style="bold red")
raise SystemExit(1) from e
def current(self):
def current(self) -> None:
settings = Settings()
if settings.org_uuid:
console.print(

View File

@@ -3,7 +3,6 @@ from typing import Any
from urllib.parse import urljoin
import httpx
import requests
from crewai.cli.config import Settings
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
@@ -43,16 +42,16 @@ class PlusAPI:
def _make_request(
self, method: str, endpoint: str, **kwargs: Any
) -> requests.Response:
) -> httpx.Response:
url = urljoin(self.base_url, endpoint)
session = requests.Session()
session.trust_env = False
return session.request(method, url, headers=self.headers, **kwargs)
verify = kwargs.pop("verify", True)
with httpx.Client(trust_env=False, verify=verify) as client:
return client.request(method, url, headers=self.headers, **kwargs)
def login_to_tool_repository(self) -> requests.Response:
def login_to_tool_repository(self) -> httpx.Response:
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
def get_tool(self, handle: str) -> requests.Response:
def get_tool(self, handle: str) -> httpx.Response:
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
async def get_agent(self, handle: str) -> httpx.Response:
@@ -68,7 +67,7 @@ class PlusAPI:
description: str | None,
encoded_file: str,
available_exports: list[dict[str, Any]] | None = None,
) -> requests.Response:
) -> httpx.Response:
params = {
"handle": handle,
"public": is_public,
@@ -79,54 +78,52 @@ class PlusAPI:
}
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
def deploy_by_name(self, project_name: str) -> requests.Response:
def deploy_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> requests.Response:
def deploy_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
def crew_status_by_name(self, project_name: str) -> requests.Response:
def crew_status_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
)
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
def crew_status_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment"
) -> requests.Response:
) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def crew_by_uuid(
self, uuid: str, log_type: str = "deployment"
) -> requests.Response:
def crew_by_uuid(self, uuid: str, log_type: str = "deployment") -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
)
def delete_crew_by_name(self, project_name: str) -> requests.Response:
def delete_crew_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
)
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
def delete_crew_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
def list_crews(self) -> requests.Response:
def list_crews(self) -> httpx.Response:
return self._make_request("GET", self.CREWS_RESOURCE)
def create_crew(self, payload: dict[str, Any]) -> requests.Response:
def create_crew(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
def get_organizations(self) -> requests.Response:
def get_organizations(self) -> httpx.Response:
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
def initialize_trace_batch(self, payload: dict[str, Any]) -> requests.Response:
def initialize_trace_batch(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches",
@@ -136,7 +133,7 @@ class PlusAPI:
def initialize_ephemeral_trace_batch(
self, payload: dict[str, Any]
) -> requests.Response:
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
@@ -145,7 +142,7 @@ class PlusAPI:
def send_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> requests.Response:
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
@@ -155,7 +152,7 @@ class PlusAPI:
def send_ephemeral_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> requests.Response:
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
@@ -165,7 +162,7 @@ class PlusAPI:
def finalize_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> requests.Response:
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
@@ -175,7 +172,7 @@ class PlusAPI:
def finalize_ephemeral_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> requests.Response:
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
@@ -185,7 +182,7 @@ class PlusAPI:
def mark_trace_batch_as_failed(
self, trace_batch_id: str, error_message: str
) -> requests.Response:
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
@@ -193,13 +190,11 @@ class PlusAPI:
timeout=30,
)
def get_triggers(self) -> requests.Response:
def get_triggers(self) -> httpx.Response:
"""Get all available triggers from integrations."""
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
def get_trigger_payload(
self, app_slug: str, trigger_slug: str
) -> requests.Response:
def get_trigger_payload(self, app_slug: str, trigger_slug: str) -> httpx.Response:
"""Get sample payload for a specific trigger."""
return self._make_request(
"GET", f"{self.INTEGRATIONS_RESOURCE}/{app_slug}/{trigger_slug}/payload"

View File

@@ -8,7 +8,7 @@ from typing import Any
import certifi
import click
import requests
import httpx
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
@@ -165,20 +165,20 @@ def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
try:
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
response.raise_for_status()
data = download_data(response)
with open(cache_file, "w") as f:
json.dump(data, f)
return data
except requests.RequestException as e:
with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
response.raise_for_status()
data = download_data(response)
with open(cache_file, "w") as f:
json.dump(data, f)
return data
except httpx.HTTPError as e:
click.secho(f"Error fetching provider data: {e}", fg="red")
except json.JSONDecodeError:
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None
def download_data(response: requests.Response) -> dict[str, Any]:
def download_data(response: httpx.Response) -> dict[str, Any]:
"""Downloads data from a given HTTP response and returns the JSON content.
Args:
@@ -194,7 +194,7 @@ def download_data(response: requests.Response) -> dict[str, Any]:
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as bar:
for chunk in response.iter_content(block_size):
for chunk in response.iter_bytes(block_size):
if chunk:
data_chunks.append(chunk)
bar.update(len(chunk))

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
import inspect
import json
import threading
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -778,7 +780,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
from_cache = cast(bool, execution_result["from_cache"])
original_tool = execution_result["original_tool"]
tool_message: LLMMessage = {
tool_message = {
"role": "tool",
"tool_call_id": call_id,
"name": func_name,
@@ -1358,7 +1360,9 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
formatted_answer: Current agent response.
"""
if self.step_callback:
self.step_callback(formatted_answer)
cb_result = self.step_callback(formatted_answer)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
def _append_message_to_state(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from concurrent.futures import Future
from copy import copy as shallow_copy
import datetime
@@ -624,11 +625,15 @@ class Task(BaseModel):
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
cb_result = self.callback(self.output)
if inspect.isawaitable(cb_result):
await cb_result
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
cb_result = crew.task_callback(self.output)
if inspect.isawaitable(cb_result):
await cb_result
if self.output_file:
content = (
@@ -722,11 +727,15 @@ class Task(BaseModel):
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
cb_result = self.callback(self.output)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
cb_result = crew.task_callback(self.output)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
if self.output_file:
content = (

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
import concurrent.futures
import inspect
import json
import re
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
@@ -501,7 +502,9 @@ def handle_agent_action_core(
- TODO: Remove messages parameter and its usage.
"""
if step_callback:
step_callback(tool_result)
cb_result = step_callback(tool_result)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
formatted_answer.text += f"\nObservation: {tool_result.result}"
formatted_answer.result = tool_result.result

View File

@@ -2,7 +2,7 @@
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
@@ -291,6 +291,46 @@ class TestAsyncAgentExecutor:
assert max_concurrent > 1, f"Expected concurrent execution, max concurrent was {max_concurrent}"
class TestInvokeStepCallback:
"""Tests for _invoke_step_callback with sync and async callbacks."""
def test_invoke_step_callback_with_sync_callback(
self, executor: CrewAgentExecutor
) -> None:
"""Test that a sync step callback is called normally."""
callback = Mock()
executor.step_callback = callback
answer = AgentFinish(thought="thinking", output="test", text="final")
executor._invoke_step_callback(answer)
callback.assert_called_once_with(answer)
def test_invoke_step_callback_with_async_callback(
self, executor: CrewAgentExecutor
) -> None:
"""Test that an async step callback is awaited via asyncio.run."""
async_callback = AsyncMock()
executor.step_callback = async_callback
answer = AgentFinish(thought="thinking", output="test", text="final")
with patch("crewai.agents.crew_agent_executor.asyncio.run") as mock_run:
executor._invoke_step_callback(answer)
async_callback.assert_called_once_with(answer)
mock_run.assert_called_once()
def test_invoke_step_callback_with_none(
self, executor: CrewAgentExecutor
) -> None:
"""Test that no error is raised when step_callback is None."""
executor.step_callback = None
answer = AgentFinish(thought="thinking", output="test", text="final")
# Should not raise
executor._invoke_step_callback(answer)
class TestAsyncLLMResponseHelper:
"""Tests for aget_llm_response helper function."""

View File

@@ -2,7 +2,7 @@ from datetime import datetime, timedelta
from unittest.mock import MagicMock, call, patch
import pytest
import requests
import httpx
from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
@@ -220,7 +220,7 @@ class TestAuthenticationCommand:
]
mock_console_print.assert_has_calls(expected_calls)
@patch("requests.post")
@patch("crewai.cli.authentication.main.httpx.post")
def test_get_device_code(self, mock_post):
mock_response = MagicMock()
mock_response.json.return_value = {
@@ -256,7 +256,7 @@ class TestAuthenticationCommand:
"verification_uri_complete": "https://example.com/auth",
}
@patch("requests.post")
@patch("crewai.cli.authentication.main.httpx.post")
@patch("crewai.cli.authentication.main.console.print")
def test_poll_for_token_success(self, mock_console_print, mock_post):
mock_response_success = MagicMock()
@@ -305,7 +305,7 @@ class TestAuthenticationCommand:
]
mock_console_print.assert_has_calls(expected_calls)
@patch("requests.post")
@patch("crewai.cli.authentication.main.httpx.post")
@patch("crewai.cli.authentication.main.console.print")
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
mock_response_pending = MagicMock()
@@ -324,7 +324,7 @@ class TestAuthenticationCommand:
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
@patch("requests.post")
@patch("crewai.cli.authentication.main.httpx.post")
def test_poll_for_token_error(self, mock_post):
"""Test the method to poll for token (error path)."""
# Setup mock to return error
@@ -338,5 +338,5 @@ class TestAuthenticationCommand:
device_code_data = {"device_code": "test_device_code", "interval": 1}
with pytest.raises(requests.HTTPError):
with pytest.raises(httpx.HTTPError):
self.auth_command._poll_for_token(device_code_data)

View File

@@ -4,10 +4,11 @@ from io import StringIO
from unittest.mock import MagicMock, Mock, patch
import pytest
import requests
import json
import httpx
from crewai.cli.deploy.main import DeployCommand
from crewai.cli.utils import parse_toml
from requests.exceptions import JSONDecodeError
class TestDeployCommand(unittest.TestCase):
@@ -37,18 +38,18 @@ class TestDeployCommand(unittest.TestCase):
DeployCommand()
def test_validate_response_successful_response(self):
mock_response = Mock(spec=requests.Response)
mock_response = Mock(spec=httpx.Response)
mock_response.json.return_value = {"message": "Success"}
mock_response.status_code = 200
mock_response.ok = True
mock_response.is_success = True
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command._validate_response(mock_response)
assert fake_out.getvalue() == ""
def test_validate_response_json_decode_error(self):
mock_response = Mock(spec=requests.Response)
mock_response.json.side_effect = JSONDecodeError("Decode error", "", 0)
mock_response = Mock(spec=httpx.Response)
mock_response.json.side_effect = json.JSONDecodeError("Decode error", "", 0)
mock_response.status_code = 500
mock_response.content = b"Invalid JSON"
@@ -64,13 +65,13 @@ class TestDeployCommand(unittest.TestCase):
assert "Response:\nInvalid JSON" in output
def test_validate_response_422_error(self):
mock_response = Mock(spec=requests.Response)
mock_response = Mock(spec=httpx.Response)
mock_response.json.return_value = {
"field1": ["Error message 1"],
"field2": ["Error message 2"],
}
mock_response.status_code = 422
mock_response.ok = False
mock_response.is_success = False
with patch("sys.stdout", new=StringIO()) as fake_out:
with pytest.raises(SystemExit):
@@ -84,10 +85,10 @@ class TestDeployCommand(unittest.TestCase):
assert "Field2 Error message 2" in output
def test_validate_response_other_error(self):
mock_response = Mock(spec=requests.Response)
mock_response = Mock(spec=httpx.Response)
mock_response.json.return_value = {"error": "Something went wrong"}
mock_response.status_code = 500
mock_response.ok = False
mock_response.is_success = False
with patch("sys.stdout", new=StringIO()) as fake_out:
with pytest.raises(SystemExit):

View File

@@ -3,8 +3,9 @@ import unittest
from pathlib import Path
from unittest.mock import Mock, patch
import requests
from requests.exceptions import JSONDecodeError
import json
import httpx
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
from crewai.cli.settings.main import SettingsCommand
@@ -25,7 +26,7 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.test_dir)
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_successful_configuration(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
@@ -73,19 +74,23 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
self.assertEqual(call_args[0], key)
self.assertEqual(call_args[1], value)
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_http_error_handling(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"404 Not Found",
request=httpx.Request("GET", "http://test"),
response=httpx.Response(404),
)
mock_requests_get.return_value = mock_response
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_invalid_json_response(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
@@ -93,13 +98,13 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
mock_response = Mock()
mock_response.status_code = 200
mock_response.raise_for_status.return_value = None
mock_response.json.side_effect = JSONDecodeError("Invalid JSON", "", 0)
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
mock_requests_get.return_value = mock_response
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_missing_required_fields(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
@@ -115,7 +120,7 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.httpx.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_settings_update_error(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch, call
import pytest
from click.testing import CliRunner
import requests
import httpx
from crewai.cli.organization.main import OrganizationCommand
from crewai.cli.cli import org_list, switch, current
@@ -115,7 +115,7 @@ class TestOrganizationCommand(unittest.TestCase):
def test_list_organizations_api_error(self, mock_console):
self.org_command.plus_api_client = MagicMock()
self.org_command.plus_api_client.get_organizations.side_effect = (
requests.exceptions.RequestException("API Error")
httpx.HTTPError("API Error")
)
with pytest.raises(SystemExit):
@@ -201,8 +201,10 @@ class TestOrganizationCommand(unittest.TestCase):
@patch("crewai.cli.organization.main.console")
def test_list_organizations_unauthorized(self, mock_console):
mock_response = MagicMock()
mock_http_error = requests.exceptions.HTTPError(
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
mock_http_error = httpx.HTTPStatusError(
"401 Client Error: Unauthorized",
request=httpx.Request("GET", "http://test"),
response=httpx.Response(401),
)
mock_response.raise_for_status.side_effect = mock_http_error
@@ -219,8 +221,10 @@ class TestOrganizationCommand(unittest.TestCase):
@patch("crewai.cli.organization.main.console")
def test_switch_organization_unauthorized(self, mock_console):
mock_response = MagicMock()
mock_http_error = requests.exceptions.HTTPError(
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
mock_http_error = httpx.HTTPStatusError(
"401 Client Error: Unauthorized",
request=httpx.Request("GET", "http://test"),
response=httpx.Response(401),
)
mock_response.raise_for_status.side_effect = mock_http_error

View File

@@ -33,9 +33,9 @@ class TestPlusAPI(unittest.TestCase):
self.assertEqual(response, mock_response)
def assert_request_with_org_id(
self, mock_make_request, method: str, endpoint: str, **kwargs
self, mock_client_instance, method: str, endpoint: str, **kwargs
):
mock_make_request.assert_called_once_with(
mock_client_instance.request.assert_called_once_with(
method,
f"{os.getenv('CREWAI_PLUS_URL')}{endpoint}",
headers={
@@ -49,24 +49,25 @@ class TestPlusAPI(unittest.TestCase):
)
@patch("crewai.cli.plus_api.Settings")
@patch("requests.Session.request")
@patch("crewai.cli.plus_api.httpx.Client")
def test_login_to_tool_repository_with_org_uuid(
self, mock_make_request, mock_settings_class
self, mock_client_class, mock_settings_class
):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_make_request.return_value = mock_response
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api.login_to_tool_repository()
self.assert_request_with_org_id(
mock_make_request, "POST", "/crewai_plus/api/v1/tools/login"
mock_client_instance, "POST", "/crewai_plus/api/v1/tools/login"
)
self.assertEqual(response, mock_response)
@@ -82,23 +83,23 @@ class TestPlusAPI(unittest.TestCase):
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("requests.Session.request")
def test_get_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
@patch("crewai.cli.plus_api.httpx.Client")
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
# Set up mock response
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_make_request.return_value = mock_response
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api.get_tool("test_tool_handle")
self.assert_request_with_org_id(
mock_make_request, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
mock_client_instance, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@@ -130,18 +131,18 @@ class TestPlusAPI(unittest.TestCase):
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("requests.Session.request")
def test_publish_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
@patch("crewai.cli.plus_api.httpx.Client")
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
# Set up mock response
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_make_request.return_value = mock_response
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
handle = "test_tool_handle"
public = True
@@ -153,7 +154,6 @@ class TestPlusAPI(unittest.TestCase):
handle, public, version, description, encoded_file
)
# Expected params including organization_uuid
expected_params = {
"handle": handle,
"public": public,
@@ -164,7 +164,7 @@ class TestPlusAPI(unittest.TestCase):
}
self.assert_request_with_org_id(
mock_make_request, "POST", "/crewai_plus/api/v1/tools", json=expected_params
mock_client_instance, "POST", "/crewai_plus/api/v1/tools", json=expected_params
)
self.assertEqual(response, mock_response)
@@ -195,20 +195,19 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.requests.Session")
def test_make_request(self, mock_session):
@patch("crewai.cli.plus_api.httpx.Client")
def test_make_request(self, mock_client_class):
mock_client_instance = MagicMock()
mock_response = MagicMock()
mock_session_instance = mock_session.return_value
mock_session_instance.request.return_value = mock_response
mock_client_instance.request.return_value = mock_response
mock_client_class.return_value.__enter__.return_value = mock_client_instance
response = self.api._make_request("GET", "test_endpoint")
mock_session.assert_called_once()
mock_session_instance.request.assert_called_once_with(
mock_client_class.assert_called_once_with(trust_env=False, verify=True)
mock_client_instance.request.assert_called_once_with(
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
)
mock_session_instance.trust_env = False
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")

View File

@@ -351,7 +351,7 @@ def test_publish_api_error(
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.json.return_value = {"error": "Internal Server Error"}
mock_response.ok = False
mock_response.is_success = False
mock_publish.return_value = mock_response
with raises(SystemExit):

View File

@@ -3,7 +3,7 @@ import subprocess
import unittest
from unittest.mock import Mock, patch
import requests
import httpx
from crewai.cli.triggers.main import TriggersCommand
@@ -21,7 +21,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
def test_list_triggers_success(self, mock_console_print):
mock_response = Mock(spec=requests.Response)
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.ok = True
mock_response.json.return_value = {
@@ -50,7 +50,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
def test_list_triggers_no_apps(self, mock_console_print):
mock_response = Mock(spec=requests.Response)
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.ok = True
mock_response.json.return_value = {"apps": []}
@@ -81,7 +81,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
@patch.object(TriggersCommand, "_run_crew_with_payload")
def test_execute_with_trigger_success(self, mock_run_crew, mock_console_print):
mock_response = Mock(spec=requests.Response)
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.ok = True
mock_response.json.return_value = {
@@ -99,7 +99,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
def test_execute_with_trigger_not_found(self, mock_console_print):
mock_response = Mock(spec=requests.Response)
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 404
mock_response.json.return_value = {"error": "Trigger not found"}
self.mock_client.get_trigger_payload.return_value = mock_response
@@ -159,7 +159,7 @@ class TestTriggersCommand(unittest.TestCase):
@patch("crewai.cli.triggers.main.console.print")
def test_execute_with_trigger_with_default_error_message(self, mock_console_print):
mock_response = Mock(spec=requests.Response)
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 404
mock_response.json.return_value = {}
self.mock_client.get_trigger_payload.return_value = mock_response

View File

@@ -0,0 +1,225 @@
"""Tests for MCP tools loading during delegation (Issue #4571).
When an agent with MCP servers configured is used as a sub-agent via delegation,
its MCP tools must be loaded even though the Crew's _prepare_tools() is not called
for the delegated-to agent.
"""
from unittest.mock import MagicMock, patch
import pytest
from crewai.agent.core import Agent
from crewai.agent.utils import _inject_mcp_tools, prepare_tools
from crewai.mcp.config import MCPServerHTTP
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
def _make_mock_tool(name: str) -> MagicMock:
"""Create a MagicMock that looks like a BaseTool with the given name."""
tool = MagicMock(spec=BaseTool)
tool.name = name
return tool
@pytest.fixture
def http_config():
"""Create a sample MCPServerHTTP configuration."""
return MCPServerHTTP(url="https://api.example.com/mcp")
@pytest.fixture
def sub_agent_with_mcp(http_config):
"""Create an agent with MCP servers configured (the delegated-to agent)."""
return Agent(
role="MCP Sub Agent",
goal="Execute tasks using MCP tools",
backstory="An agent that uses MCP server tools",
mcps=[http_config],
allow_delegation=False,
)
@pytest.fixture
def sub_agent_without_mcp():
"""Create an agent without MCP servers."""
return Agent(
role="Regular Sub Agent",
goal="Execute tasks normally",
backstory="An agent without MCP tools",
allow_delegation=False,
)
class TestInjectMcpTools:
"""Tests for the _inject_mcp_tools helper function."""
def test_injects_mcp_tools_when_agent_has_mcps(self, sub_agent_with_mcp):
"""MCP tools should be injected when agent has mcps configured."""
mock_mcp_tools = [_make_mock_tool("mcp_search"), _make_mock_tool("mcp_fetch")]
with patch.object(Agent, "get_mcp_tools", return_value=mock_mcp_tools):
tools: list[BaseTool] = []
result = _inject_mcp_tools(sub_agent_with_mcp, tools)
assert len(result) == 2
tool_names = {t.name for t in result}
assert "mcp_search" in tool_names
assert "mcp_fetch" in tool_names
def test_does_not_inject_when_agent_has_no_mcps(self, sub_agent_without_mcp):
"""No MCP tools should be injected when agent has no mcps."""
tools: list[BaseTool] = []
result = _inject_mcp_tools(sub_agent_without_mcp, tools)
assert len(result) == 0
def test_does_not_duplicate_existing_mcp_tools(self, sub_agent_with_mcp):
"""MCP tools already in the list should not be duplicated."""
existing_search = _make_mock_tool("mcp_search")
mock_mcp_tools = [_make_mock_tool("mcp_search"), _make_mock_tool("mcp_fetch")]
with patch.object(Agent, "get_mcp_tools", return_value=mock_mcp_tools):
tools = [existing_search]
result = _inject_mcp_tools(sub_agent_with_mcp, tools)
# Should have 2 tools: existing mcp_search + new mcp_fetch
assert len(result) == 2
tool_names = [t.name for t in result]
assert tool_names.count("mcp_search") == 1
assert tool_names.count("mcp_fetch") == 1
def test_preserves_existing_tools(self, sub_agent_with_mcp):
"""Existing non-MCP tools should be preserved after injection."""
mock_mcp_tools = [_make_mock_tool("mcp_search"), _make_mock_tool("mcp_fetch")]
with patch.object(Agent, "get_mcp_tools", return_value=mock_mcp_tools):
existing_tool = _make_mock_tool("existing_tool")
tools = [existing_tool]
result = _inject_mcp_tools(sub_agent_with_mcp, tools)
assert len(result) == 3 # 1 existing + 2 MCP
tool_names = {t.name for t in result}
assert "existing_tool" in tool_names
assert "mcp_search" in tool_names
assert "mcp_fetch" in tool_names
def test_handles_mcp_loading_failure_gracefully(self, sub_agent_with_mcp):
"""If MCP tool loading fails, existing tools should be returned unmodified."""
with patch.object(
Agent, "get_mcp_tools", side_effect=Exception("Connection failed")
):
existing_tool = _make_mock_tool("my_tool")
tools = [existing_tool]
result = _inject_mcp_tools(sub_agent_with_mcp, tools)
assert len(result) == 1
assert result[0].name == "my_tool"
def test_handles_empty_mcp_tools_list(self, sub_agent_with_mcp):
"""If MCP server returns empty tools list, original tools are unchanged."""
with patch.object(Agent, "get_mcp_tools", return_value=[]):
existing_tool = _make_mock_tool("my_tool")
tools = [existing_tool]
result = _inject_mcp_tools(sub_agent_with_mcp, tools)
assert len(result) == 1
assert result[0].name == "my_tool"
def test_handles_agent_with_empty_mcps_list(self):
"""An agent with an empty mcps list should not trigger MCP loading."""
agent = Agent(
role="Agent",
goal="Test",
backstory="Test",
mcps=[],
allow_delegation=False,
)
tools: list[BaseTool] = []
result = _inject_mcp_tools(agent, tools)
assert len(result) == 0
class TestPrepareToolsWithMcp:
"""Tests for prepare_tools function with MCP integration."""
def test_prepare_tools_injects_mcp_when_tools_is_none(
self, sub_agent_with_mcp
):
"""When tools=None (delegation scenario), MCP tools should be loaded."""
task = Task(
description="Test task for delegation",
agent=sub_agent_with_mcp,
expected_output="Test output",
)
mock_mcp_tools = [_make_mock_tool("mcp_search"), _make_mock_tool("mcp_fetch")]
with patch.object(Agent, "get_mcp_tools", return_value=mock_mcp_tools), \
patch.object(Agent, "create_agent_executor"):
result = prepare_tools(sub_agent_with_mcp, None, task)
tool_names = {t.name for t in result}
assert "mcp_search" in tool_names
assert "mcp_fetch" in tool_names
def test_prepare_tools_no_mcp_when_agent_has_no_mcps(
self, sub_agent_without_mcp
):
"""When agent has no mcps, prepare_tools should behave normally."""
task = Task(
description="Test task",
agent=sub_agent_without_mcp,
expected_output="Test output",
)
with patch.object(Agent, "create_agent_executor"):
result = prepare_tools(sub_agent_without_mcp, None, task)
assert len(result) == 0
def test_prepare_tools_merges_explicit_tools_and_mcp(
self, sub_agent_with_mcp
):
"""When explicit tools are passed + agent has mcps, both should be present."""
task = Task(
description="Test task",
agent=sub_agent_with_mcp,
expected_output="Test output",
)
explicit_tool = _make_mock_tool("custom_tool")
mock_mcp_tools = [_make_mock_tool("mcp_search")]
with patch.object(Agent, "get_mcp_tools", return_value=mock_mcp_tools), \
patch.object(Agent, "create_agent_executor"):
result = prepare_tools(sub_agent_with_mcp, [explicit_tool], task)
tool_names = {t.name for t in result}
assert "custom_tool" in tool_names
assert "mcp_search" in tool_names
class TestDelegationWithMcp:
"""Tests for the full delegation flow with MCP-configured sub-agents."""
def test_delegation_tool_loads_mcp_tools_for_sub_agent(
self, sub_agent_with_mcp
):
"""When DelegateWorkTool delegates to an agent with MCPs,
the MCP tools should be loaded during execute_task."""
task = Task(
description="Search for AI news",
agent=sub_agent_with_mcp,
expected_output="AI news results",
)
mock_mcp_tools = [_make_mock_tool("mcp_search")]
with patch.object(Agent, "get_mcp_tools", return_value=mock_mcp_tools), \
patch.object(Agent, "create_agent_executor"), \
patch.object(Agent, "_execute_without_timeout", return_value="Found AI news"):
# Simulate what DelegateWorkTool does: call execute_task with no tools
result = sub_agent_with_mcp.execute_task(task, "context")
assert result == "Found AI news"

2
uv.lock generated
View File

@@ -1096,6 +1096,7 @@ dependencies = [
{ name = "appdirs" },
{ name = "chromadb" },
{ name = "click" },
{ name = "httpx" },
{ name = "instructor" },
{ name = "json-repair" },
{ name = "json5" },
@@ -1195,6 +1196,7 @@ requires-dist = [
{ name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" },
{ name = "docling", marker = "extra == 'docling'", specifier = "~=2.63.0" },
{ name = "google-genai", marker = "extra == 'google-genai'", specifier = "~=1.49.0" },
{ name = "httpx", specifier = "~=0.28.1" },
{ name = "httpx-auth", marker = "extra == 'a2a'", specifier = "~=0.23.1" },
{ name = "httpx-sse", marker = "extra == 'a2a'", specifier = "~=0.4.0" },
{ name = "ibm-watsonx-ai", marker = "extra == 'watson'", specifier = "~=1.3.39" },