From 699088914b4e92ca99d96a27d7c212d5ccb02927 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Thu, 16 Apr 2026 11:11:55 -0700 Subject: [PATCH] feat: enhance StdioTransport to prevent environment variable leakage - Replaced os.environ.copy() with get_default_environment() to ensure only allowed environment variables are passed to the MCP server. - Added tests to verify that ambient environment variables do not leak and that user-supplied environment variables can override defaults. --- lib/crewai/src/crewai/mcp/transports/stdio.py | 7 +- lib/crewai/tests/mcp/test_stdio_transport.py | 81 +++++++++++++++++++ 2 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 lib/crewai/tests/mcp/test_stdio_transport.py diff --git a/lib/crewai/src/crewai/mcp/transports/stdio.py b/lib/crewai/src/crewai/mcp/transports/stdio.py index d609daf1d..e7bd69857 100644 --- a/lib/crewai/src/crewai/mcp/transports/stdio.py +++ b/lib/crewai/src/crewai/mcp/transports/stdio.py @@ -1,7 +1,6 @@ """Stdio transport for MCP servers running as local processes.""" import asyncio -import os import subprocess from typing import Any @@ -71,15 +70,15 @@ class StdioTransport(BaseTransport): try: from mcp import StdioServerParameters - from mcp.client.stdio import stdio_client + from mcp.client.stdio import get_default_environment, stdio_client - process_env = os.environ.copy() + process_env = get_default_environment() process_env.update(self.env) server_params = StdioServerParameters( command=self.command, args=self.args, - env=process_env if process_env else None, + env=process_env, ) self._transport_context = stdio_client(server_params) diff --git a/lib/crewai/tests/mcp/test_stdio_transport.py b/lib/crewai/tests/mcp/test_stdio_transport.py new file mode 100644 index 000000000..5326566e5 --- /dev/null +++ b/lib/crewai/tests/mcp/test_stdio_transport.py @@ -0,0 +1,81 @@ +"""Tests for stdio transport.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from crewai.mcp.transports.stdio import StdioTransport + + +@pytest.mark.asyncio +async def test_ambient_env_does_not_leak_to_server(monkeypatch): + """Ambient env vars outside the MCP SDK's default allowlist must not reach the server. + + Regression guard: previously StdioTransport did os.environ.copy(), which leaked + every ambient var (COMPANY_SECRET, AWS_*, etc.) into every spawned MCP server. + """ + monkeypatch.setenv("COMPANY_SECRET", "leaked") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "leaked") + + transport = StdioTransport( + command="python", + args=["server.py"], + env={"OPENAI_API_KEY": "sk-test"}, + ) + + captured: dict[str, dict[str, str] | None] = {} + + fake_ctx = MagicMock() + fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock())) + fake_ctx.__aexit__ = AsyncMock(return_value=None) + + def fake_stdio_client(server_params): + captured["env"] = server_params.env + return fake_ctx + + with ( + patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client), + patch( + "mcp.client.stdio.get_default_environment", + return_value={"PATH": "/usr/bin", "HOME": "/home/user"}, + ), + ): + await transport.connect() + + env = captured["env"] + assert env is not None + assert "COMPANY_SECRET" not in env + assert "AWS_SECRET_ACCESS_KEY" not in env + assert env.get("OPENAI_API_KEY") == "sk-test" + assert env.get("PATH") == "/usr/bin" + + +@pytest.mark.asyncio +async def test_user_env_overrides_default_environment(): + """User-supplied env values must override keys returned by get_default_environment().""" + transport = StdioTransport( + command="python", + args=["server.py"], + env={"PATH": "/custom/bin"}, + ) + + captured: dict[str, dict[str, str] | None] = {} + + fake_ctx = MagicMock() + fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock())) + fake_ctx.__aexit__ = AsyncMock(return_value=None) + + def fake_stdio_client(server_params): + captured["env"] = server_params.env + return fake_ctx + + with ( + patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client), + patch( + "mcp.client.stdio.get_default_environment", + return_value={"PATH": "/usr/bin"}, + ), + ): + await transport.connect() + + assert captured["env"]["PATH"] == "/custom/bin"