Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions src/seclab_taskflows/mcp_servers/container_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: MIT

import atexit
import hashlib
import json
import logging
import os
import subprocess
Expand All @@ -26,8 +28,63 @@
CONTAINER_IMAGE = os.environ.get("CONTAINER_IMAGE", "")
CONTAINER_WORKSPACE = os.environ.get("CONTAINER_WORKSPACE", "")
CONTAINER_TIMEOUT = int(os.environ.get("CONTAINER_TIMEOUT", "30"))
CONTAINER_PERSIST = os.environ.get("CONTAINER_PERSIST", "").lower() in ("1", "true", "yes")
CONTAINER_PERSIST_KEY = os.environ.get("CONTAINER_PERSIST_KEY", "")

_DEFAULT_WORKDIR = "/workspace"
_DOCKER_TIMEOUT = 30


def _persistent_name() -> str:
"""Derive a deterministic container name from the image for reuse across tasks.

Incorporates a hash of the full image reference (and optional
CONTAINER_PERSIST_KEY) to avoid collisions between long image names that
share a common prefix, or between independent runs of the same image.
"""
key_material = CONTAINER_IMAGE
if CONTAINER_PERSIST_KEY:
key_material += f":{CONTAINER_PERSIST_KEY}"
digest = hashlib.sha256(key_material.encode()).hexdigest()[:12]
return f"seclab-persist-{digest}"


def _is_running(name: str) -> bool:
"""Check if a container with the given name is already running."""
try:
result = subprocess.run(
["docker", "inspect", "--format", "json", name],
capture_output=True,
text=True,
timeout=_DOCKER_TIMEOUT,
)
if result.returncode != 0:
return False
data = json.loads(result.stdout)
return bool(data and data[0].get("State", {}).get("Running"))
except (subprocess.TimeoutExpired, json.JSONDecodeError, IndexError):
return False


def _remove_container(name: str) -> None:
"""Remove a stopped container by name. Logs failures for diagnostics.

Uses ``docker rm`` (without -f) so that running containers are NOT
killed — only genuinely stopped leftovers are cleaned up.
"""
try:
result = subprocess.run(
["docker", "rm", name],
capture_output=True,
text=True,
timeout=_DOCKER_TIMEOUT,
)
if result.returncode != 0:
logging.debug(
"docker rm skipped for %s: %s", name, result.stderr.strip()
)
except subprocess.TimeoutExpired:
logging.exception("docker rm timed out for %s after %ds", name, _DOCKER_TIMEOUT)


def _start_container() -> str:
Expand All @@ -38,25 +95,41 @@ def _start_container() -> str:
if CONTAINER_WORKSPACE and ":" in CONTAINER_WORKSPACE:
msg = f"CONTAINER_WORKSPACE must not contain a colon: {CONTAINER_WORKSPACE!r}"
raise RuntimeError(msg)
name = f"seclab-shell-{uuid.uuid4().hex[:8]}"
cmd = ["docker", "run", "-d", "--rm", "--name", name]

if CONTAINER_PERSIST:
name = _persistent_name()
if _is_running(name):
logging.debug(f"Reusing persistent container: {name}")
return name
# Remove stopped leftover with the same name
_remove_container(name)
else:
name = f"seclab-shell-{uuid.uuid4().hex[:8]}"

