mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 05:38:12 +00:00
feat: enhance StdioTransport to prevent environment variable leakage (#5506)
* feat: enhance StdioTransport to prevent environment variable leakage - Replaced os.environ.copy() with get_default_environment() to ensure only allowed environment variables are passed to the MCP server. - Added tests to verify that ambient environment variables do not leak and that user-supplied environment variables can override defaults. * feat: add environment variable filtering hook to StdioTransport - Introduced an optional `_env_filter_hook` to allow extensions to modify the environment variables passed to MCP servers, enabling features like credential stripping. - Updated tests to ensure the filtering hook is applied correctly after merging user-supplied and default environment variables.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Stdio transport for MCP servers running as local processes."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
@@ -10,6 +10,16 @@ from typing_extensions import Self
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
|
||||
|
||||
_env_filter_hook: Callable[[dict[str, str]], dict[str, str]] | None = None
|
||||
"""Optional hook to post-process the environment passed to stdio MCP servers.
|
||||
|
||||
Extensions (e.g., enterprise policy) can set this to enforce org-wide rules such
|
||||
as stripping credentials from `env` before the subprocess is spawned. The hook
|
||||
receives the merged env (SDK defaults + user-supplied `env=`) and returns the
|
||||
filtered env. Set to None to disable.
|
||||
"""
|
||||
|
||||
|
||||
class StdioTransport(BaseTransport):
|
||||
"""Stdio transport for connecting to local MCP servers.
|
||||
|
||||
@@ -71,15 +81,18 @@ class StdioTransport(BaseTransport):
|
||||
|
||||
try:
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.stdio import get_default_environment, stdio_client
|
||||
|
||||
process_env = os.environ.copy()
|
||||
process_env = get_default_environment()
|
||||
process_env.update(self.env)
|
||||
|
||||
if _env_filter_hook is not None:
|
||||
process_env = _env_filter_hook(process_env)
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command,
|
||||
args=self.args,
|
||||
env=process_env if process_env else None,
|
||||
env=process_env,
|
||||
)
|
||||
self._transport_context = stdio_client(server_params)
|
||||
|
||||
|
||||
124
lib/crewai/tests/mcp/test_stdio_transport.py
Normal file
124
lib/crewai/tests/mcp/test_stdio_transport.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Tests for stdio transport."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import crewai.mcp.transports.stdio as stdio_transport_module
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambient_env_does_not_leak_to_server(monkeypatch):
|
||||
"""Ambient env vars outside the MCP SDK's default allowlist must not reach the server.
|
||||
|
||||
Regression guard: previously StdioTransport did os.environ.copy(), which leaked
|
||||
every ambient var (COMPANY_SECRET, AWS_*, etc.) into every spawned MCP server.
|
||||
"""
|
||||
monkeypatch.setenv("COMPANY_SECRET", "leaked")
|
||||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "leaked")
|
||||
|
||||
transport = StdioTransport(
|
||||
command="python",
|
||||
args=["server.py"],
|
||||
env={"OPENAI_API_KEY": "sk-test"},
|
||||
)
|
||||
|
||||
captured: dict[str, dict[str, str] | None] = {}
|
||||
|
||||
fake_ctx = MagicMock()
|
||||
fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
|
||||
fake_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
def fake_stdio_client(server_params):
|
||||
captured["env"] = server_params.env
|
||||
return fake_ctx
|
||||
|
||||
with (
|
||||
patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client),
|
||||
patch(
|
||||
"mcp.client.stdio.get_default_environment",
|
||||
return_value={"PATH": "/usr/bin", "HOME": "/home/user"},
|
||||
),
|
||||
):
|
||||
await transport.connect()
|
||||
|
||||
env = captured["env"]
|
||||
assert env is not None
|
||||
assert "COMPANY_SECRET" not in env
|
||||
assert "AWS_SECRET_ACCESS_KEY" not in env
|
||||
assert env.get("OPENAI_API_KEY") == "sk-test"
|
||||
assert env.get("PATH") == "/usr/bin"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_env_overrides_default_environment():
|
||||
"""User-supplied env values must override keys returned by get_default_environment()."""
|
||||
transport = StdioTransport(
|
||||
command="python",
|
||||
args=["server.py"],
|
||||
env={"PATH": "/custom/bin"},
|
||||
)
|
||||
|
||||
captured: dict[str, dict[str, str] | None] = {}
|
||||
|
||||
fake_ctx = MagicMock()
|
||||
fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
|
||||
fake_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
def fake_stdio_client(server_params):
|
||||
captured["env"] = server_params.env
|
||||
return fake_ctx
|
||||
|
||||
with (
|
||||
patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client),
|
||||
patch(
|
||||
"mcp.client.stdio.get_default_environment",
|
||||
return_value={"PATH": "/usr/bin"},
|
||||
),
|
||||
):
|
||||
await transport.connect()
|
||||
|
||||
assert captured["env"]["PATH"] == "/custom/bin"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_env_filter_hook_runs_after_merge():
|
||||
"""An extension-supplied env_filter_hook must be applied to the final env."""
|
||||
transport = StdioTransport(
|
||||
command="python",
|
||||
args=["server.py"],
|
||||
env={"OPENAI_API_KEY": "sk-test", "AWS_SECRET_ACCESS_KEY": "should-strip"},
|
||||
)
|
||||
|
||||
captured: dict[str, dict[str, str] | None] = {}
|
||||
|
||||
fake_ctx = MagicMock()
|
||||
fake_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
|
||||
fake_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
def fake_stdio_client(server_params):
|
||||
captured["env"] = server_params.env
|
||||
return fake_ctx
|
||||
|
||||
def drop_aws(env):
|
||||
return {k: v for k, v in env.items() if not k.startswith("AWS_")}
|
||||
|
||||
original_hook = stdio_transport_module._env_filter_hook
|
||||
stdio_transport_module._env_filter_hook = drop_aws
|
||||
try:
|
||||
with (
|
||||
patch("mcp.client.stdio.stdio_client", side_effect=fake_stdio_client),
|
||||
patch(
|
||||
"mcp.client.stdio.get_default_environment",
|
||||
return_value={"PATH": "/usr/bin"},
|
||||
),
|
||||
):
|
||||
await transport.connect()
|
||||
finally:
|
||||
stdio_transport_module._env_filter_hook = original_hook
|
||||
|
||||
env = captured["env"]
|
||||
assert "AWS_SECRET_ACCESS_KEY" not in env
|
||||
assert env.get("OPENAI_API_KEY") == "sk-test"
|
||||
assert env.get("PATH") == "/usr/bin"
|
||||
Reference in New Issue
Block a user