diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d9ff48f..528fd4e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -54,7 +54,7 @@ jobs: - name: Download ANTsXNet data and models run: | pip install -e . - python download_all_data.py --strict --cache-dir ${{ runner.temp }}/ANTsXNet + python download_antsxnet_data.py --strict --cache-dir ${{ runner.temp }}/ANTsXNet - name: Upload data to artifact uses: actions/upload-artifact@v6 diff --git a/docker/Dockerfile b/docker/Dockerfile index 808f4b2..df65ed8 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -47,11 +47,17 @@ RUN useradd --create-home antspyuser && \ COPY --from=builder --chown=antspyuser \ /opt/antspydata/* /home/antspyuser/.antspy/ -COPY docker/get_antsxnet_data.py /opt/bin/ - -RUN echo "Install data option: ${install_antsxnet_data}" -RUN . ${VIRTUAL_ENV}/bin/activate && \ - /opt/bin/get_antsxnet_data.py /home/antspyuser/.keras ${INSTALL_ANTSXNET_DATA} && \ +COPY download_antsxnet_data.py /opt/bin/ + +RUN echo "Install data option: ${INSTALL_ANTSXNET_DATA}" && \ + if [ "${INSTALL_ANTSXNET_DATA}" = "1" ]; then \ + . "${VIRTUAL_ENV}/bin/activate" && \ + /opt/bin/download_antsxnet_data.py \ + --cache-dir /home/antspyuser/.keras/ANTsXNet \ + --strict ; \ + else \ + echo "Skipping ANTsXNet data download"; \ + fi && \ chmod -R 0755 /home/antspyuser/.antspy /home/antspyuser/.keras WORKDIR /home/antspyuser diff --git a/docker/README.md b/docker/README.md index a8acee3..ae2b813 100644 --- a/docker/README.md +++ b/docker/README.md @@ -42,7 +42,8 @@ docker build \ ``` This will make the container larger, but all data and pretrained networks will be -available at run time without downloading. +available at run time without downloading. You can also download a subset of data / +networks, see the help for the `download_antsxnet_data.py` script. ## Downloading data to a local cache @@ -54,19 +55,18 @@ you are using, and preferably make it read-only after populating it. To download and networks, run ``` docker run --rm -it antspynet:latest \ - /opt/bin/get_antsxnet_data.py \ - /path/to/local/cache/dir \ - 1 + /opt/bin/download_antsxnet_data.py \ + --cache-dir /path/to/local/cache/dir ``` You can also download a subset of data / networks by providing a list of names in a text file, one per line ``` docker run --rm -it antspynet:latest \ - /opt/bin/get_antsxnet_data.py \ - /path/to/local/cache/dir \ - 1 \ - /path/to/names.txt + /opt/bin/download_antsxnet_data.py \ + --cache-dir /path/to/local/cache/dir \ + --data-key-file datakeys.txt \ + --model-key-file modelkeys.txt ``` If the local cache directory is not mounted as `/home/antspyuser/.keras`, runtime scripts diff --git a/docker/get_antsxnet_data.py b/docker/get_antsxnet_data.py deleted file mode 100755 index 6cc1ac8..0000000 --- a/docker/get_antsxnet_data.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python - -# -# This gets ANTsXNet data and pretrained networks -# -# Getting all the data ahead of time is optional, by default it is downloaded -# on demand and stored in ~/.keras/ANTsXNet . But for complete reproducibility, -# or for applications lacking Internet access, the data can be downloaded with -# this script. -# -# To use the data in a container, mount the data directory containing keras.json -# as $USER/.keras. -# -# -import sys - -if (len(sys.argv) == 1): - usage = ''' - Usage: {} /path/to/ANTsXNetData [doInstall=1] [dataList.txt] [networkList.txt] - - Second argument can be passed to skip installation in docker files. - - Subsequent arguments, if specified, read a list of things to fetch from a text file. - This can be used to get a subset of the data / networks. - - Downloads ANTsXNet data and networks to the specified directory. - - The path MUST be absolute or it will be interpreted relative to - the default ~/.keras -''' - print(usage.format(sys.argv[0])) - - sys.exit(1) - -import antspynet - -# Base output dir, make ANTsXNet/ and keras.json under here -output_dir=sys.argv[1] - -do_install=1 - -if len(sys.argv) > 2: - do_install = int(sys.argv[2]) - -if do_install == 0: - # Exit 0, so docker won't think there's an error - sys.exit(0) - -data_path = f"{output_dir}/ANTsXNet" - -all_data = list() - -if len(sys.argv) > 3: - with open(sys.argv[3]) as f: - all_data = f.read().splitlines() -else: - all_data = list(antspynet.get_antsxnet_data('show')) - all_data.remove('show') - -antspynet.set_antsxnet_cache_directory(data_path) - -for entry in all_data: - print(f"Downloading {entry}") - try: - antspynet.get_antsxnet_data(entry) - except NotImplementedError as e: - print(f"Failed to download {entry}") - -all_networks = list() - -if len(sys.argv) > 4: - with open(sys.argv[4]) as f: - all_networks = f.read().splitlines() -else: - all_networks = list(antspynet.get_pretrained_network('show')) - all_networks.remove('show') - -for entry in all_networks: - print(f"Downloading {entry}") - try: - antspynet.get_pretrained_network(entry) - except NotImplementedError as e: - print(f"Failed to download {entry}") - diff --git a/download_all_data.py b/download_all_data.py deleted file mode 100644 index e85fada..0000000 --- a/download_all_data.py +++ /dev/null @@ -1,63 +0,0 @@ -import antspynet -import argparse -import sys -import tensorflow as tf - -def download_all_data(strict=False, cache_dir=None): - print("Downloading data files from get_antsxnet_data...") - if cache_dir is not None: - antspynet.set_antsxnet_cache_directory(cache_dir) - print(f"Using custom cache directory: {cache_dir}") - try: - data_keys = antspynet.get_antsxnet_data("show") - for key in data_keys: - if key == "show": - continue - try: - print(f" ↳ Downloading data: {key}") - fpath = antspynet.get_antsxnet_data(key) - print(f" ✓ Saved to: {fpath}") - except Exception as e: - print(f" ✗ Failed to download {key}: {e}") - if strict: - raise - except Exception as e: - print(f"✗ Failed to retrieve data keys: {e}") - if strict: - sys.exit(1) - - print("\nDownloading model weights from get_pretrained_network...") - try: - model_keys = antspynet.get_pretrained_network("show") - for key in model_keys: - if key == "show": - continue - try: - print(f" ↳ Downloading model: {key}") - fpath = antspynet.get_pretrained_network(key) - print(f" ✓ Saved to: {fpath}") - except Exception as e: - print(f" ✗ Failed to download {key}: {e}") - if strict: - raise - except Exception as e: - print(f"✗ Failed to retrieve model keys: {e}") - if strict: - sys.exit(1) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--strict", action="store_true", help="Exit on first failed download.") - parser.add_argument("--verbose", action="store_true", help="Enable verbose output showing download progress.") - parser.add_argument("--cache-dir", type=str, help="Custom cache directory for downloads.", default=None) - args = parser.parse_args() - - if not args.verbose: - # This stops download progress logs from clogging the output in non-interactive environments - tf.keras.utils.disable_interactive_logging() - - try: - download_all_data(strict=args.strict, cache_dir=args.cache_dir) - except Exception as e: - print(f"\nAborted due to error: {e}") - sys.exit(1) diff --git a/download_antsxnet_data.py b/download_antsxnet_data.py new file mode 100755 index 0000000..9600dc6 --- /dev/null +++ b/download_antsxnet_data.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python + +import antspynet +import argparse +import os +import sys +import tensorflow as tf + +def download_data(strict=False, cache_dir=None, data_keys=None, model_keys=None, verbose=False): + """Download data and / or networks. + + If called with no arguments, this will attempt to download all data and networks to the default cache directory. + + Args: + strict (bool, optional): Exit with error if any download fails + cache_dir (_str_, optional): Cache directory to use for downloads. If None, the default cache directory + `~/.keras` will be used. + data_keys (_list_ of _str_, optional): List of data keys to download. If None, all available data will be downloaded. + model_keys (_list_ of _str_, optional): List of model keys to download. If None, all available models will be + downloaded. + verbose (bool, optional): If True, show more details of downloads. Default is False. + """ + print("Downloading data files from get_antsxnet_data...") + if cache_dir is not None: + antspynet.set_antsxnet_cache_directory(cache_dir) + print(f"Using custom cache directory: {cache_dir}") + try: + if data_keys is None: + data_keys = antspynet.get_antsxnet_data("show") + for key in data_keys: + if key == "show": + continue + try: + print(f" ↳ Downloading data: {key}") + fpath = antspynet.get_antsxnet_data(key) + if verbose: + print(f" ✓ Saved to: {fpath}") + except Exception as e: + print(f" ✗ Failed to download {key}: {e}") + if strict: + raise + except Exception as e: + print(f"✗ Failed to retrieve data keys: {e}") + if strict: + sys.exit(1) + + print("\nDownloading model weights from get_pretrained_network...") + try: + if model_keys is None: + model_keys = antspynet.get_pretrained_network("show") + for key in model_keys: + if key == "show": + continue + try: + print(f" ↳ Downloading model: {key}") + fpath = antspynet.get_pretrained_network(key) + if verbose: + print(f" ✓ Saved to: {fpath}") + except Exception as e: + print(f" ✗ Failed to download {key}: {e}") + if strict: + raise + except Exception as e: + print(f"✗ Failed to retrieve model keys: {e}") + if strict: + sys.exit(1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=""" + Download ANTsXNet data and / or pretrained models to the local cache directory. + + The default behavior is to download all available data and models to the default cache directory (`~/.keras`). + + Optionally, the cache directory can be customized, and specific data and model keys can be specified. + + If any download fails, the script will continue by default, but this can be changed with the `--strict` flag to exit on the first failure. + + """) + parser.add_argument("--strict", action="store_true", help="Exit on first failed download.") + parser.add_argument("--verbose", action="store_true", help="Enable verbose output showing download progress.") + parser.add_argument("--cache-dir", type=str, help="Custom cache directory for downloads.", default=None) + parser.add_argument("--data-key-file", type=str, help="Text file containing a list of data keys to download, one per line.", + default=None) + parser.add_argument("--model-key-file", type=str, help="Text file containing a list of model keys to download, one per " + "line.", default=None) + parser.add_argument("--data-keys", nargs='+', type=str, help="One or more data keys to download, separated by spaces.", + default=None) + parser.add_argument("--model-keys", nargs='+',type=str, help="One or more model keys to download, separated by spaces.", + default=None) + args = parser.parse_args() + + if not args.verbose: + # This stops download progress logs from clogging the output in non-interactive environments + tf.keras.utils.disable_interactive_logging() + + # Cache dir must be an absolute path, otherwise keras will interpret it relative to ~/.keras/ + if args.cache_dir is not None: + args.cache_dir = os.path.abspath(args.cache_dir) + + if args.data_key_file is not None and args.data_keys is not None: + print("Error: Cannot specify both --data-key-file and --data-keys.") + sys.exit(1) + if args.model_key_file is not None and args.model_keys is not None: + print("Error: Cannot specify both --model-key-file and --model-keys.") + sys.exit(1) + + data_keys = args.data_keys + + if args.data_key_file is not None: + if not os.path.isfile(args.data_key_file): + print(f"Error: Data keys file '{args.data_key_file}' does not exist.") + sys.exit(1) + with open(args.data_key_file, "r") as f: + data_keys = [line.strip() for line in f if line.strip()] + + model_keys = args.model_keys + + if args.model_key_file is not None: + if not os.path.isfile(args.model_key_file): + print(f"Error: Model keys file '{args.model_key_file}' does not exist.") + sys.exit(1) + with open(args.model_key_file, "r") as f: + model_keys = [line.strip() for line in f if line.strip()] + + try: + download_data(strict=args.strict, cache_dir=args.cache_dir, data_keys=data_keys, model_keys=model_keys, + verbose=args.verbose) + except Exception as e: + print(f"\nAborted due to error: {e}") + sys.exit(1)