Files
crewAI/lib/crewai/tests/mcp/test_stdio_transport.py
Lorenze Jay 2e36f06732 feat: enhance StdioTransport to prevent environment variable leakage (#5506)
* feat: enhance StdioTransport to prevent environment variable leakage

- Replaced os.environ.copy() with get_default_environment() to ensure only allowed environment variables are passed to the MCP server.
- Added tests to verify that ambient environment variables do not leak and that user-supplied environment variables can override defaults.

* feat: add environment variable filtering hook to StdioTransport

- Introduced an optional `_env_filter_hook` to allow extensions to modify the environment variables passed to MCP servers, enabling features like credential stripping.
- Updated tests to ensure the filtering hook is applied correctly after merging user-supplied and default environment variables.
2026-05-27 13:38:25 -07:00

125 lines
3.9 KiB
Python

"""Tests for stdio transport."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import crewai.mcp.transports.stdio as stdio_transport_module
from crewai.mcp.transports.stdio import StdioTransport
@pytest.mark.asyncio
async def test_ambient_env_does_not_leak_to_server(monkeypatch):
"""Ambient env vars outside the MCP SDK's default allowlist must not reach the server.
Regression guard: previously StdioTransport did os.environ.copy(), which leaked
every ambient var (COMPANY_SECRET, AWS_*, etc.) into every spawned MCP server.
"""
monkeypatch.setenv("COMPANY_SECRET", "leaked")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "leaked")
transport = StdioTransport(
command="python",
args=["server.py"],
env={"OPENAI_API_KEY": "sk-test"},
)
captured: dict[str, dict[str, str] | None] = {}
fake_ctx = MagicMock()
fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
fake_ctx.__aexit__ = AsyncMock(return_value=None)
def fake_stdio_client(server_params):
captured["env"] = server_params.env
return fake_ctx
with (
patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client),
patch(
"mcp.client.stdio.get_default_environment",
return_value={"PATH": "/usr/bin", "HOME": "/home/user"},
),
):
await transport.connect()
env = captured["env"]
assert env is not None
assert "COMPANY_SECRET" not in env
assert "AWS_SECRET_ACCESS_KEY" not in env
assert env.get("OPENAI_API_KEY") == "sk-test"
assert env.get("PATH") == "/usr/bin"
@pytest.mark.asyncio
async def test_user_env_overrides_default_environment():
"""User-supplied env values must override keys returned by get_default_environment()."""
transport = StdioTransport(
command="python",
args=["server.py"],
env={"PATH": "/custom/bin"},
)
captured: dict[str, dict[str, str] | None] = {}
fake_ctx = MagicMock()
fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
fake_ctx.__aexit__ = AsyncMock(return_value=None)
def fake_stdio_client(server_params):
captured["env"] = server_params.env
return fake_ctx
with (
patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client),
patch(
"mcp.client.stdio.get_default_environment",
return_value={"PATH": "/usr/bin"},
),
):
await transport.connect()
assert captured["env"]["PATH"] == "/custom/bin"
@pytest.mark.asyncio
async def test_env_filter_hook_runs_after_merge():
"""An extension-supplied env_filter_hook must be applied to the final env."""
transport = StdioTransport(
command="python",
args=["server.py"],
env={"OPENAI_API_KEY": "sk-test", "AWS_SECRET_ACCESS_KEY": "should-strip"},
)
captured: dict[str, dict[str, str] | None] = {}
fake_ctx = MagicMock()
fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
fake_ctx.__aexit__ = AsyncMock(return_value=None)
def fake_stdio_client(server_params):
captured["env"] = server_params.env
return fake_ctx
def drop_aws(env):
return {k: v for k, v in env.items() if not k.startswith("AWS_")}
original_hook = stdio_transport_module._env_filter_hook
stdio_transport_module._env_filter_hook = drop_aws
try:
with (
patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client),
patch(
"mcp.client.stdio.get_default_environment",
return_value={"PATH": "/usr/bin"},
),
):
await transport.connect()
finally:
stdio_transport_module._env_filter_hook = original_hook
env = captured["env"]
assert "AWS_SECRET_ACCESS_KEY" not in env
assert env.get("OPENAI_API_KEY") == "sk-test"
assert env.get("PATH") == "/usr/bin"