From d56eb70cbb0f9b816502e9617e8226164d86a9c5 Mon Sep 17 00:00:00 2001 From: GangGreenTemperTatum <104169244+GangGreenTemperTatum@users.noreply.github.com> Date: Tue, 24 Feb 2026 16:48:37 -0500 Subject: [PATCH] feat: add pytorch loader for pickle opcode security analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Static analysis of PyTorch .pt/.pth checkpoint files — scans pickle bytecode for dangerous opcodes (GLOBAL, REDUCE, BUILD, etc.) without executing anything. Detects suspicious module imports (os, subprocess), validates ZIP structure, and flags path traversal in archive entries. Co-Authored-By: Claude Opus 4.6 --- README.md | 2 +- dyana/cli.py | 6 +- dyana/cli_test.py | 4 +- dyana/loaders/base/dyana_test.py | 36 +-- dyana/loaders/loader.py | 4 +- dyana/loaders/pytorch/.gitignore | 3 + dyana/loaders/pytorch/Dockerfile | 12 + dyana/loaders/pytorch/main.py | 346 ++++++++++++++++++++++++++ dyana/loaders/pytorch/pytorch_test.py | 187 ++++++++++++++ dyana/loaders/pytorch/settings.yml | 16 ++ dyana/view.py | 74 +++++- dyana/view_test.py | 85 ++++++- 12 files changed, 738 insertions(+), 37 deletions(-) create mode 100644 dyana/loaders/pytorch/.gitignore create mode 100644 dyana/loaders/pytorch/Dockerfile create mode 100644 dyana/loaders/pytorch/main.py create mode 100644 dyana/loaders/pytorch/pytorch_test.py create mode 100644 dyana/loaders/pytorch/settings.yml diff --git a/README.md b/README.md index c0a87bab..328de4f7 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ -Dyana is a sandbox environment using Docker and [Tracee](https://github.com/aquasecurity/tracee) for loading, running and profiling a wide range of files, including machine learning models, ELF executables, Pickle serialized files, Javascripts [and more](https://docs.dreadnode.io/open-source/dyana/topics/loaders). It provides detailed insights into GPU memory usage, filesystem interactions, network requests, and security related events. +Dyana is a sandbox environment using Docker and [Tracee](https://github.com/aquasecurity/tracee) for loading, running and profiling a wide range of files, including machine learning models, PyTorch checkpoints, ELF executables, Pickle serialized files, Javascripts [and more](https://docs.dreadnode.io/open-source/dyana/topics/loaders). It provides detailed insights into GPU memory usage, filesystem interactions, network requests, and security related events. ## Installation diff --git a/dyana/cli.py b/dyana/cli.py index 5eb92fff..47b65493 100644 --- a/dyana/cli.py +++ b/dyana/cli.py @@ -20,6 +20,7 @@ from dyana.view import ( view_disk_events, view_disk_usage, + view_extra, view_gpus, view_header, view_imports, @@ -139,7 +140,9 @@ def trace( except Exception as e: serr = str(e) if "could not select device driver" in serr and "capabilities: [[gpu]]" in serr: - rich_print(":cross_mark: [bold][red]error:[/] [red]GPUs are not available on this system, run with --no-gpu.[/]") + rich_print( + ":cross_mark: [bold][red]error:[/] [red]GPUs are not available on this system, run with --no-gpu.[/]" + ) else: rich_print(f":cross_mark: [bold][red]error:[/] [red]{e}[/]") @@ -187,3 +190,4 @@ def summary(trace_path: pathlib.Path = typer.Option(help="Path to the trace file view_legacy_extra(trace["run"]) else: view_imports(trace["run"]["stages"]) + view_extra(trace["run"]) diff --git a/dyana/cli_test.py b/dyana/cli_test.py index 16fef17f..8bdea9e4 100644 --- a/dyana/cli_test.py +++ b/dyana/cli_test.py @@ -177,9 +177,7 @@ def test_trace_runs_and_saves(self, _mock_run: t.Any, tmp_path: t.Any) -> None: @patch("dyana.cli.Tracer.__init__", _noop_tracer_init) @patch( "dyana.cli.Tracer.run_trace", - side_effect=RuntimeError( - "could not select device driver '' with capabilities: [[gpu]]" - ), + side_effect=RuntimeError("could not select device driver '' with capabilities: [[gpu]]"), ) @patch("dyana.cli.Loader.__init__", _noop_loader_init) def test_trace_gpu_error(self, _mock_run: t.Any) -> None: diff --git a/dyana/loaders/base/dyana_test.py b/dyana/loaders/base/dyana_test.py index 0f223473..5f13cc1b 100644 --- a/dyana/loaders/base/dyana_test.py +++ b/dyana/loaders/base/dyana_test.py @@ -18,62 +18,48 @@ def setup_method(self) -> None: def test_singleton(self) -> None: with patch("dyana.loaders.base.dyana.Stage.create") as mock_create: - mock_create.return_value = Stage( - name="start", timestamp=0, ram=0, disk=0, network={}, imports={} - ) + mock_create.return_value = Stage(name="start", timestamp=0, ram=0, disk=0, network={}, imports={}) p = Profiler() assert Profiler.instance is p def test_on_stage(self) -> None: with patch("dyana.loaders.base.dyana.Stage.create") as mock_create: - mock_create.return_value = Stage( - name="test", timestamp=0, ram=0, disk=0, network={}, imports={} - ) + mock_create.return_value = Stage(name="test", timestamp=0, ram=0, disk=0, network={}, imports={}) p = Profiler() p.on_stage("after_load") assert len(p._stages) == 2 def test_track_error(self) -> None: with patch("dyana.loaders.base.dyana.Stage.create") as mock_create: - mock_create.return_value = Stage( - name="start", timestamp=0, ram=0, disk=0, network={}, imports={} - ) + mock_create.return_value = Stage(name="start", timestamp=0, ram=0, disk=0, network={}, imports={}) p = Profiler() p.track_error("loader", "something broke") assert p._errors == {"loader": "something broke"} def test_track_warning(self) -> None: with patch("dyana.loaders.base.dyana.Stage.create") as mock_create: - mock_create.return_value = Stage( - name="start", timestamp=0, ram=0, disk=0, network={}, imports={} - ) + mock_create.return_value = Stage(name="start", timestamp=0, ram=0, disk=0, network={}, imports={}) p = Profiler() p.track_warning("pip", "could not import") assert p._warnings == {"pip": "could not import"} def test_track_extra(self) -> None: with patch("dyana.loaders.base.dyana.Stage.create") as mock_create: - mock_create.return_value = Stage( - name="start", timestamp=0, ram=0, disk=0, network={}, imports={} - ) + mock_create.return_value = Stage(name="start", timestamp=0, ram=0, disk=0, network={}, imports={}) p = Profiler() p.track_extra("imports", {"os": "/usr/lib"}) assert p._extra == {"imports": {"os": "/usr/lib"}} def test_track(self) -> None: with patch("dyana.loaders.base.dyana.Stage.create") as mock_create: - mock_create.return_value = Stage( - name="start", timestamp=0, ram=0, disk=0, network={}, imports={} - ) + mock_create.return_value = Stage(name="start", timestamp=0, ram=0, disk=0, network={}, imports={}) p = Profiler() p.track("custom_key", "custom_value") assert p._additionals == {"custom_key": "custom_value"} def test_as_dict(self) -> None: with patch("dyana.loaders.base.dyana.Stage.create") as mock_create: - mock_create.return_value = Stage( - name="start", timestamp=0, ram=0, disk=0, network={}, imports={} - ) + mock_create.return_value = Stage(name="start", timestamp=0, ram=0, disk=0, network={}, imports={}) p = Profiler() p.track_error("err", "msg") result = p.as_dict() @@ -85,9 +71,7 @@ def test_as_dict(self) -> None: def test_flush(self, capsys: t.Any) -> None: with patch("dyana.loaders.base.dyana.Stage.create") as mock_create: - mock_create.return_value = Stage( - name="start", timestamp=0, ram=0, disk=0, network={}, imports={} - ) + mock_create.return_value = Stage(name="start", timestamp=0, ram=0, disk=0, network={}, imports={}) Profiler() Profiler.flush() captured = capsys.readouterr() @@ -151,7 +135,9 @@ def test_with_prev_imports(self) -> None: patch("dyana.loaders.base.dyana.get_peak_rss", return_value=1024), patch("dyana.loaders.base.dyana.get_disk_usage", return_value=2048), patch("dyana.loaders.base.dyana.get_network_stats", return_value={}), - patch("dyana.loaders.base.dyana.get_current_imports", return_value={"os": "/a", "sys": "/b", "new_mod": "/c"}), + patch( + "dyana.loaders.base.dyana.get_current_imports", return_value={"os": "/a", "sys": "/b", "new_mod": "/c"} + ), ): stage = Stage.create("test", prev_imports={"os": "/a", "sys": "/b"}) assert "new_mod" in stage.imports diff --git a/dyana/loaders/loader.py b/dyana/loaders/loader.py index e430ffdf..222b7f42 100644 --- a/dyana/loaders/loader.py +++ b/dyana/loaders/loader.py @@ -184,7 +184,9 @@ def run(self, allow_network: bool = False, allow_gpus: bool = True, allow_volume rich_print(":popcorn: [bold]loader[/]: [yellow]required bridged network access[/]") elif allow_network: - rich_print(":popcorn: [bold]loader[/]: [yellow]warning: allowing bridged network access to the container[/]") + rich_print( + ":popcorn: [bold]loader[/]: [yellow]warning: allowing bridged network access to the container[/]" + ) if allow_volume_write: rich_print(":popcorn: [bold]loader[/]: [yellow]warning: allowing volume write to the container[/]") diff --git a/dyana/loaders/pytorch/.gitignore b/dyana/loaders/pytorch/.gitignore new file mode 100644 index 00000000..3d1264e6 --- /dev/null +++ b/dyana/loaders/pytorch/.gitignore @@ -0,0 +1,3 @@ +dyana.py +dyana-requirements.txt +dyana-requirements-gpu.txt diff --git a/dyana/loaders/pytorch/Dockerfile b/dyana/loaders/pytorch/Dockerfile new file mode 100644 index 00000000..07fbfb64 --- /dev/null +++ b/dyana/loaders/pytorch/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3.12-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y build-essential +COPY dyana.py . +COPY dyana-requirements.txt . +RUN pip install --no-cache-dir --root-user-action=ignore -r dyana-requirements.txt + +COPY main.py . + +ENTRYPOINT ["python3", "-W", "ignore", "main.py"] diff --git a/dyana/loaders/pytorch/main.py b/dyana/loaders/pytorch/main.py new file mode 100644 index 00000000..d30ce28d --- /dev/null +++ b/dyana/loaders/pytorch/main.py @@ -0,0 +1,346 @@ +from __future__ import annotations + +import argparse +import os +import pickletools +import re +import typing as t +import zipfile + +# Opcodes that can execute arbitrary code +DANGEROUS_OPCODES: dict[str, str] = { + "GLOBAL": "imports a module attribute (can import arbitrary code)", + "INST": "instantiates a class (can execute arbitrary constructors)", + "OBJ": "builds an object (can execute arbitrary constructors)", + "NEWOBJ": "creates a new object (can call arbitrary __new__)", + "NEWOBJ_EX": "creates a new object with kwargs (can call arbitrary __new__)", + "REDUCE": "calls a callable (can execute arbitrary functions)", + "BUILD": "calls __setstate__ (can trigger arbitrary code via object reconstruction)", + "STACK_GLOBAL": "imports a module attribute from stack (can import arbitrary code)", +} + +# Known-safe modules/functions commonly seen in legitimate PyTorch checkpoints +KNOWN_SAFE_GLOBALS: set[str] = { + "torch._utils._rebuild_tensor_v2", + "torch._utils._rebuild_parameter", + "torch._utils._rebuild_parameter_with_state", + "torch.FloatStorage", + "torch.LongStorage", + "torch.IntStorage", + "torch.ShortStorage", + "torch.DoubleStorage", + "torch.HalfStorage", + "torch.ByteStorage", + "torch.CharStorage", + "torch.BFloat16Storage", + "torch.ComplexFloatStorage", + "torch.ComplexDoubleStorage", + "torch.storage._load_from_bytes", + "torch.nn.modules.module.Module", + "collections.OrderedDict", + "_codecs.encode", + "torch.Size", + "torch.device", + "torch.dtype", + "torch.float16", + "torch.float32", + "torch.float64", + "torch.bfloat16", + "torch.int8", + "torch.int16", + "torch.int32", + "torch.int64", + "torch.uint8", + "torch.bool", + "torch.complex64", + "torch.complex128", + "torch._utils._rebuild_tensor_v3", + "torch._utils._rebuild_device_tensor_v2", +} + +# Module prefixes that are suspicious +SUSPICIOUS_MODULE_PREFIXES: list[str] = [ + "os.", + "subprocess.", + "shutil.", + "sys.", + "builtins.", + "importlib.", + "ctypes.", + "socket.", + "http.", + "urllib.", + "requests.", + "webbrowser.", + "code.", + "eval", + "exec", + "compile", + "__builtin__.", + "nt.", + "posix.", + "signal.", +] + + +def check_zip_structure(path: str) -> dict[str, t.Any]: + """Check if the file is a valid ZIP archive (PyTorch format) or legacy pickle.""" + errors: list[str] = [] + info: list[str] = [] + is_zip = False + is_legacy_pickle = False + zip_entries: list[dict[str, t.Any]] = [] + + try: + is_zip = zipfile.is_zipfile(path) + except OSError as e: + errors.append(f"cannot read file: {e}") + return { + "is_zip": False, + "is_legacy_pickle": False, + "zip_entries": [], + "errors": errors, + "info": info, + } + + if is_zip: + try: + with zipfile.ZipFile(path, "r") as zf: + for zi in zf.infolist(): + zip_entries.append( + { + "filename": zi.filename, + "file_size": zi.file_size, + "compress_size": zi.compress_size, + } + ) + # check for path traversal in zip entries + if ".." in zi.filename or zi.filename.startswith("/"): + errors.append(f"suspicious zip entry path: '{zi.filename}' (possible path traversal)") + except zipfile.BadZipFile as e: + errors.append(f"corrupted zip file: {e}") + else: + # check if it's a legacy pickle file (pre-zip PyTorch format) + try: + with open(path, "rb") as f: + magic = f.read(2) + # pickle protocol opcodes: \x80 = PROTO + if magic and magic[0] == 0x80: + is_legacy_pickle = True + info.append("file is legacy pickle format (not ZIP-based)") + except OSError as e: + errors.append(f"cannot read file: {e}") + + return { + "is_zip": is_zip, + "is_legacy_pickle": is_legacy_pickle, + "zip_entries": zip_entries, + "errors": errors, + "info": info, + } + + +def scan_pickle_opcodes(data: bytes) -> dict[str, t.Any]: + """Scan pickle bytecode for dangerous opcodes without executing it.""" + errors: list[str] = [] + warnings: list[str] = [] + info: list[str] = [] + + dangerous_ops: list[dict[str, str]] = [] + global_imports: list[str] = [] + all_opcodes: dict[str, int] = {} + + try: + ops = list(pickletools.genops(data)) + except Exception as e: + errors.append(f"failed to disassemble pickle: {e}") + return { + "dangerous_ops": [], + "global_imports": [], + "all_opcodes": {}, + "errors": errors, + "warnings": warnings, + "info": info, + } + + for opcode, arg, _pos in ops: + name = opcode.name + all_opcodes[name] = all_opcodes.get(name, 0) + 1 + + if name in DANGEROUS_OPCODES: + entry: dict[str, str] = { + "opcode": name, + "reason": DANGEROUS_OPCODES[name], + } + if arg is not None: + entry["arg"] = str(arg) + dangerous_ops.append(entry) + + # Track GLOBAL/STACK_GLOBAL imports specifically + if name in ("GLOBAL", "STACK_GLOBAL") and arg is not None: + # pickletools returns "module attr" (space-separated), normalize to "module.attr" + normalized = str(arg).replace(" ", ".") + global_imports.append(normalized) + + # Classify global imports + suspicious_globals: list[str] = [] + unknown_globals: list[str] = [] + + for g in global_imports: + if g in KNOWN_SAFE_GLOBALS: + continue + is_suspicious = False + for prefix in SUSPICIOUS_MODULE_PREFIXES: + if g.startswith(prefix) or g == prefix.rstrip("."): + suspicious_globals.append(g) + is_suspicious = True + break + if not is_suspicious and g not in KNOWN_SAFE_GLOBALS: + unknown_globals.append(g) + + if suspicious_globals: + for g in suspicious_globals: + errors.append(f"suspicious global import: '{g}'") + + if unknown_globals: + for g in unknown_globals: + warnings.append(f"unknown global import: '{g}'") + + # Summarize + safe_count = len(global_imports) - len(suspicious_globals) - len(unknown_globals) + if global_imports: + info.append( + f"{len(global_imports)} global imports: " + f"{safe_count} known-safe, " + f"{len(unknown_globals)} unknown, " + f"{len(suspicious_globals)} suspicious" + ) + + return { + "dangerous_ops": dangerous_ops, + "global_imports": global_imports, + "all_opcodes": all_opcodes, + "errors": errors, + "warnings": warnings, + "info": info, + } + + +def analyze_pytorch_file(path: str) -> dict[str, t.Any]: + """Full analysis of a PyTorch checkpoint file.""" + errors: list[str] = [] + warnings: list[str] = [] + info: list[str] = [] + + file_size = os.path.getsize(path) + zip_result = check_zip_structure(path) + + errors.extend(zip_result["errors"]) + info.extend(zip_result["info"]) + + pickle_scan: dict[str, t.Any] | None = None + data_files: list[dict[str, t.Any]] = [] + + if zip_result["is_zip"]: + try: + with zipfile.ZipFile(path, "r") as zf: + for zi in zf.infolist(): + if zi.filename.endswith(".pkl") or zi.filename.endswith("data.pkl"): + # Scan pickle data inside ZIP + pkl_data = zf.read(zi.filename) + pickle_scan = scan_pickle_opcodes(pkl_data) + errors.extend(pickle_scan["errors"]) + warnings.extend(pickle_scan["warnings"]) + info.extend(pickle_scan["info"]) + elif re.match(r".*/data/\d+$", zi.filename): + data_files.append( + { + "name": zi.filename, + "size": zi.file_size, + } + ) + except Exception as e: + errors.append(f"failed to read zip contents: {e}") + + elif zip_result["is_legacy_pickle"]: + try: + with open(path, "rb") as f: + pkl_data = f.read() + pickle_scan = scan_pickle_opcodes(pkl_data) + errors.extend(pickle_scan["errors"]) + warnings.extend(pickle_scan["warnings"]) + info.extend(pickle_scan["info"]) + except Exception as e: + errors.append(f"failed to read pickle data: {e}") + + else: + errors.append("file is neither a ZIP archive nor a pickle file — not a valid PyTorch checkpoint") + + return { + "file_size": file_size, + "is_zip": zip_result["is_zip"], + "is_legacy_pickle": zip_result["is_legacy_pickle"], + "zip_entries": zip_result["zip_entries"], + "data_files": data_files, + "pickle_scan": pickle_scan, + "errors": errors, + "warnings": warnings, + "info": info, + } + + +if __name__ == "__main__": + from dyana import Profiler # type: ignore[attr-defined] + + parser = argparse.ArgumentParser(description="Analyze PyTorch checkpoint files for security issues") + parser.add_argument("--pytorch", help="Path to PyTorch checkpoint file", required=True) + args = parser.parse_args() + profiler: Profiler = Profiler(gpu=False) + + if not os.path.exists(args.pytorch): + profiler.track_error("pytorch", "PyTorch checkpoint file not found") + else: + # Stage 1: check file structure + profiler.on_stage("checking_structure") + result = analyze_pytorch_file(args.pytorch) + + for error in result["errors"]: + profiler.track_error("pytorch", error) + for warning in result["warnings"]: + profiler.track_warning("pytorch", warning) + + profiler.track_extra( + "file_structure", + { + "file_size": result["file_size"], + "is_zip": result["is_zip"], + "is_legacy_pickle": result["is_legacy_pickle"], + "zip_entries": result["zip_entries"], + "data_files": result["data_files"], + }, + ) + + # Stage 2: report pickle analysis + profiler.on_stage("analyzing_pickle") + if result["pickle_scan"]: + scan = result["pickle_scan"] + + profiler.track_extra( + "pickle_analysis", + { + "global_imports": scan["global_imports"], + "dangerous_ops_count": len(scan["dangerous_ops"]), + "dangerous_ops": scan["dangerous_ops"][:20], + "opcode_distribution": scan["all_opcodes"], + }, + ) + + # Collect findings + profiler.track_extra( + "findings", + { + "errors": result["errors"], + "warnings": result["warnings"], + "info": result["info"], + }, + ) diff --git a/dyana/loaders/pytorch/pytorch_test.py b/dyana/loaders/pytorch/pytorch_test.py new file mode 100644 index 00000000..a0384db1 --- /dev/null +++ b/dyana/loaders/pytorch/pytorch_test.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import pickle +import typing as t +import zipfile +from pathlib import Path + +from dyana.loaders.loader import Loader +from dyana.loaders.pytorch.main import ( + analyze_pytorch_file, + check_zip_structure, + scan_pickle_opcodes, +) + + +def _make_pytorch_zip(path: Path, state_dict: dict[str, t.Any] | None = None) -> None: + """Create a minimal PyTorch-style ZIP checkpoint.""" + if state_dict is None: + state_dict = {"model.weight": [1.0, 2.0, 3.0]} + + pkl_data = pickle.dumps(state_dict, protocol=2) + tensor_data = b"\x00" * 24 + + with zipfile.ZipFile(path, "w") as zf: + zf.writestr("archive/data.pkl", pkl_data) + zf.writestr("archive/data/0", tensor_data) + + +def _make_legacy_pickle(path: Path, obj: t.Any = None) -> None: + """Create a legacy pickle file (non-ZIP).""" + if obj is None: + obj = {"weight": [1.0, 2.0]} + with open(path, "wb") as f: + pickle.dump(obj, f, protocol=2) + + +class TestPyTorchLoaderSettings: + def test_loader_loads(self) -> None: + loader = Loader(name="pytorch", build=False) + assert loader.settings is not None + assert loader.settings.gpu is False + + def test_correct_arg_structure(self) -> None: + loader = Loader(name="pytorch", build=False) + assert loader.settings is not None + assert loader.settings.args is not None + assert len(loader.settings.args) == 1 + assert loader.settings.args[0].name == "pytorch" + assert loader.settings.args[0].required is True + assert loader.settings.args[0].volume is True + + +class TestCheckZipStructure: + def test_valid_zip(self, tmp_path: Path) -> None: + path = tmp_path / "model.pt" + _make_pytorch_zip(path) + result = check_zip_structure(str(path)) + assert result["is_zip"] is True + assert result["is_legacy_pickle"] is False + assert len(result["zip_entries"]) == 2 + assert result["errors"] == [] + + def test_legacy_pickle(self, tmp_path: Path) -> None: + path = tmp_path / "model.pt" + _make_legacy_pickle(path) + result = check_zip_structure(str(path)) + assert result["is_zip"] is False + assert result["is_legacy_pickle"] is True + assert any("legacy" in i for i in result["info"]) + + def test_invalid_file(self, tmp_path: Path) -> None: + path = tmp_path / "garbage.pt" + path.write_text("this is not a pytorch file") + result = check_zip_structure(str(path)) + assert result["is_zip"] is False + assert result["is_legacy_pickle"] is False + + def test_empty_file(self, tmp_path: Path) -> None: + path = tmp_path / "empty.pt" + path.write_bytes(b"") + result = check_zip_structure(str(path)) + assert result["is_zip"] is False + assert result["is_legacy_pickle"] is False + + def test_zip_path_traversal(self, tmp_path: Path) -> None: + path = tmp_path / "evil.pt" + with zipfile.ZipFile(path, "w") as zf: + zf.writestr("../../etc/passwd", b"root:x:0:0") + result = check_zip_structure(str(path)) + assert any("path traversal" in e for e in result["errors"]) + + +class TestScanPickleOpcodes: + def test_safe_pickle(self) -> None: + data = pickle.dumps({"key": "value"}, protocol=2) + result = scan_pickle_opcodes(data) + assert result["errors"] == [] + + def test_detects_global_import(self) -> None: + # Create pickle that uses GLOBAL opcode + data = pickle.dumps({"key": "value"}, protocol=2) + result = scan_pickle_opcodes(data) + # A simple dict pickle may or may not have GLOBAL ops depending on protocol + # but it should parse without errors + assert "all_opcodes" in result + + def test_dangerous_os_system(self) -> None: + # Manually craft a pickle with os.system call + # PROTO 2, GLOBAL 'os system', SHORT_BINUNICODE 'echo pwned', TUPLE1, REDUCE, STOP + malicious = ( + b"\x80\x02" # PROTO 2 + b"cos\nsystem\n" # GLOBAL os.system + b"\x8c\x0aecho pwned" # SHORT_BINUNICODE 'echo pwned' + b"\x85" # TUPLE1 + b"R" # REDUCE + b"." # STOP + ) + result = scan_pickle_opcodes(malicious) + assert any("os.system" in e for e in result["errors"]) + assert any(op["opcode"] == "GLOBAL" for op in result["dangerous_ops"]) + assert any(op["opcode"] == "REDUCE" for op in result["dangerous_ops"]) + + def test_dangerous_subprocess(self) -> None: + malicious = b"\x80\x02csubprocess\ncheck_output\n\x8c\x02id\x85R." + result = scan_pickle_opcodes(malicious) + assert any("subprocess" in e for e in result["errors"]) + + def test_known_safe_globals_not_flagged(self) -> None: + # Craft pickle with a known-safe torch global + safe = b"\x80\x02ctorch._utils\n_rebuild_tensor_v2\n." + result = scan_pickle_opcodes(safe) + # Should not have errors for known-safe globals + assert not any("suspicious" in e for e in result["errors"]) + + def test_unknown_global_warning(self) -> None: + # Craft pickle with an unknown but not suspicious global + unknown = b"\x80\x02cmy_custom_module\nmy_function\n." + result = scan_pickle_opcodes(unknown) + assert any("unknown" in w for w in result["warnings"]) + + def test_invalid_pickle(self) -> None: + result = scan_pickle_opcodes(b"\xff\xff\xff\xff") + assert len(result["errors"]) > 0 + + def test_opcode_counting(self) -> None: + data = pickle.dumps([1, 2, 3], protocol=2) + result = scan_pickle_opcodes(data) + assert isinstance(result["all_opcodes"], dict) + assert sum(result["all_opcodes"].values()) > 0 + + +class TestAnalyzePytorchFile: + def test_valid_zip_checkpoint(self, tmp_path: Path) -> None: + path = tmp_path / "model.pt" + _make_pytorch_zip(path) + result = analyze_pytorch_file(str(path)) + assert result["is_zip"] is True + assert result["pickle_scan"] is not None + assert result["file_size"] > 0 + + def test_legacy_pickle_checkpoint(self, tmp_path: Path) -> None: + path = tmp_path / "model.pt" + _make_legacy_pickle(path) + result = analyze_pytorch_file(str(path)) + assert result["is_legacy_pickle"] is True + assert result["pickle_scan"] is not None + + def test_invalid_file(self, tmp_path: Path) -> None: + path = tmp_path / "garbage.pt" + path.write_text("not a pytorch file") + result = analyze_pytorch_file(str(path)) + assert any("not a valid PyTorch" in e for e in result["errors"]) + + def test_data_files_detected(self, tmp_path: Path) -> None: + path = tmp_path / "model.pt" + _make_pytorch_zip(path) + result = analyze_pytorch_file(str(path)) + assert len(result["data_files"]) == 1 + assert result["data_files"][0]["name"] == "archive/data/0" + + def test_malicious_zip_checkpoint(self, tmp_path: Path) -> None: + path = tmp_path / "evil.pt" + malicious_pkl = b"\x80\x02cos\nsystem\n\x8c\x0aecho pwned\x85R." + with zipfile.ZipFile(path, "w") as zf: + zf.writestr("archive/data.pkl", malicious_pkl) + result = analyze_pytorch_file(str(path)) + assert any("os.system" in e for e in result["errors"]) diff --git a/dyana/loaders/pytorch/settings.yml b/dyana/loaders/pytorch/settings.yml new file mode 100644 index 00000000..f16b5ff3 --- /dev/null +++ b/dyana/loaders/pytorch/settings.yml @@ -0,0 +1,16 @@ +description: Analyzes PyTorch checkpoint files for suspicious pickle opcodes and structural integrity. + +gpu: false + +args: + - name: pytorch + description: Path to the PyTorch checkpoint file (.pt or .pth) to analyze. + required: true + volume: true + +examples: + - description: "Analyze a PyTorch checkpoint:" + command: dyana trace --loader pytorch --pytorch /path/to/model.pt + + - description: "Analyze with verbose output:" + command: dyana trace --loader pytorch --pytorch /path/to/model.pt --verbose diff --git a/dyana/view.py b/dyana/view.py index aed6a9b5..9c81ee0f 100644 --- a/dyana/view.py +++ b/dyana/view.py @@ -20,7 +20,9 @@ def _view_loader_help_markdown(loader: Loader) -> None: rich_print() rich_print("* **Requires Network:**", "yes" if loader.settings.network else "no") if loader.settings.build_args: - rich_print("* **Optional Build Arguments:**", ", ".join({f"`--{k}`" for k in loader.settings.build_args.keys()})) + rich_print( + "* **Optional Build Arguments:**", ", ".join({f"`--{k}`" for k in loader.settings.build_args.keys()}) + ) if loader.settings.args: rich_print() @@ -34,7 +36,9 @@ def _view_loader_help_markdown(loader: Loader) -> None: "|--------------|---------------------------------------------------------------------|------------------------------|----------|" ) for arg in loader.settings.args: - rich_print(f"| `--{arg.name}` | {arg.description} | `{arg.default}` | {'yes' if arg.required else 'no'} |") + rich_print( + f"| `--{arg.name}` | {arg.description} | `{arg.default}` | {'yes' if arg.required else 'no'} |" + ) if loader.settings.examples: rich_print() @@ -333,7 +337,7 @@ def view_network_events(trace: dict[str, t.Any]) -> None: else: data = [arg["value"] for arg in event["args"] if arg["name"] == "proto_dns"][0] question_names = [q["name"] for q in data["questions"]] - answers = [f'{a["name"]}={a["IP"]}' for a in data["answers"]] + answers = [f"{a['name']}={a['IP']}" for a in data["answers"]] if not answers: line = f" * [[dim]{event['processId']}[/]] {event['processName']} | [bold red]dns[/] | question={', '.join(question_names)}" @@ -422,3 +426,67 @@ def view_security_events(trace: dict[str, t.Any]) -> None: rich_print(f" * {signature} ([dim]{category}[/], {severity_fmt(severity_level)})") rich_print() + + +def _view_pytorch_extra(extra: dict[str, t.Any]) -> None: + file_structure = extra.get("file_structure") + if file_structure: + rich_print("[bold yellow]File Structure:[/]") + rich_print(f" File size : {sizeof_fmt(file_structure['file_size'])}") + if file_structure["is_zip"]: + rich_print(" Format : ZIP archive (modern PyTorch)") + if file_structure.get("zip_entries"): + rich_print(f" ZIP entries : {len(file_structure['zip_entries'])}") + if file_structure.get("data_files"): + total = sum(d["size"] for d in file_structure["data_files"]) + rich_print(f" Data files : {len(file_structure['data_files'])} ({sizeof_fmt(total)})") + elif file_structure["is_legacy_pickle"]: + rich_print(" Format : legacy pickle") + rich_print() + + pickle_analysis = extra.get("pickle_analysis") + if pickle_analysis: + rich_print("[bold yellow]Pickle Analysis:[/]") + rich_print(f" Global imports : {len(pickle_analysis['global_imports'])}") + dangerous = pickle_analysis.get("dangerous_ops_count", 0) + if dangerous > 0: + rich_print(f" Dangerous ops : [bold red]{dangerous}[/]") + else: + rich_print(" Dangerous ops : [green]0[/]") + + if pickle_analysis.get("opcode_distribution"): + top = sorted(pickle_analysis["opcode_distribution"].items(), key=lambda x: x[1], reverse=True)[:8] + dist = ", ".join(f"{k}: {v}" for k, v in top) + rich_print(f" Top opcodes : {dist}") + rich_print() + + if pickle_analysis.get("global_imports"): + rich_print("[bold yellow]Global Imports:[/]") + for g in pickle_analysis["global_imports"][:15]: + rich_print(f" * {g}") + if len(pickle_analysis["global_imports"]) > 15: + rich_print(f" [dim]... and {len(pickle_analysis['global_imports']) - 15} more[/]") + rich_print() + + findings = extra.get("findings") + if findings: + has_any = findings.get("errors") or findings.get("warnings") or findings.get("info") + if has_any: + rich_print("[bold yellow]Findings:[/]") + for error in findings.get("errors", []): + rich_print(f" * [bold red]ERROR[/] : {error}") + for warning in findings.get("warnings", []): + rich_print(f" * [yellow]WARNING[/] : {warning}") + for info_msg in findings.get("info", []): + rich_print(f" * [dim]INFO[/] : {info_msg}") + rich_print() + + +def view_extra(run: dict[str, t.Any]) -> None: + extra = run.get("extra") + if not extra: + return + + loader_name = run.get("loader_name", "") + if loader_name == "pytorch": + _view_pytorch_extra(extra) diff --git a/dyana/view_test.py b/dyana/view_test.py index 02f95c1a..b8eddcfb 100644 --- a/dyana/view_test.py +++ b/dyana/view_test.py @@ -5,6 +5,7 @@ severity_fmt, view_disk_events, view_disk_usage, + view_extra, view_header, view_network_events, view_process_executions, @@ -284,9 +285,7 @@ def test_dedup(self) -> None: "processId": 1, "processName": "curl", "syscall": "connect", - "args": [ - {"name": "remote_addr", "value": {"sa_family": "AF_INET", "sin_addr": "1.2.3.4", "sin_port": 80}} - ], + "args": [{"name": "remote_addr", "value": {"sa_family": "AF_INET", "sin_addr": "1.2.3.4", "sin_port": 80}}], } trace: dict[str, t.Any] = {"events": [event, {**event, "timestamp": 2000}]} with patch("dyana.view.rich_print") as mock_print: @@ -434,3 +433,83 @@ def test_basic(self) -> None: assert "Disk Usage" in output assert "start" in output assert "end" in output + + +class TestViewExtra: + def test_no_extra(self) -> None: + run: dict[str, t.Any] = {"extra": None, "loader_name": "pytorch"} + with patch("dyana.view.rich_print") as mock_print: + view_extra(run) + mock_print.assert_not_called() + + def test_empty_extra(self) -> None: + run: dict[str, t.Any] = {"extra": {}, "loader_name": "pytorch"} + with patch("dyana.view.rich_print") as mock_print: + view_extra(run) + mock_print.assert_not_called() + + def test_pytorch_file_structure_zip(self) -> None: + run: dict[str, t.Any] = { + "loader_name": "pytorch", + "extra": { + "file_structure": { + "file_size": 1048576, + "is_zip": True, + "is_legacy_pickle": False, + "zip_entries": [{"filename": "archive/data.pkl"}, {"filename": "archive/data/0"}], + "data_files": [{"name": "archive/data/0", "size": 1024}], + } + }, + } + with patch("dyana.view.rich_print") as mock_print: + view_extra(run) + output = " ".join(str(c) for c in mock_print.call_args_list) + assert "File Structure" in output + assert "ZIP archive" in output + + def test_pytorch_pickle_analysis(self) -> None: + run: dict[str, t.Any] = { + "loader_name": "pytorch", + "extra": { + "pickle_analysis": { + "global_imports": ["torch._utils._rebuild_tensor_v2", "collections.OrderedDict"], + "dangerous_ops_count": 3, + "dangerous_ops": [{"opcode": "REDUCE", "reason": "calls a callable"}], + "opcode_distribution": {"GLOBAL": 2, "REDUCE": 3, "BINPUT": 10}, + } + }, + } + with patch("dyana.view.rich_print") as mock_print: + view_extra(run) + output = " ".join(str(c) for c in mock_print.call_args_list) + assert "Pickle Analysis" in output + assert "Global imports" in output + assert "torch._utils._rebuild_tensor_v2" in output + + def test_pytorch_findings(self) -> None: + run: dict[str, t.Any] = { + "loader_name": "pytorch", + "extra": { + "findings": { + "errors": ["suspicious global import: 'os.system'"], + "warnings": ["unknown global import: 'custom.module'"], + "info": ["2 global imports"], + } + }, + } + with patch("dyana.view.rich_print") as mock_print: + view_extra(run) + output = " ".join(str(c) for c in mock_print.call_args_list) + assert "ERROR" in output + assert "os.system" in output + assert "WARNING" in output + assert "INFO" in output + + def test_unknown_loader_no_output(self) -> None: + run: dict[str, t.Any] = { + "loader_name": "unknown_loader", + "extra": {"some_key": "some_value"}, + } + with patch("dyana.view.rich_print") as mock_print: + view_extra(run) + mock_print.assert_not_called()