cmd = ["docker", "run", "-d", "--name", name]
if not CONTAINER_PERSIST:
cmd.append("--rm")
if CONTAINER_WORKSPACE:
cmd += ["-v", f"{CONTAINER_WORKSPACE}:/workspace"]
cmd += [CONTAINER_IMAGE, "tail", "-f", "/dev/null"]
logging.debug(f"Starting container: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
result = subprocess.run(cmd, capture_output=True, text=True, timeout=_DOCKER_TIMEOUT)
if result.returncode != 0:
msg = f"docker run failed: {result.stderr.strip()}"
raise RuntimeError(msg)
logging.debug(f"Container started: {name}")
logging.debug(f"Container started: {name} (persist={CONTAINER_PERSIST})")
return name


def _stop_container() -> None:
"""Stop the running container."""
"""Stop the running container (skipped for persistent containers)."""
global _container_name
if _container_name is None:
return
if CONTAINER_PERSIST:
logging.debug(f"Leaving persistent container running: {_container_name}")
_container_name = None
return
logging.debug(f"Stopping container: {_container_name}")
result = subprocess.run(
["docker", "stop", "--time", "5", _container_name],
Expand Down
2 changes: 2 additions & 0 deletions src/seclab_taskflows/toolboxes/container_shell_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ server_params:
CONTAINER_IMAGE: "seclab-shell-base:latest"
CONTAINER_WORKSPACE: "{{ env('CONTAINER_WORKSPACE', required=False) }}"
CONTAINER_TIMEOUT: "{{ env('CONTAINER_TIMEOUT', '30') }}"
CONTAINER_PERSIST: "{{ env('CONTAINER_PERSIST', required=False) }}"
CONTAINER_PERSIST_KEY: "{{ env('CONTAINER_PERSIST_KEY', required=False) }}"
LOG_DIR: "{{ env('LOG_DIR') }}"

confirm:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ server_params:
CONTAINER_IMAGE: "seclab-shell-malware-analysis:latest"
CONTAINER_WORKSPACE: "{{ env('CONTAINER_WORKSPACE', required=False) }}"
CONTAINER_TIMEOUT: "{{ env('CONTAINER_TIMEOUT', '60') }}"
CONTAINER_PERSIST: "{{ env('CONTAINER_PERSIST', required=False) }}"
CONTAINER_PERSIST_KEY: "{{ env('CONTAINER_PERSIST_KEY', required=False) }}"
LOG_DIR: "{{ env('LOG_DIR') }}"

confirm:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ server_params:
CONTAINER_IMAGE: "seclab-shell-network-analysis:latest"
CONTAINER_WORKSPACE: "{{ env('CONTAINER_WORKSPACE', required=False) }}"
CONTAINER_TIMEOUT: "{{ env('CONTAINER_TIMEOUT', '30') }}"
CONTAINER_PERSIST: "{{ env('CONTAINER_PERSIST', required=False) }}"
CONTAINER_PERSIST_KEY: "{{ env('CONTAINER_PERSIST_KEY', required=False) }}"
LOG_DIR: "{{ env('LOG_DIR') }}"

confirm:
Expand Down
2 changes: 2 additions & 0 deletions src/seclab_taskflows/toolboxes/container_shell_sast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ server_params:
CONTAINER_IMAGE: "seclab-shell-sast:latest"
CONTAINER_WORKSPACE: "{{ env('CONTAINER_WORKSPACE', required=False) }}"
CONTAINER_TIMEOUT: "{{ env('CONTAINER_TIMEOUT', '60') }}"
CONTAINER_PERSIST: "{{ env('CONTAINER_PERSIST', required=False) }}"
CONTAINER_PERSIST_KEY: "{{ env('CONTAINER_PERSIST_KEY', required=False) }}"
LOG_DIR: "{{ env('LOG_DIR') }}"

confirm:
Expand Down
101 changes: 101 additions & 0 deletions tests/test_container_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,107 @@ def test_stop_container_clears_name_on_failure(self):
assert cs_mod._container_name is None


# ---------------------------------------------------------------------------
# Persistent container tests
# ---------------------------------------------------------------------------

class TestPersistentContainer:
def setup_method(self):
_reset_container()

def test_persistent_name_uses_hash(self):
with patch.object(cs_mod, "CONTAINER_IMAGE", "myregistry.io/org/image:v1.2.3"):
with patch.object(cs_mod, "CONTAINER_PERSIST_KEY", ""):
name = cs_mod._persistent_name()
assert name.startswith("seclab-persist-")
assert len(name) == len("seclab-persist-") + 12

def test_persistent_name_varies_with_key(self):
with patch.object(cs_mod, "CONTAINER_IMAGE", "test-image:latest"):
with patch.object(cs_mod, "CONTAINER_PERSIST_KEY", ""):
name_a = cs_mod._persistent_name()
with patch.object(cs_mod, "CONTAINER_PERSIST_KEY", "run-42"):
name_b = cs_mod._persistent_name()
assert name_a != name_b

def test_persistent_name_differs_for_different_images(self):
with patch.object(cs_mod, "CONTAINER_PERSIST_KEY", ""):
with patch.object(cs_mod, "CONTAINER_IMAGE", "image-a:latest"):
name_a = cs_mod._persistent_name()
with patch.object(cs_mod, "CONTAINER_IMAGE", "image-b:latest"):
name_b = cs_mod._persistent_name()
assert name_a != name_b

def test_start_reuses_running_persistent_container(self):
inspect_proc = _make_proc(
returncode=0,
stdout='[{"State":{"Running":true}}]',
)
with (
patch.object(cs_mod, "CONTAINER_IMAGE", "test-image:latest"),
patch.object(cs_mod, "CONTAINER_WORKSPACE", ""),
patch.object(cs_mod, "CONTAINER_PERSIST", True),
patch.object(cs_mod, "CONTAINER_PERSIST_KEY", ""),
patch("subprocess.run", return_value=inspect_proc) as mock_run,
):
name = cs_mod._start_container()
assert name.startswith("seclab-persist-")
# Only docker inspect should be called, NOT docker run
assert mock_run.call_count == 1
cmd = mock_run.call_args[0][0]
assert cmd == ["docker", "inspect", "--format", "json", name]

def test_start_persistent_no_rm_flag(self):
inspect_proc = _make_proc(
returncode=1,
stdout="",
)
rm_proc = _make_proc(returncode=0)
run_proc = _make_proc(returncode=0)
with (
patch.object(cs_mod, "CONTAINER_IMAGE", "test-image:latest"),
patch.object(cs_mod, "CONTAINER_WORKSPACE", ""),
patch.object(cs_mod, "CONTAINER_PERSIST", True),
patch.object(cs_mod, "CONTAINER_PERSIST_KEY", ""),
patch("subprocess.run", side_effect=[inspect_proc, rm_proc, run_proc]) as mock_run,
):
name = cs_mod._start_container()
assert name.startswith("seclab-persist-")
# The docker run call is the third one
run_cmd = mock_run.call_args_list[2][0][0]
assert "--rm" not in run_cmd

def test_stop_skips_persistent_container(self):
cs_mod._container_name = "seclab-persist-abc123"
with (
patch.object(cs_mod, "CONTAINER_PERSIST", True),
patch("subprocess.run") as mock_run,
):
cs_mod._stop_container()
mock_run.assert_not_called()
assert cs_mod._container_name is None

def test_remove_container_logs_failure(self):
with patch("subprocess.run", return_value=_make_proc(returncode=1, stderr="conflict")):
with patch.object(cs_mod.logging, "debug") as mock_debug:
cs_mod._remove_container("test-name")
mock_debug.assert_called_once()

def test_remove_container_logs_timeout(self):
with patch("subprocess.run", side_effect=subprocess.TimeoutExpired(cmd="docker", timeout=30)):
with patch.object(cs_mod.logging, "exception") as mock_err:
cs_mod._remove_container("test-name")
mock_err.assert_called_once()

def test_is_running_returns_false_on_timeout(self):
with patch("subprocess.run", side_effect=subprocess.TimeoutExpired(cmd="docker", timeout=30)):
assert cs_mod._is_running("test-name") is False

def test_is_running_returns_false_on_bad_json(self):
with patch("subprocess.run", return_value=_make_proc(returncode=0, stdout="not json")):
assert cs_mod._is_running("test-name") is False


# ---------------------------------------------------------------------------
# Toolbox YAML validation
# ---------------------------------------------------------------------------
Expand Down
Loading