Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions modelaudit/scanners/pytorch_zip_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,11 +988,14 @@ def _get_detected_pytorch_version(
return None, source if isinstance(source, str) else None

def _get_installed_pytorch_version(self) -> str | None:
"""Get locally installed PyTorch version when available."""
"""Get PyTorch version from an already-imported module without importing torch."""
try:
import torch
import sys

version = getattr(torch, "__version__", None)
torch_module = sys.modules.get("torch")
if torch_module is None:
return None
Comment on lines +995 to +997
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Detect installed torch version without requiring prior import

Return value now depends on torch already being present in sys.modules. In normal scanner runs, torch is often installed but not imported, so this returns None. That causes _select_pytorch_version_for_check() to ignore vulnerable local runtimes (or skip version-gated CVE checks when metadata is missing), creating false negatives in security detection.

Useful? React with 👍 / 👎.

version = getattr(torch_module, "__version__", None)
if isinstance(version, str) and version.strip():
return version.strip()
except Exception:
Expand Down
9 changes: 7 additions & 2 deletions tests/scanners/test_pytorch_zip_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,12 +655,14 @@ def test_pytorch_zip_version_selection_uses_metadata_when_torch_unavailable(
assert source == "metadata:config.json:pytorch_version"


def test_get_installed_pytorch_version_handles_import_errors(monkeypatch: pytest.MonkeyPatch) -> None:
"""Broken torch imports should degrade to None instead of aborting the scan."""
def test_get_installed_pytorch_version_does_not_import_torch(monkeypatch: pytest.MonkeyPatch) -> None:
"""Scanner should not import torch while collecting version context."""
import builtins
import sys

scanner = PyTorchZipScanner()
real_import = builtins.__import__
import_calls: list[str] = []

def fail_torch_import(
name: str,
Expand All @@ -669,13 +671,16 @@ def fail_torch_import(
fromlist: tuple[str, ...] = (),
level: int = 0,
) -> object:
import_calls.append(name)
if name == "torch":
raise RuntimeError("broken torch import")
return real_import(name, globals, locals, fromlist, level)

monkeypatch.delitem(sys.modules, "torch", raising=False)
monkeypatch.setattr(builtins, "__import__", fail_torch_import)

assert scanner._get_installed_pytorch_version() is None
assert "torch" not in import_calls


def test_pytorch_zip_version_detection_uses_local_torch_when_metadata_missing(
Expand Down
Loading