diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..ebf09a5 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,37 @@ +name: Run Tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + # Set up micromamba + - name: Set up micromamba + uses: mamba-org/setup-micromamba@v2 + with: + environment-file: rf_diffusion/environment/ci_environment.yml + init-shell: bash + cache-environment: true + + - name: Install pytest + shell: micromamba-shell {0} + run: | + python -m pip install pytest + + - name: Download weights + run: | + mkdir weights + curl -o weights/train_session2024-07-08_1720455712_BFF_3.00.pt https://files.ipd.uw.edu/pub/2025_RFDpoly/train_session2024-07-08_1720455712_BFF_3.00.pt + + - name: Run tests + shell: micromamba-shell {0} + run: | + python -m pytest test/test_demo.py diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py index 74f04a0..617e0ed 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py @@ -29,7 +29,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range +from se3_transformer.utils.nvtx import nvtx_range from se3_transformer.runtime.utils import degree_to_dim diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py index 1be3219..4511c4c 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py @@ -34,7 +34,7 @@ from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel from se3_transformer.model.layers.linear import LinearSE3 from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features -from torch.cuda.nvtx import range as nvtx_range +from se3_transformer.utils.nvtx import nvtx_range class AttentionSE3(nn.Module): diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py index 3ef1693..61aa29b 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py @@ -31,7 +31,7 @@ import torch.nn as nn from dgl import DGLGraph from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range +from se3_transformer.utils.nvtx import nvtx_range from se3_transformer.model.fiber import Fiber from se3_transformer.runtime.utils import degree_to_dim, unfuse_features diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py index acbe23d..4a8ee31 100644 --- a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range +from se3_transformer.utils.nvtx import nvtx_range from se3_transformer.model.fiber import Fiber diff --git a/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/utils/nvtx.py b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/utils/nvtx.py new file mode 100644 index 0000000..d6939a4 --- /dev/null +++ b/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/utils/nvtx.py @@ -0,0 +1,32 @@ +# se3_transformer/utils/nvtx.py +from __future__ import annotations + +from contextlib import contextmanager +from typing import Iterator + +@contextmanager +def nvtx_range(message: str) -> Iterator[None]: + """ + Safe NVTX range context manager. + + - If running with CUDA + NVTX support, emits real NVTX ranges. + - Otherwise, becomes a no-op (CPU-only CI, ROCm-only builds, etc). + """ + try: + import torch + + if torch.cuda.is_available() and hasattr(torch.cuda, "nvtx"): + try: + from torch.cuda.nvtx import range as _nvtx_range + with _nvtx_range(message): + yield + return + except Exception: + # CUDA available but NVTX missing/misconfigured -> fall back to no-op + pass + + yield + except Exception: + # torch not importable or other unexpected env issue -> no-op + yield + diff --git a/rf_diffusion/environment/ci_environment.yml b/rf_diffusion/environment/ci_environment.yml new file mode 100644 index 0000000..981a089 --- /dev/null +++ b/rf_diffusion/environment/ci_environment.yml @@ -0,0 +1,242 @@ +name: RFDpoly_env_ci_test +channels: + - pytorch + - pyg + - dglteam + - conda-forge + - bioconda + - defaults +dependencies: + - python=3.10.8 + - pip=23.0.1 + - pytorch=1.13.1 + - cpuonly + - pyg=2.2.0 + - pytorch-scatter=2.1.0 + - pytorch-sparse=0.6.16 + - pytorch-cluster=1.6.0 + - dgl=1.0.1 + - _libgcc_mutex=0.1 + - _openmp_mutex=4.5 + - anyio=3.5.0 + - appdirs=1.4.4 + - argon2-cffi=21.3.0 + - argon2-cffi-bindings=21.2.0 + - asttokens=2.0.5 + - attrs=22.1.0 + - babel=2.11.0 + - backcall=0.2.0 + - beautifulsoup4=4.11.1 + - blas=1.0 + - bleach=4.1.0 + - bottleneck=1.3.5 + - brotli=1.0.9 + - brotli-bin=1.0.9 + - brotlipy=0.7.0 + - bzip2=1.0.8 + - ca-certificates=2022.12.7 + - cairo=1.16.0 + - certifi=2022.12.7 + - cffi=1.15.1 + - charset-normalizer=2.0.4 + - comm=0.1.2 + - conda=23.1.0 + - conda-content-trust=0.1.3 + - conda-package-handling=2.0.2 + - conda-package-streaming=0.7.0 + - contourpy=1.0.5 + - cryptography=39.0.1 + - cycler=0.11.0 + - dbus=1.13.18 + - debugpy=1.5.1 + - decorator=5.1.1 + - defusedxml=0.7.1 + - entrypoints=0.4 + - executing=0.8.3 + - expat=2.4.9 + - flit-core=3.6.0 + - fontconfig=2.14.1 + - fonttools=4.25.0 + - freetype=2.12.1 + - giflib=5.2.1 + - glib=2.69.1 + - gst-plugins-base=1.14.1 + - gstreamer=1.14.1 + - icu=58.2 + - idna=3.4 + - intel-openmp=2021.4.0 + - ipykernel=6.19.2 + - ipython=8.10.0 + - ipython_genutils=0.2.0 + - jedi=0.18.1 + - jinja2=3.1.2 + - joblib=1.1.1 + - jpeg=9e + - json5=0.9.6 + - jsonschema=4.17.3 + - jupyter_client=7.4.9 + - jupyter_core=5.2.0 + - jupyter_server=1.23.4 + - jupyterlab=3.5.3 + - jupyterlab_pygments=0.1.2 + - jupyterlab_server=2.19.0 + - kiwisolver=1.4.4 + - krb5=1.19.4 + - lcms2=2.12 + - ld_impl_linux-64=2.38 + - lerc=3.0 + - libbrotlicommon=1.0.9 + - libbrotlidec=1.0.9 + - libbrotlienc=1.0.9 + - libclang=10.0.1 + - libdeflate=1.17 + - libedit=3.1.20221030 + - libevent=2.1.12 + - libffi=3.4.2 + - libgcc-ng=12.2.0 + - libgfortran-ng=11.2.0 + - libgfortran5=11.2.0 + - libiconv=1.17 + - libllvm10=10.0.1 + - libpng=1.6.39 + - libpq=12.9 + - libsodium=1.0.18 + - libstdcxx-ng=11.2.0 + - libtiff=4.5.0 + - libuuid=1.41.5 + - libwebp=1.2.4 + - libwebp-base=1.2.4 + - libxcb=1.15 + - libxkbcommon=1.0.1 + - libxml2=2.9.14 + - libxslt=1.1.35 + - libzlib=1.2.13 + - llvm-openmp=15.0.7 + - lxml=4.9.1 + - lz4-c=1.9.4 + - markupsafe=2.1.1 + - matplotlib=3.7.0 + - matplotlib-base=3.7.0 + - matplotlib-inline=0.1.6 + - mistune=0.8.4 + - mkl=2021.4.0 + - mkl-service=2.4.0 + - mkl_fft=1.3.1 + - mkl_random=1.2.2 + - munkres=1.1.4 + - nbclassic=0.5.2 + - nbclient=0.5.13 + - nbconvert=6.5.4 + - nbformat=5.7.0 + - ncurses=6.4 + - nest-asyncio=1.5.6 + - networkx=2.8.4 + - notebook=6.5.2 + - notebook-shim=0.2.2 + - nspr=4.33 + - nss=3.74 + - numexpr=2.8.4 + - numpy=1.23.5 + - numpy-base=1.23.5 + - openbabel=3.1.1 + - openssl=1.1.1t + - packaging=22.0 + - pandas=1.5.3 + - pandocfilters=1.5.0 + - parso=0.8.3 + - pcre=8.45 + - pexpect=4.8.0 + - pickleshare=0.7.5 + - pillow=9.4.0 + - pip=23.0.1 + - pixman=0.40.0 + - platformdirs=2.5.2 + - pluggy=1.0.0 + - ply=3.11 + - pooch=1.4.0 + - prometheus_client=0.14.1 + - prompt-toolkit=3.0.36 + - psutil=5.9.0 + - ptyprocess=0.7.0 + - pure_eval=0.2.2 + - pycosat=0.6.4 + - pycparser=2.21 + - pyg=2.2.0 + - pygments=2.11.2 + - pyopenssl=23.0.0 + - pyparsing=3.0.9 + - pyqt=5.15.7 + - pyrsistent=0.18.0 + - pysocks=1.7.1 + - python-dateutil=2.8.2 + - python-fastjsonschema=2.16.2 + - python_abi=3.10 + - pytz=2022.7 + - pyzmq=23.2.0 + - qt-main=5.15.2 + - qt-webengine=5.15.9 + - qtwebkit=5.212 + - readline=8.2 + - requests=2.28.1 + - ruamel.yaml=0.17.21 + - ruamel.yaml.clib=0.2.6 + - scikit-learn=1.2.1 + - scipy=1.10.0 + - seaborn=0.12.2 + - send2trash=1.8.0 + - setuptools=65.5.0 + - sip=6.6.2 + - six=1.16.0 + - sniffio=1.2.0 + - soupsieve=2.3.2.post1 + - sqlite=3.40.1 + - stack_data=0.2.0 + - terminado=0.17.1 + - threadpoolctl=2.2.0 + - tinycss2=1.2.1 + - tk=8.6.12 + - toml=0.10.2 + - tomli=2.0.1 + - toolz=0.12.0 + - tornado=6.2 + - tqdm=4.64.1 + - traitlets=5.7.1 + - typing-extensions=4.4.0 + - typing_extensions=4.4.0 + - tzdata=2022g + - urllib3=1.26.14 + - wcwidth=0.2.5 + - webencodings=0.5.1 + - websocket-client=0.58.0 + - wheel=0.37.1 + - xz=5.2.10 + - zeromq=4.3.4 + - zlib=1.2.13 + - zstandard=0.19.0 + - zstd=1.5.2 + - pip: + - antlr4-python3-runtime==4.9.3 + - assertpy==1.1 + - click==8.1.3 + - colorama==0.4.6 + - deepdiff==6.2.3 + - docker-pycreds==0.4.0 + - e3nn==0.5.1 + - gitdb==4.0.10 + - gitpython==3.1.31 + - hydra-core==1.3.2 + - mpmath==1.3.0 + - omegaconf==2.3.0 + - opt-einsum==3.3.0 + - opt-einsum-fx==0.1.4 + - ordered-set==4.1.0 + - orjson==3.8.7 + - pathtools==0.1.2 + - protobuf==4.22.1 + - pyqt5-sip==12.11.0 + - pyyaml==6.0 + - sentry-sdk==1.16.0 + - setproctitle==1.3.2 + - smmap==5.0.0 + - sympy==1.11.1 + - wandb==0.13.11 diff --git a/rf_diffusion/environment/macos_environment.yml b/rf_diffusion/environment/macos_environment.yml index 649983d..d9b7b11 100644 --- a/rf_diffusion/environment/macos_environment.yml +++ b/rf_diffusion/environment/macos_environment.yml @@ -1,4 +1,4 @@ -name: RFDpoly_env +name: RFDpoly_env_macos channels: - pytorch - conda-forge @@ -18,4 +18,4 @@ dependencies: - pip: - dgl==1.0.1 - e3nn==0.5.1 - - hydra-core==1.3.2 \ No newline at end of file + - hydra-core==1.3.2 diff --git a/test/test_demo.py b/test/test_demo.py new file mode 100644 index 0000000..b886e9d --- /dev/null +++ b/test/test_demo.py @@ -0,0 +1,129 @@ +import subprocess +import sys +import os +import pytest +from pathlib import Path + +TEST_DIR = Path(__file__).resolve().parent +PROJECT_ROOT = (TEST_DIR / "..").resolve() + +DEFAULT_CKPT = "weights/train_session2024-07-08_1720455712_BFF_3.00.pt" + +# Path to rf_diffusion/run_inference.py +RUN_INFERENCE = PROJECT_ROOT / "rf_diffusion" / "run_inference.py" +SCENARIOS = [ + ( + "rna_unconditional", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['9']", + "contigmap.polymer_chains=['rna']", + "inference.output_prefix=demo_outputs/RNA_uncond_standard_settings", + ], + ), + ( + "multi_polymer_unconditional", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['3 3 3']", + "contigmap.polymer_chains=['dna','rna','protein']", + "inference.output_prefix=test_outputs/basic_uncond_test01", + ], + ), + ( + "dna_binder_unconditional", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['2 2 5']", + "contigmap.polymer_chains=['dna','dna','protein']", + "inference.output_prefix=demo_outputs/DNA_prot_uncond_standard_settings", + ], + ), + ( + "rna_secondary_structure", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['9']", + "contigmap.polymer_chains=['rna']", + "scaffoldguided.target_ss_string=555...333", + ], + ), + ( + "motif_scaffolding_v1", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['1,D8-10,1,B8-10,1 1,B18-20,1,D18-20,1 A1-3,0 C1-3,0']", + "contigmap.polymer_chains=['dna','dna','protein','protein']", + "inference.ij_visible=bce-adf", + "inference.input_pdb=test_data/combo_DBP009_DBP010_DBP011_with_DNA_v2.pdb", + "inference.output_prefix=demo_outputs/DNA_binders_scaffolding_test1_standard_settings", + ], + ), + ( + "motif_scaffolding_v2", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['1,D8-10,1,B8-10,1 1,B18-20,1,D18-20,1 A1-3,3,C1-3,0']", + "contigmap.polymer_chains=['dna','dna','protein']", + "scaffoldguided.target_ss_pairs=[\"A1-9,B1-9\"]", + "inference.ij_visible=bce-adf", + "inference.input_pdb=test_data/combo_DBP009_DBP010_DBP011_with_DNA_v2.pdb", + "inference.output_prefix=demo_outputs/DNA_binders_scaffolding_test2_standard_settings", + ], + ), + ( + "dna_pair_specification", + [ + "diffuser.T=50", + "inference.num_designs=1", + "contigmap.contigs=['6 6 6 6']", + "contigmap.polymer_chains=['dna','dna','dna','dna']", + "scaffoldguided.target_ss_pairs=['A1-2,B1-2','A3-4,C3-4','A5-6,D5-6','B3-4,D3-4','B5-6,C5-6','C1-2,D1-2']", + "inference.symmetry=d2", + "inference.output_prefix=demo_outputs/DNA_origami_standard_settings", + ], + ), +] + + +@pytest.mark.parametrize("name, overrides", SCENARIOS) +def test_multi_polymer_scenarios(name, overrides): + """ + Runs rf_diffusion/run_inference.py via subprocess for each scenario + """ + # set path to weights, assuming weights have been downloaded in root directory + os.environ.setdefault("RFDPOLY_CKPT_PATH", DEFAULT_CKPT) + + # Build the command + cmd = [ + sys.executable, + str(RUN_INFERENCE), + # Include this if your Hydra app requires an explicit config name. + # If run_inference.py already sets config_name in @hydra.main, you + # can remove the two lines below. + "--config-name", + "multi_polymer", + *overrides, + ] + + result = subprocess.run( + cmd, + cwd=str(PROJECT_ROOT), + capture_output=True, + text=True, + ) + + # Debug info if something fails + if result.returncode != 0: + print(f"\n=== Scenario {name} failed ===") + print("STDOUT:\n", result.stdout) + print("STDERR:\n", result.stderr) + + assert result.returncode == 0 +