mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-13 13:08:14 +00:00
Compare commits
1 Commits
main
...
feat/opens
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
516e4fdfc3 |
206
lib/crewai/src/crewai/tools/opensandbox_tool.py
Normal file
206
lib/crewai/src/crewai/tools/opensandbox_tool.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""OpenSandbox tool for CrewAI agents.
|
||||
|
||||
OpenSandbox (https://open-sandbox.ai) is a self-hosted sandbox platform
|
||||
for running shell commands and managing files inside isolated containers.
|
||||
This tool exposes its core operations to CrewAI agents through a single
|
||||
``OpenSandboxTool`` that lazily creates one sandbox per tool instance and
|
||||
reuses it across calls until ``kill`` is invoked.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from datetime import timedelta
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
|
||||
|
||||
class OpenSandboxToolSchema(BaseModel):
|
||||
"""Arguments accepted by ``OpenSandboxTool``."""
|
||||
|
||||
action: Literal["run_command", "read_file", "write_file", "kill"] = Field(
|
||||
description=(
|
||||
"Operation to perform: run_command (execute shell command), "
|
||||
"read_file (read file contents), write_file (write file contents), "
|
||||
"or kill (terminate the sandbox)."
|
||||
),
|
||||
)
|
||||
command: str | None = Field(
|
||||
default=None,
|
||||
description="Shell command to execute. Required when action is 'run_command'.",
|
||||
)
|
||||
path: str | None = Field(
|
||||
default=None,
|
||||
description="Absolute file path. Required for 'read_file' and 'write_file'.",
|
||||
)
|
||||
content: str | None = Field(
|
||||
default=None,
|
||||
description="File content to write. Required when action is 'write_file'.",
|
||||
)
|
||||
|
||||
|
||||
class OpenSandboxTool(BaseTool):
|
||||
"""Run shell commands and manage files inside an OpenSandbox sandbox."""
|
||||
|
||||
name: str = "OpenSandbox"
|
||||
description: str = (
|
||||
"Execute commands and manage files in an isolated OpenSandbox container. "
|
||||
"Useful for running untrusted code, scripting, file I/O, or any work that "
|
||||
"should be isolated from the host. The same sandbox is reused across "
|
||||
"calls; invoke action='kill' to release it."
|
||||
)
|
||||
args_schema: type[BaseModel] = OpenSandboxToolSchema
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="OPENSANDBOX_DOMAIN",
|
||||
description="Host:port of the OpenSandbox server (e.g. 'localhost:8080').",
|
||||
required=True,
|
||||
),
|
||||
EnvVar(
|
||||
name="OPENSANDBOX_PROTOCOL",
|
||||
description="Protocol used to reach the server: 'http' or 'https'.",
|
||||
required=False,
|
||||
default="http",
|
||||
),
|
||||
EnvVar(
|
||||
name="OPENSANDBOX_IMAGE",
|
||||
description="Container image to launch (e.g. 'python:3.12').",
|
||||
required=False,
|
||||
default="python:3.12",
|
||||
),
|
||||
EnvVar(
|
||||
name="OPENSANDBOX_TIMEOUT_MINUTES",
|
||||
description="Sandbox idle timeout in minutes before auto-shutdown.",
|
||||
required=False,
|
||||
default="30",
|
||||
),
|
||||
EnvVar(
|
||||
name="OPENSANDBOX_API_KEY",
|
||||
description="Optional API key if the OpenSandbox server requires auth.",
|
||||
required=False,
|
||||
default=None,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
_sandbox: Any = PrivateAttr(default=None)
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
action = kwargs.get("action")
|
||||
command = kwargs.get("command")
|
||||
path = kwargs.get("path")
|
||||
content = kwargs.get("content")
|
||||
|
||||
if action == "kill":
|
||||
return self._run_async(self._kill())
|
||||
if action == "run_command":
|
||||
if not command:
|
||||
return "Error: 'command' is required when action='run_command'."
|
||||
return self._run_async(self._run_command(command))
|
||||
if action == "read_file":
|
||||
if not path:
|
||||
return "Error: 'path' is required when action='read_file'."
|
||||
return self._run_async(self._read_file(path))
|
||||
if action == "write_file":
|
||||
if not path:
|
||||
return "Error: 'path' is required when action='write_file'."
|
||||
if content is None:
|
||||
return "Error: 'content' is required when action='write_file'."
|
||||
return self._run_async(self._write_file(path, content))
|
||||
return f"Error: unknown action '{action}'."
|
||||
|
||||
@staticmethod
|
||||
def _run_async(coro: Any) -> str:
|
||||
"""Run ``coro`` to completion from a sync context, regardless of loop state."""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
return executor.submit(asyncio.run, coro).result()
|
||||
|
||||
def _build_connection_config(self) -> Any:
|
||||
from opensandbox.config.connection import ConnectionConfig
|
||||
|
||||
domain = (os.getenv("OPENSANDBOX_DOMAIN") or "").strip()
|
||||
if not domain:
|
||||
raise ValueError(
|
||||
"OPENSANDBOX_DOMAIN is not set. Configure it to point at your "
|
||||
"OpenSandbox server (e.g. 'localhost:8080')."
|
||||
)
|
||||
protocol = (os.getenv("OPENSANDBOX_PROTOCOL") or "http").strip()
|
||||
api_key = os.getenv("OPENSANDBOX_API_KEY") or None
|
||||
return ConnectionConfig(domain=domain, protocol=protocol, api_key=api_key)
|
||||
|
||||
async def _ensure_sandbox(self) -> Any:
|
||||
if self._sandbox is not None:
|
||||
return self._sandbox
|
||||
from opensandbox import Sandbox
|
||||
|
||||
image = (os.getenv("OPENSANDBOX_IMAGE") or "python:3.12").strip()
|
||||
timeout_minutes = int(os.getenv("OPENSANDBOX_TIMEOUT_MINUTES") or "30")
|
||||
connection_config = self._build_connection_config()
|
||||
self._sandbox = await Sandbox.create(
|
||||
image,
|
||||
timeout=timedelta(minutes=timeout_minutes),
|
||||
connection_config=connection_config,
|
||||
)
|
||||
return self._sandbox
|
||||
|
||||
async def _run_command(self, command: str) -> str:
|
||||
try:
|
||||
sandbox = await self._ensure_sandbox()
|
||||
execution = await sandbox.commands.run(command)
|
||||
except Exception as exc:
|
||||
return f"OpenSandbox error running command: {exc}"
|
||||
|
||||
stdout = "".join(
|
||||
getattr(item, "text", "") for item in (execution.logs.stdout or [])
|
||||
)
|
||||
stderr = "".join(
|
||||
getattr(item, "text", "") for item in (execution.logs.stderr or [])
|
||||
)
|
||||
parts: list[str] = []
|
||||
if stdout:
|
||||
parts.append(stdout)
|
||||
if stderr:
|
||||
parts.append(f"stderr:\n{stderr}")
|
||||
if getattr(execution, "error", None):
|
||||
parts.append(f"error: {execution.error}")
|
||||
return "\n".join(parts).strip() or "(no output)"
|
||||
|
||||
async def _read_file(self, path: str) -> str:
|
||||
try:
|
||||
sandbox = await self._ensure_sandbox()
|
||||
return await sandbox.files.read_file(path)
|
||||
except Exception as exc:
|
||||
return f"OpenSandbox error reading {path}: {exc}"
|
||||
|
||||
async def _write_file(self, path: str, content: str) -> str:
|
||||
from opensandbox.models import WriteEntry
|
||||
|
||||
try:
|
||||
sandbox = await self._ensure_sandbox()
|
||||
await sandbox.files.write_files(
|
||||
[WriteEntry(path=path, data=content, mode=0o644)]
|
||||
)
|
||||
except Exception as exc:
|
||||
return f"OpenSandbox error writing {path}: {exc}"
|
||||
return f"Wrote {len(content)} bytes to {path}."
|
||||
|
||||
async def _kill(self) -> str:
|
||||
if self._sandbox is None:
|
||||
return "No sandbox to kill."
|
||||
try:
|
||||
await self._sandbox.kill()
|
||||
except Exception as exc:
|
||||
self._sandbox = None
|
||||
return f"OpenSandbox error during kill: {exc}"
|
||||
self._sandbox = None
|
||||
return "Sandbox killed."
|
||||
200
lib/crewai/tests/test_opensandbox_tool.py
Normal file
200
lib/crewai/tests/test_opensandbox_tool.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Tests for ``OpenSandboxTool``.
|
||||
|
||||
These tests mock the underlying ``opensandbox`` SDK so the suite runs
|
||||
without a real OpenSandbox server. The mocks live inside the
|
||||
``crewai.tools.opensandbox_tool`` module so the locally imported names
|
||||
(``Sandbox``, ``ConnectionConfig``, ``WriteEntry``) are intercepted.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.tools.opensandbox_tool import OpenSandboxTool
|
||||
|
||||
|
||||
def _stub_opensandbox_modules() -> None:
|
||||
"""Install stub ``opensandbox`` submodules so deferred imports succeed.
|
||||
|
||||
The tool imports ``opensandbox`` lazily inside its async helpers; tests
|
||||
patch the symbols on those modules directly. We pre-create the module
|
||||
objects so ``patch`` can find an attribute to replace.
|
||||
"""
|
||||
pkg = sys.modules.setdefault("opensandbox", types.ModuleType("opensandbox"))
|
||||
config_pkg = sys.modules.setdefault(
|
||||
"opensandbox.config", types.ModuleType("opensandbox.config")
|
||||
)
|
||||
connection_mod = sys.modules.setdefault(
|
||||
"opensandbox.config.connection",
|
||||
types.ModuleType("opensandbox.config.connection"),
|
||||
)
|
||||
models_mod = sys.modules.setdefault(
|
||||
"opensandbox.models", types.ModuleType("opensandbox.models")
|
||||
)
|
||||
|
||||
if not hasattr(pkg, "Sandbox"):
|
||||
pkg.Sandbox = MagicMock()
|
||||
if not hasattr(connection_mod, "ConnectionConfig"):
|
||||
connection_mod.ConnectionConfig = MagicMock()
|
||||
if not hasattr(config_pkg, "connection"):
|
||||
config_pkg.connection = connection_mod
|
||||
if not hasattr(models_mod, "WriteEntry"):
|
||||
models_mod.WriteEntry = MagicMock()
|
||||
|
||||
|
||||
_stub_opensandbox_modules()
|
||||
|
||||
|
||||
def _make_execution(stdout: str = "", stderr: str = "", error: str | None = None) -> SimpleNamespace:
|
||||
stdout_items = [SimpleNamespace(text=stdout)] if stdout else []
|
||||
stderr_items = [SimpleNamespace(text=stderr)] if stderr else []
|
||||
return SimpleNamespace(
|
||||
logs=SimpleNamespace(stdout=stdout_items, stderr=stderr_items),
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env(monkeypatch):
|
||||
monkeypatch.setenv("OPENSANDBOX_DOMAIN", "localhost:8080")
|
||||
monkeypatch.setenv("OPENSANDBOX_PROTOCOL", "http")
|
||||
monkeypatch.setenv("OPENSANDBOX_IMAGE", "python:3.12")
|
||||
monkeypatch.setenv("OPENSANDBOX_TIMEOUT_MINUTES", "30")
|
||||
monkeypatch.delenv("OPENSANDBOX_API_KEY", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_sandbox():
|
||||
sandbox = MagicMock()
|
||||
sandbox.commands.run = AsyncMock(return_value=_make_execution(stdout="hello\n"))
|
||||
sandbox.files.read_file = AsyncMock(return_value="file contents")
|
||||
sandbox.files.write_files = AsyncMock(return_value=None)
|
||||
sandbox.kill = AsyncMock(return_value=None)
|
||||
return sandbox
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_sandbox_create(fake_sandbox):
|
||||
with patch("opensandbox.Sandbox") as sandbox_cls:
|
||||
sandbox_cls.create = AsyncMock(return_value=fake_sandbox)
|
||||
yield sandbox_cls, fake_sandbox
|
||||
|
||||
|
||||
def test_missing_domain_raises_value_error(monkeypatch):
|
||||
monkeypatch.delenv("OPENSANDBOX_DOMAIN", raising=False)
|
||||
tool = OpenSandboxTool()
|
||||
with pytest.raises(ValueError, match="OPENSANDBOX_DOMAIN is not set"):
|
||||
tool._build_connection_config()
|
||||
|
||||
|
||||
def test_run_command_returns_stdout(env, patched_sandbox_create):
|
||||
_, fake = patched_sandbox_create
|
||||
tool = OpenSandboxTool()
|
||||
result = tool.run(action="run_command", command="echo hello")
|
||||
assert "hello" in result
|
||||
fake.commands.run.assert_awaited_once_with("echo hello")
|
||||
|
||||
|
||||
def test_run_command_includes_stderr_and_error(env, fake_sandbox):
|
||||
fake_sandbox.commands.run = AsyncMock(
|
||||
return_value=_make_execution(stdout="ok", stderr="warn", error="boom")
|
||||
)
|
||||
with patch("opensandbox.Sandbox") as sandbox_cls:
|
||||
sandbox_cls.create = AsyncMock(return_value=fake_sandbox)
|
||||
tool = OpenSandboxTool()
|
||||
result = tool.run(action="run_command", command="do_thing")
|
||||
assert "ok" in result
|
||||
assert "stderr:\nwarn" in result
|
||||
assert "error: boom" in result
|
||||
|
||||
|
||||
def test_run_command_requires_command(env):
|
||||
tool = OpenSandboxTool()
|
||||
result = tool.run(action="run_command")
|
||||
assert "command" in result.lower()
|
||||
assert "required" in result.lower()
|
||||
|
||||
|
||||
def test_read_file(env, patched_sandbox_create):
|
||||
_, fake = patched_sandbox_create
|
||||
tool = OpenSandboxTool()
|
||||
result = tool.run(action="read_file", path="/tmp/foo.txt")
|
||||
assert result == "file contents"
|
||||
fake.files.read_file.assert_awaited_once_with("/tmp/foo.txt")
|
||||
|
||||
|
||||
def test_read_file_requires_path(env):
|
||||
tool = OpenSandboxTool()
|
||||
result = tool.run(action="read_file")
|
||||
assert "path" in result.lower()
|
||||
assert "required" in result.lower()
|
||||
|
||||
|
||||
def test_write_file(env, patched_sandbox_create):
|
||||
_, fake = patched_sandbox_create
|
||||
with patch("opensandbox.models.WriteEntry") as write_entry:
|
||||
write_entry.side_effect = lambda **kwargs: kwargs
|
||||
tool = OpenSandboxTool()
|
||||
result = tool.run(
|
||||
action="write_file", path="/tmp/foo.txt", content="hello"
|
||||
)
|
||||
assert "Wrote" in result
|
||||
assert "/tmp/foo.txt" in result
|
||||
fake.files.write_files.assert_awaited_once()
|
||||
entries = fake.files.write_files.await_args.args[0]
|
||||
assert entries[0]["path"] == "/tmp/foo.txt"
|
||||
assert entries[0]["data"] == "hello"
|
||||
|
||||
|
||||
def test_write_file_requires_path_and_content(env):
|
||||
tool = OpenSandboxTool()
|
||||
no_path = tool.run(action="write_file", content="x")
|
||||
assert "path" in no_path.lower()
|
||||
no_content = tool.run(action="write_file", path="/tmp/x")
|
||||
assert "content" in no_content.lower()
|
||||
|
||||
|
||||
def test_kill_when_no_sandbox(env):
|
||||
tool = OpenSandboxTool()
|
||||
result = tool.run(action="kill")
|
||||
assert result == "No sandbox to kill."
|
||||
|
||||
|
||||
def test_kill_after_use(env, patched_sandbox_create):
|
||||
_, fake = patched_sandbox_create
|
||||
tool = OpenSandboxTool()
|
||||
tool.run(action="run_command", command="echo hi")
|
||||
result = tool.run(action="kill")
|
||||
assert "killed" in result.lower()
|
||||
fake.kill.assert_awaited_once()
|
||||
assert tool._sandbox is None
|
||||
|
||||
|
||||
def test_sandbox_reused_across_calls(env, patched_sandbox_create):
|
||||
sandbox_cls, fake = patched_sandbox_create
|
||||
tool = OpenSandboxTool()
|
||||
tool.run(action="run_command", command="echo 1")
|
||||
tool.run(action="read_file", path="/tmp/x")
|
||||
assert sandbox_cls.create.await_count == 1
|
||||
|
||||
|
||||
def test_run_command_wraps_sdk_exception(env, fake_sandbox):
|
||||
fake_sandbox.commands.run = AsyncMock(side_effect=RuntimeError("network down"))
|
||||
with patch("opensandbox.Sandbox") as sandbox_cls:
|
||||
sandbox_cls.create = AsyncMock(return_value=fake_sandbox)
|
||||
tool = OpenSandboxTool()
|
||||
result = tool.run(action="run_command", command="echo hi")
|
||||
assert "OpenSandbox error" in result
|
||||
assert "network down" in result
|
||||
|
||||
|
||||
def test_unknown_action(env):
|
||||
tool = OpenSandboxTool()
|
||||
# Bypass schema validation by calling _run directly with an invalid action.
|
||||
result = tool._run(action="not_a_real_action")
|
||||
assert "unknown action" in result.lower()
|
||||
Reference in New Issue
Block a user