From 82758263d172311442db2b0d7362d7489f97b6e0 Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Tue, 2 Dec 2025 12:25:46 -0500 Subject: [PATCH 01/11] Transfer test scenarios from test_inference_commands.py to test/test_demo.py --- test/test_demo.py | 129 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 test/test_demo.py diff --git a/test/test_demo.py b/test/test_demo.py new file mode 100644 index 0000000..b886e9d --- /dev/null +++ b/test/test_demo.py @@ -0,0 +1,129 @@ +import subprocess +import sys +import os +import pytest +from pathlib import Path + +TEST_DIR = Path(__file__).resolve().parent +PROJECT_ROOT = (TEST_DIR / "..").resolve() + +DEFAULT_CKPT = "weights/train_session2024-07-08_1720455712_BFF_3.00.pt" + +# Path to rf_diffusion/run_inference.py +RUN_INFERENCE = PROJECT_ROOT / "rf_diffusion" / "run_inference.py" +SCENARIOS = [ + ( + "rna_unconditional", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['9']", + "contigmap.polymer_chains=['rna']", + "inference.output_prefix=demo_outputs/RNA_uncond_standard_settings", + ], + ), + ( + "multi_polymer_unconditional", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['3 3 3']", + "contigmap.polymer_chains=['dna','rna','protein']", + "inference.output_prefix=test_outputs/basic_uncond_test01", + ], + ), + ( + "dna_binder_unconditional", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['2 2 5']", + "contigmap.polymer_chains=['dna','dna','protein']", + "inference.output_prefix=demo_outputs/DNA_prot_uncond_standard_settings", + ], + ), + ( + "rna_secondary_structure", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['9']", + "contigmap.polymer_chains=['rna']", + "scaffoldguided.target_ss_string=555...333", + ], + ), + ( + "motif_scaffolding_v1", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['1,D8-10,1,B8-10,1 1,B18-20,1,D18-20,1 A1-3,0 C1-3,0']", + "contigmap.polymer_chains=['dna','dna','protein','protein']", + "inference.ij_visible=bce-adf", + "inference.input_pdb=test_data/combo_DBP009_DBP010_DBP011_with_DNA_v2.pdb", + "inference.output_prefix=demo_outputs/DNA_binders_scaffolding_test1_standard_settings", + ], + ), + ( + "motif_scaffolding_v2", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['1,D8-10,1,B8-10,1 1,B18-20,1,D18-20,1 A1-3,3,C1-3,0']", + "contigmap.polymer_chains=['dna','dna','protein']", + "scaffoldguided.target_ss_pairs=[\"A1-9,B1-9\"]", + "inference.ij_visible=bce-adf", + "inference.input_pdb=test_data/combo_DBP009_DBP010_DBP011_with_DNA_v2.pdb", + "inference.output_prefix=demo_outputs/DNA_binders_scaffolding_test2_standard_settings", + ], + ), + ( + "dna_pair_specification", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['6 6 6 6']", + "contigmap.polymer_chains=['dna','dna','dna','dna']", + "scaffoldguided.target_ss_pairs=['A1-2,B1-2','A3-4,C3-4','A5-6,D5-6','B3-4,D3-4','B5-6,C5-6','C1-2,D1-2']", + "inference.symmetry=d2", + "inference.output_prefix=demo_outputs/DNA_origami_standard_settings", + ], + ), +] + + +@pytest.mark.parametrize("name, overrides", SCENARIOS) +def test_multi_polymer_scenarios(name, overrides): + """ + Runs rf_diffusion/run_inference.py via subprocess for each scenario + """ + # set path to weights, assuming weights have been downloaded in root directory + os.environ.setdefault("RFDPOLY_CKPT_PATH", DEFAULT_CKPT) + + # Build the command + cmd = [ + sys.executable, + str(RUN_INFERENCE), + # Include this if your Hydra app requires an explicit config name. + # If run_inference.py already sets config_name in @hydra.main, you + # can remove the two lines below. + "--config-name", + "multi_polymer", + *overrides, + ] + + result = subprocess.run( + cmd, + cwd=str(PROJECT_ROOT), + capture_output=True, + text=True, + ) + + # Debug info if something fails + if result.returncode != 0: + print(f"\n=== Scenario {name} failed ===") + print("STDOUT:\n", result.stdout) + print("STDERR:\n", result.stderr) + + assert result.returncode == 0 + From f5b3ad235f91b7cb77c4bbde4b1e9287f01dd8dd Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Tue, 2 Dec 2025 12:27:22 -0500 Subject: [PATCH 02/11] Create new workflow to run test --- .github/workflows/test.yml | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..0033513 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,36 @@ +name: Run Tests + +on: + push: + branches: ["main", "github-ci"] + pull_request: + branches: [ "main", "github-ci" ] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + # Set up micromamba + - name: Set up micromamba + uses: mamba-org/setup-micromamba@v2 + with: + environment-file: rf_diffusion/environment/environment.yml + init-shell: bash + cache-environment: true + + - name: Install pytest + shell: micromamba-shell {0} + run: | + python -m pip install pytest + + - name: Download weights + run: | + curl -O https://files.ipd.uw.edu/pub/2025_RFDpoly/train_session2024-07-08_1720455712_BFF_3.00.pt + + - name: Run tests + shell: micromamba-shell {0} + run: | + python -m pytest test/test_demo.py From 8bbe690a83af289610860bfcddd34e7f8d46d02a Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Tue, 2 Dec 2025 12:36:16 -0500 Subject: [PATCH 03/11] Remove extra branch in test.yml --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0033513..18fdbd5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,9 +2,9 @@ name: Run Tests on: push: - branches: ["main", "github-ci"] + branches: ["main"] pull_request: - branches: [ "main", "github-ci" ] + branches: ["main"] jobs: test: From b7d92311ece0578e3090fde04f547ea212ca7b78 Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Tue, 2 Dec 2025 13:26:47 -0500 Subject: [PATCH 04/11] Change environment file that is loaded --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 18fdbd5..e1e88dc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: - name: Set up micromamba uses: mamba-org/setup-micromamba@v2 with: - environment-file: rf_diffusion/environment/environment.yml + environment-file: rf_diffusion/environment/macos_environment.yml init-shell: bash cache-environment: true From 9afe475ab78430e4bc9230f6a94ea1395129e73a Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Tue, 2 Dec 2025 16:52:36 -0500 Subject: [PATCH 05/11] Added new environment file to decrease size of dependencies for testing --- .github/workflows/test.yml | 2 +- rf_diffusion/environment/macos_environment.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e1e88dc..5015c8a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: - name: Set up micromamba uses: mamba-org/setup-micromamba@v2 with: - environment-file: rf_diffusion/environment/macos_environment.yml + environment-file: rf_diffusion/environment/ci_environment.yml init-shell: bash cache-environment: true diff --git a/rf_diffusion/environment/macos_environment.yml b/rf_diffusion/environment/macos_environment.yml index 649983d..d9b7b11 100644 --- a/rf_diffusion/environment/macos_environment.yml +++ b/rf_diffusion/environment/macos_environment.yml @@ -1,4 +1,4 @@ -name: RFDpoly_env +name: RFDpoly_env_macos channels: - pytorch - conda-forge @@ -18,4 +18,4 @@ dependencies: - pip: - dgl==1.0.1 - e3nn==0.5.1 - - hydra-core==1.3.2 \ No newline at end of file + - hydra-core==1.3.2 From 177752fee4edd6dfdeebb6d5e6fd323721fa4795 Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Tue, 2 Dec 2025 18:05:22 -0500 Subject: [PATCH 06/11] Add new environment file with fewer dependencies for tests --- rf_diffusion/environment/ci_environment.yml | 292 ++++++++++++++++++++ 1 file changed, 292 insertions(+) create mode 100644 rf_diffusion/environment/ci_environment.yml diff --git a/rf_diffusion/environment/ci_environment.yml b/rf_diffusion/environment/ci_environment.yml new file mode 100644 index 0000000..69b9fdf --- /dev/null +++ b/rf_diffusion/environment/ci_environment.yml @@ -0,0 +1,292 @@ +name: RFDpoly_env_ci +channels: + - pytorch + - pyg + - dglteam/label/cu117 + - nvidia + - conda-forge + - bioconda + - defaults +dependencies: + - _libgcc_mutex=0.1 + - _openmp_mutex=4.5 + - anyio=3.5.0 + - appdirs=1.4.4 + - argon2-cffi=21.3.0 + - argon2-cffi-bindings=21.2.0 + - asttokens=2.0.5 + - attrs=22.1.0 + - babel=2.11.0 + - backcall=0.2.0 + - beautifulsoup4=4.11.1 + - blas=1.0 + - bleach=4.1.0 + - bottleneck=1.3.5 + - brotli=1.0.9 + - brotli-bin=1.0.9 + - brotlipy=0.7.0 + - bzip2=1.0.8 + - ca-certificates=2022.12.7 + - cairo=1.16.0 + - certifi=2022.12.7 + - cffi=1.15.1 + - charset-normalizer=2.0.4 + - comm=0.1.2 + - conda=23.1.0 + - conda-content-trust=0.1.3 + - conda-package-handling=2.0.2 + - conda-package-streaming=0.7.0 + - contourpy=1.0.5 + - cryptography=39.0.1 + - cuda=11.7.1 + - cuda-cccl=11.7.91 + - cuda-command-line-tools=11.7.1 + - cuda-compiler=11.7.1 + - cuda-cudart=11.7.99 + - cuda-cudart-dev=11.7.99 + - cuda-cuobjdump=11.7.91 + - cuda-cupti=11.7.101 + - cuda-cuxxfilt=11.7.91 + - cuda-demo-suite=12.1.55 + - cuda-documentation=12.1.55 + - cuda-driver-dev=11.7.99 + - cuda-gdb=12.1.55 + - cuda-libraries=11.7.1 + - cuda-libraries-dev=11.7.1 + - cuda-memcheck=11.8.86 + - cuda-nsight=12.1.55 + - cuda-nsight-compute=12.1.0 + - cuda-nvcc=11.7.99 + - cuda-nvdisasm=12.1.55 + - cuda-nvml-dev=11.7.91 + - cuda-nvprof=12.1.55 + - cuda-nvprune=11.7.91 + - cuda-nvrtc=11.7.99 + - cuda-nvrtc-dev=11.7.99 + - cuda-nvtx=11.7.91 + - cuda-nvvp=12.1.55 + - cuda-runtime=11.7.1 + - cuda-sanitizer-api=12.1.55 + - cuda-toolkit=11.7.1 + - cuda-tools=11.7.1 + - cuda-visual-tools=11.7.1 + - cycler=0.11.0 + - dbus=1.13.18 + - debugpy=1.5.1 + - decorator=5.1.1 + - defusedxml=0.7.1 + - dgl=1.0.1.cu117 + - entrypoints=0.4 + - executing=0.8.3 + - expat=2.4.9 + - flit-core=3.6.0 + - fontconfig=2.14.1 + - fonttools=4.25.0 + - freetype=2.12.1 + - gds-tools=1.6.0.25 + - giflib=5.2.1 + - glib=2.69.1 + - gst-plugins-base=1.14.1 + - gstreamer=1.14.1 + - icu=58.2 + - idna=3.4 + - intel-openmp=2021.4.0 + - ipykernel=6.19.2 + - ipython=8.10.0 + - ipython_genutils=0.2.0 + - jedi=0.18.1 + - jinja2=3.1.2 + - joblib=1.1.1 + - jpeg=9e + - json5=0.9.6 + - jsonschema=4.17.3 + - jupyter_client=7.4.9 + - jupyter_core=5.2.0 + - jupyter_server=1.23.4 + - jupyterlab=3.5.3 + - jupyterlab_pygments=0.1.2 + - jupyterlab_server=2.19.0 + - kiwisolver=1.4.4 + - krb5=1.19.4 + - lcms2=2.12 + - ld_impl_linux-64=2.38 + - lerc=3.0 + - libbrotlicommon=1.0.9 + - libbrotlidec=1.0.9 + - libbrotlienc=1.0.9 + - libclang=10.0.1 + - libcublas=11.10.3.66 + - libcublas-dev=11.10.3.66 + - libcufft=10.7.2.124 + - libcufft-dev=10.7.2.124 + - libcufile=1.6.0.25 + - libcufile-dev=1.6.0.25 + - libcurand=10.3.2.56 + - libcurand-dev=10.3.2.56 + - libcusolver=11.4.0.1 + - libcusolver-dev=11.4.0.1 + - libcusparse=11.7.4.91 + - libcusparse-dev=11.7.4.91 + - libdeflate=1.17 + - libedit=3.1.20221030 + - libevent=2.1.12 + - libffi=3.4.2 + - libgcc-ng=12.2.0 + - libgfortran-ng=11.2.0 + - libgfortran5=11.2.0 + - libiconv=1.17 + - libllvm10=10.0.1 + - libnpp=11.7.4.75 + - libnpp-dev=11.7.4.75 + - libnvjpeg=11.8.0.2 + - libnvjpeg-dev=11.8.0.2 + - libpng=1.6.39 + - libpq=12.9 + - libsodium=1.0.18 + - libstdcxx-ng=11.2.0 + - libtiff=4.5.0 + - libuuid=1.41.5 + - libwebp=1.2.4 + - libwebp-base=1.2.4 + - libxcb=1.15 + - libxkbcommon=1.0.1 + - libxml2=2.9.14 + - libxslt=1.1.35 + - libzlib=1.2.13 + - llvm-openmp=15.0.7 + - lxml=4.9.1 + - lz4-c=1.9.4 + - markupsafe=2.1.1 + - matplotlib=3.7.0 + - matplotlib-base=3.7.0 + - matplotlib-inline=0.1.6 + - mistune=0.8.4 + - mkl=2021.4.0 + - mkl-service=2.4.0 + - mkl_fft=1.3.1 + - mkl_random=1.2.2 + - munkres=1.1.4 + - nbclassic=0.5.2 + - nbclient=0.5.13 + - nbconvert=6.5.4 + - nbformat=5.7.0 + - ncurses=6.4 + - nest-asyncio=1.5.6 + - networkx=2.8.4 + - notebook=6.5.2 + - notebook-shim=0.2.2 + - nsight-compute=2023.1.0.15 + - nspr=4.33 + - nss=3.74 + - numexpr=2.8.4 + - numpy=1.23.5 + - numpy-base=1.23.5 + - openbabel=3.1.1 + - openssl=1.1.1t + - packaging=22.0 + - pandas=1.5.3 + - pandocfilters=1.5.0 + - parso=0.8.3 + - pcre=8.45 + - pexpect=4.8.0 + - pickleshare=0.7.5 + - pillow=9.4.0 + - pip=23.0.1 + - pixman=0.40.0 + - platformdirs=2.5.2 + - pluggy=1.0.0 + - ply=3.11 + - pooch=1.4.0 + - prometheus_client=0.14.1 + - prompt-toolkit=3.0.36 + - psutil=5.9.0 + - ptyprocess=0.7.0 + - pure_eval=0.2.2 + - pycosat=0.6.4 + - pycparser=2.21 + - pyg=2.2.0 + - pygments=2.11.2 + - pyopenssl=23.0.0 + - pyparsing=3.0.9 + - pyqt=5.15.7 + - pyrsistent=0.18.0 + - pysocks=1.7.1 + - python=3.10.8 + - python-dateutil=2.8.2 + - python-fastjsonschema=2.16.2 + - python_abi=3.10 + - pytorch=1.13.1 + - pytorch-cluster=1.6.0 + - pytorch-cuda=11.7 + - pytorch-mutex=1.0 + - pytorch-scatter=2.1.0 + - pytorch-sparse=0.6.16 + - pytz=2022.7 + - pyzmq=23.2.0 + - qt-main=5.15.2 + - qt-webengine=5.15.9 + - qtwebkit=5.212 + - readline=8.2 + - requests=2.28.1 + - ruamel.yaml=0.17.21 + - ruamel.yaml.clib=0.2.6 + - scikit-learn=1.2.1 + - scipy=1.10.0 + - seaborn=0.12.2 + - send2trash=1.8.0 + - setuptools=65.5.0 + - sip=6.6.2 + - six=1.16.0 + - sniffio=1.2.0 + - soupsieve=2.3.2.post1 + - sqlite=3.40.1 + - stack_data=0.2.0 + - terminado=0.17.1 + - threadpoolctl=2.2.0 + - tinycss2=1.2.1 + - tk=8.6.12 + - toml=0.10.2 + - tomli=2.0.1 + - toolz=0.12.0 + - tornado=6.2 + - tqdm=4.64.1 + - traitlets=5.7.1 + - typing-extensions=4.4.0 + - typing_extensions=4.4.0 + - tzdata=2022g + - urllib3=1.26.14 + - wcwidth=0.2.5 + - webencodings=0.5.1 + - websocket-client=0.58.0 + - wheel=0.37.1 + - xz=5.2.10 + - zeromq=4.3.4 + - zlib=1.2.13 + - zstandard=0.19.0 + - zstd=1.5.2 + - pip: + - antlr4-python3-runtime==4.9.3 + - assertpy==1.1 + - click==8.1.3 + - colorama==0.4.6 + - deepdiff==6.2.3 + - docker-pycreds==0.4.0 + - e3nn==0.5.1 + - gitdb==4.0.10 + - gitpython==3.1.31 + - hydra-core==1.3.2 + - mpmath==1.3.0 + - omegaconf==2.3.0 + - opt-einsum==3.3.0 + - opt-einsum-fx==0.1.4 + - ordered-set==4.1.0 + - orjson==3.8.7 + - pathtools==0.1.2 + - protobuf==4.22.1 + - pyqt5-sip==12.11.0 + - pyyaml==6.0 + - sentry-sdk==1.16.0 + - setproctitle==1.3.2 + - smmap==5.0.0 + - sympy==1.11.1 + - wandb==0.13.11 From 5d96dc123493ad71953058271a8f006af9be4197 Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Wed, 3 Dec 2025 15:47:14 -0500 Subject: [PATCH 07/11] Make weights directory and specify to download weights there --- .github/workflows/test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5015c8a..ebf09a5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,8 @@ jobs: - name: Download weights run: | - curl -O https://files.ipd.uw.edu/pub/2025_RFDpoly/train_session2024-07-08_1720455712_BFF_3.00.pt + mkdir weights + curl -o weights/train_session2024-07-08_1720455712_BFF_3.00.pt https://files.ipd.uw.edu/pub/2025_RFDpoly/train_session2024-07-08_1720455712_BFF_3.00.pt - name: Run tests shell: micromamba-shell {0} From b3a91ff209a170a07d5100e2b241ebe2b4c28bb9 Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Fri, 30 Jan 2026 18:05:48 -0500 Subject: [PATCH 08/11] Update ci_environment so tests can run on github runners --- rf_diffusion/environment/ci_environment.yml | 72 ++++----------------- 1 file changed, 11 insertions(+), 61 deletions(-) diff --git a/rf_diffusion/environment/ci_environment.yml b/rf_diffusion/environment/ci_environment.yml index 69b9fdf..981a089 100644 --- a/rf_diffusion/environment/ci_environment.yml +++ b/rf_diffusion/environment/ci_environment.yml @@ -1,13 +1,21 @@ -name: RFDpoly_env_ci +name: RFDpoly_env_ci_test channels: - pytorch - pyg - - dglteam/label/cu117 - - nvidia + - dglteam - conda-forge - bioconda - defaults dependencies: + - python=3.10.8 + - pip=23.0.1 + - pytorch=1.13.1 + - cpuonly + - pyg=2.2.0 + - pytorch-scatter=2.1.0 + - pytorch-sparse=0.6.16 + - pytorch-cluster=1.6.0 + - dgl=1.0.1 - _libgcc_mutex=0.1 - _openmp_mutex=4.5 - anyio=3.5.0 @@ -38,44 +46,11 @@ dependencies: - conda-package-streaming=0.7.0 - contourpy=1.0.5 - cryptography=39.0.1 - - cuda=11.7.1 - - cuda-cccl=11.7.91 - - cuda-command-line-tools=11.7.1 - - cuda-compiler=11.7.1 - - cuda-cudart=11.7.99 - - cuda-cudart-dev=11.7.99 - - cuda-cuobjdump=11.7.91 - - cuda-cupti=11.7.101 - - cuda-cuxxfilt=11.7.91 - - cuda-demo-suite=12.1.55 - - cuda-documentation=12.1.55 - - cuda-driver-dev=11.7.99 - - cuda-gdb=12.1.55 - - cuda-libraries=11.7.1 - - cuda-libraries-dev=11.7.1 - - cuda-memcheck=11.8.86 - - cuda-nsight=12.1.55 - - cuda-nsight-compute=12.1.0 - - cuda-nvcc=11.7.99 - - cuda-nvdisasm=12.1.55 - - cuda-nvml-dev=11.7.91 - - cuda-nvprof=12.1.55 - - cuda-nvprune=11.7.91 - - cuda-nvrtc=11.7.99 - - cuda-nvrtc-dev=11.7.99 - - cuda-nvtx=11.7.91 - - cuda-nvvp=12.1.55 - - cuda-runtime=11.7.1 - - cuda-sanitizer-api=12.1.55 - - cuda-toolkit=11.7.1 - - cuda-tools=11.7.1 - - cuda-visual-tools=11.7.1 - cycler=0.11.0 - dbus=1.13.18 - debugpy=1.5.1 - decorator=5.1.1 - defusedxml=0.7.1 - - dgl=1.0.1.cu117 - entrypoints=0.4 - executing=0.8.3 - expat=2.4.9 @@ -83,7 +58,6 @@ dependencies: - fontconfig=2.14.1 - fonttools=4.25.0 - freetype=2.12.1 - - gds-tools=1.6.0.25 - giflib=5.2.1 - glib=2.69.1 - gst-plugins-base=1.14.1 @@ -115,18 +89,6 @@ dependencies: - libbrotlidec=1.0.9 - libbrotlienc=1.0.9 - libclang=10.0.1 - - libcublas=11.10.3.66 - - libcublas-dev=11.10.3.66 - - libcufft=10.7.2.124 - - libcufft-dev=10.7.2.124 - - libcufile=1.6.0.25 - - libcufile-dev=1.6.0.25 - - libcurand=10.3.2.56 - - libcurand-dev=10.3.2.56 - - libcusolver=11.4.0.1 - - libcusolver-dev=11.4.0.1 - - libcusparse=11.7.4.91 - - libcusparse-dev=11.7.4.91 - libdeflate=1.17 - libedit=3.1.20221030 - libevent=2.1.12 @@ -136,10 +98,6 @@ dependencies: - libgfortran5=11.2.0 - libiconv=1.17 - libllvm10=10.0.1 - - libnpp=11.7.4.75 - - libnpp-dev=11.7.4.75 - - libnvjpeg=11.8.0.2 - - libnvjpeg-dev=11.8.0.2 - libpng=1.6.39 - libpq=12.9 - libsodium=1.0.18 @@ -175,7 +133,6 @@ dependencies: - networkx=2.8.4 - notebook=6.5.2 - notebook-shim=0.2.2 - - nsight-compute=2023.1.0.15 - nspr=4.33 - nss=3.74 - numexpr=2.8.4 @@ -211,16 +168,9 @@ dependencies: - pyqt=5.15.7 - pyrsistent=0.18.0 - pysocks=1.7.1 - - python=3.10.8 - python-dateutil=2.8.2 - python-fastjsonschema=2.16.2 - python_abi=3.10 - - pytorch=1.13.1 - - pytorch-cluster=1.6.0 - - pytorch-cuda=11.7 - - pytorch-mutex=1.0 - - pytorch-scatter=2.1.0 - - pytorch-sparse=0.6.16 - pytz=2022.7 - pyzmq=23.2.0 - qt-main=5.15.2 From 167ec9a500c661761d901d1642a399173774664f Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Mon, 2 Feb 2026 11:48:46 -0500 Subject: [PATCH 09/11] Avoid NVTX crashes on CPU-only environments (including github runners) by making profiling optional --- .../se3_transformer/model/basis.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py index 74f04a0..446dd56 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py @@ -33,6 +33,21 @@ from se3_transformer.runtime.utils import degree_to_dim +from contextlib import contextmanager + +@contextmanager +def safe_nvtx_range(msg): + try: + import torch + if torch.cuda.is_available(): + from torch.cuda.nvtx import range as nvtx_range + with nvtx_range(msg): + yield + else: + yield + except Exception: + # NVTX missing or broken → just run the code + yield @lru_cache(maxsize=None) def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor: @@ -54,7 +69,7 @@ def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]: def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]: all_degrees = list(range(2 * max_degree + 1)) - with nvtx_range('spherical harmonics'): + with safe_nvtx_range('spherical harmonics'): sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True) return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1) From 9fed09f6c09bb5ab1f164ba97abf3d4a500341ac Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Mon, 2 Feb 2026 12:01:58 -0500 Subject: [PATCH 10/11] Include safe NVTX through the rest of basis.py --- .../rf2aa/SE3Transformer/se3_transformer/model/basis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py index 446dd56..27ea1ba 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py @@ -178,13 +178,13 @@ def get_basis(relative_pos: Tensor, compute_gradients: bool = False, use_pad_trick: bool = False, amp: bool = False) -> Dict[str, Tensor]: - with nvtx_range('spherical harmonics'): + with safe_nvtx_range('spherical harmonics'): spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree) - with nvtx_range('CB coefficients'): + with safe_nvtx_range('CB coefficients'): clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device) with torch.autograd.set_grad_enabled(compute_gradients): - with nvtx_range('bases'): + with safe_nvtx_range('bases'): basis = get_basis_script(max_degree=max_degree, use_pad_trick=use_pad_trick, spherical_harmonics=spherical_harmonics, From 87513554a214bb70915d8de489f82fd9a5be3b9f Mon Sep 17 00:00:00 2001 From: Hope Woods Date: Mon, 2 Feb 2026 12:39:46 -0500 Subject: [PATCH 11/11] Adding util function to avoid NVTX crashes on CPU --- .../se3_transformer/model/basis.py | 25 +++------------ .../se3_transformer/model/layers/attention.py | 2 +- .../model/layers/convolution.py | 2 +- .../se3_transformer/model/layers/norm.py | 2 +- .../se3_transformer/utils/nvtx.py | 32 +++++++++++++++++++ 5 files changed, 40 insertions(+), 23 deletions(-) create mode 100644 rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/utils/nvtx.py diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py index 27ea1ba..617e0ed 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py @@ -29,25 +29,10 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range +from se3_transformer.utils.nvtx import nvtx_range from se3_transformer.runtime.utils import degree_to_dim -from contextlib import contextmanager - -@contextmanager -def safe_nvtx_range(msg): - try: - import torch - if torch.cuda.is_available(): - from torch.cuda.nvtx import range as nvtx_range - with nvtx_range(msg): - yield - else: - yield - except Exception: - # NVTX missing or broken → just run the code - yield @lru_cache(maxsize=None) def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor: @@ -69,7 +54,7 @@ def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]: def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]: all_degrees = list(range(2 * max_degree + 1)) - with safe_nvtx_range('spherical harmonics'): + with nvtx_range('spherical harmonics'): sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True) return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1) @@ -178,13 +163,13 @@ def get_basis(relative_pos: Tensor, compute_gradients: bool = False, use_pad_trick: bool = False, amp: bool = False) -> Dict[str, Tensor]: - with safe_nvtx_range('spherical harmonics'): + with nvtx_range('spherical harmonics'): spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree) - with safe_nvtx_range('CB coefficients'): + with nvtx_range('CB coefficients'): clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device) with torch.autograd.set_grad_enabled(compute_gradients): - with safe_nvtx_range('bases'): + with nvtx_range('bases'): basis = get_basis_script(max_degree=max_degree, use_pad_trick=use_pad_trick, spherical_harmonics=spherical_harmonics, diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py index 1be3219..4511c4c 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py @@ -34,7 +34,7 @@ from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel from se3_transformer.model.layers.linear import LinearSE3 from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features -from torch.cuda.nvtx import range as nvtx_range +from se3_transformer.utils.nvtx import nvtx_range class AttentionSE3(nn.Module): diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py index 3ef1693..61aa29b 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py @@ -31,7 +31,7 @@ import torch.nn as nn from dgl import DGLGraph from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range +from se3_transformer.utils.nvtx import nvtx_range from se3_transformer.model.fiber import Fiber from se3_transformer.runtime.utils import degree_to_dim, unfuse_features diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py index acbe23d..4a8ee31 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range +from se3_transformer.utils.nvtx import nvtx_range from se3_transformer.model.fiber import Fiber diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/utils/nvtx.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/utils/nvtx.py new file mode 100644 index 0000000..d6939a4 --- /dev/null +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/utils/nvtx.py @@ -0,0 +1,32 @@ +# se3_transformer/utils/nvtx.py +from __future__ import annotations + +from contextlib import contextmanager +from typing import Iterator + +@contextmanager +def nvtx_range(message: str) -> Iterator[None]: + """ + Safe NVTX range context manager. + + - If running with CUDA + NVTX support, emits real NVTX ranges. + - Otherwise, becomes a no-op (CPU-only CI, ROCm-only builds, etc). + """ + try: + import torch + + if torch.cuda.is_available() and hasattr(torch.cuda, "nvtx"): + try: + from torch.cuda.nvtx import range as _nvtx_range + with _nvtx_range(message): + yield + return + except Exception: + # CUDA available but NVTX missing/misconfigured -> fall back to no-op + pass + + yield + except Exception: + # torch not importable or other unexpected env issue -> no-op + yield +