mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
* feat: enhance StdioTransport to prevent environment variable leakage - Replaced os.environ.copy() with get_default_environment() to ensure only allowed environment variables are passed to the MCP server. - Added tests to verify that ambient environment variables do not leak and that user-supplied environment variables can override defaults. * feat: add environment variable filtering hook to StdioTransport - Introduced an optional `_env_filter_hook` to allow extensions to modify the environment variables passed to MCP servers, enabling features like credential stripping. - Updated tests to ensure the filtering hook is applied correctly after merging user-supplied and default environment variables.
125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
"""Tests for stdio transport."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
import crewai.mcp.transports.stdio as stdio_transport_module
|
|
from crewai.mcp.transports.stdio import StdioTransport
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ambient_env_does_not_leak_to_server(monkeypatch):
|
|
"""Ambient env vars outside the MCP SDK's default allowlist must not reach the server.
|
|
|
|
Regression guard: previously StdioTransport did os.environ.copy(), which leaked
|
|
every ambient var (COMPANY_SECRET, AWS_*, etc.) into every spawned MCP server.
|
|
"""
|
|
monkeypatch.setenv("COMPANY_SECRET", "leaked")
|
|
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "leaked")
|
|
|
|
transport = StdioTransport(
|
|
command="python",
|
|
args=["server.py"],
|
|
env={"OPENAI_API_KEY": "sk-test"},
|
|
)
|
|
|
|
captured: dict[str, dict[str, str] | None] = {}
|
|
|
|
fake_ctx = MagicMock()
|
|
fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
|
|
fake_ctx.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def fake_stdio_client(server_params):
|
|
captured["env"] = server_params.env
|
|
return fake_ctx
|
|
|
|
with (
|
|
patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client),
|
|
patch(
|
|
"mcp.client.stdio.get_default_environment",
|
|
return_value={"PATH": "/usr/bin", "HOME": "/home/user"},
|
|
),
|
|
):
|
|
await transport.connect()
|
|
|
|
env = captured["env"]
|
|
assert env is not None
|
|
assert "COMPANY_SECRET" not in env
|
|
assert "AWS_SECRET_ACCESS_KEY" not in env
|
|
assert env.get("OPENAI_API_KEY") == "sk-test"
|
|
assert env.get("PATH") == "/usr/bin"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_env_overrides_default_environment():
|
|
"""User-supplied env values must override keys returned by get_default_environment()."""
|
|
transport = StdioTransport(
|
|
command="python",
|
|
args=["server.py"],
|
|
env={"PATH": "/custom/bin"},
|
|
)
|
|
|
|
captured: dict[str, dict[str, str] | None] = {}
|
|
|
|
fake_ctx = MagicMock()
|
|
fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
|
|
fake_ctx.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def fake_stdio_client(server_params):
|
|
captured["env"] = server_params.env
|
|
return fake_ctx
|
|
|
|
with (
|
|
patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client),
|
|
patch(
|
|
"mcp.client.stdio.get_default_environment",
|
|
return_value={"PATH": "/usr/bin"},
|
|
),
|
|
):
|
|
await transport.connect()
|
|
|
|
assert captured["env"]["PATH"] == "/custom/bin"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_env_filter_hook_runs_after_merge():
|
|
"""An extension-supplied env_filter_hook must be applied to the final env."""
|
|
transport = StdioTransport(
|
|
command="python",
|
|
args=["server.py"],
|
|
env={"OPENAI_API_KEY": "sk-test", "AWS_SECRET_ACCESS_KEY": "should-strip"},
|
|
)
|
|
|
|
captured: dict[str, dict[str, str] | None] = {}
|
|
|
|
fake_ctx = MagicMock()
|
|
fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
|
|
fake_ctx.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def fake_stdio_client(server_params):
|
|
captured["env"] = server_params.env
|
|
return fake_ctx
|
|
|
|
def drop_aws(env):
|
|
return {k: v for k, v in env.items() if not k.startswith("AWS_")}
|
|
|
|
original_hook = stdio_transport_module._env_filter_hook
|
|
stdio_transport_module._env_filter_hook = drop_aws
|
|
try:
|
|
with (
|
|
patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client),
|
|
patch(
|
|
"mcp.client.stdio.get_default_environment",
|
|
return_value={"PATH": "/usr/bin"},
|
|
),
|
|
):
|
|
await transport.connect()
|
|
finally:
|
|
stdio_transport_module._env_filter_hook = original_hook
|
|
|
|
env = captured["env"]
|
|
assert "AWS_SECRET_ACCESS_KEY" not in env
|
|
assert env.get("OPENAI_API_KEY") == "sk-test"
|
|
assert env.get("PATH") == "/usr/bin"
|