-
Notifications
You must be signed in to change notification settings - Fork 0
Investigating external dataset usage. #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
4a849c8
WIP - Investigating external dataset usage.
drewoldag 43ea880
Additional polishing after 0.7.0 release.
drewoldag 09534ec
Cleaning up the model usage example.
drewoldag 963db3d
Moving some of the dependencies out of `dev` into primary.
drewoldag 7835629
Update src/external_hyrax_example/datasets/galaxy10_dataset.py
drewoldag c8c64b5
Responding to PR comments.
drewoldag File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
296 changes: 296 additions & 0 deletions
296
docs/pre_executed/external_dataset_and_model_training_example.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,296 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "42dfb8e1", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# External Dataset and Model Usage Example\n", | ||
| "\n", | ||
| "This notebook demonstrates how to use dataset and model classes **defined outside of the Hyrax package** within Hyrax's training pipeline. It covers:\n", | ||
| "\n", | ||
| "- Loading the [Galaxy10 dataset](https://astronn.readthedocs.io/en/stable/galaxy10sdss.html) via a custom `HyraxDataset` subclass (`Galaxy10Dataset`)\n", | ||
| "- Training Hyrax's built-in `HyraxCNN` model on that dataset\n", | ||
| "- Training a locally-defined `VGG11` model (registered with Hyrax) on the same dataset\n", | ||
| "- Comparing training loss between the two models" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "411d03bc", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Download the Galaxy10 Dataset\n", | ||
| "\n", | ||
| "The [Galaxy10 dataset](https://astronn.readthedocs.io/en/stable/galaxy10sdss.html) from [AstroNN](https://github.com/henrysky/astroNN) contains ~22k labeled galaxy images across 10 morphological classes. It serves as a stand-in for any external astronomical image dataset.\n", | ||
| "\n", | ||
| "The ~200MB HDF5 file will be downloaded and cached locally via `pooch`." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 2, | ||
| "id": "c5817aa3", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import pooch\n", | ||
| "\n", | ||
| "file_path = pooch.retrieve(\n", | ||
| " url=\"http://www.astro.utoronto.ca/~bovy/Galaxy10/Galaxy10.h5\",\n", | ||
| " known_hash=\"sha256:969A6B1CEFCC36E09FFFA86FEBD2F699A4AA19B837BA0427F01B0BC6DED458AF\",\n", | ||
| " fname=\"Galaxy10.h5\",\n", | ||
| " path=\"./data/galaxy_10\",\n", | ||
| ")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "c027bd3b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from hyrax import Hyrax\n", | ||
| "\n", | ||
| "h = Hyrax()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "ff9b0ab4", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Configure Hyrax\n", | ||
| "\n", | ||
| "Next we build a `data_request` dictionary that tells Hyrax which dataset class to use, where the data lives, and how to split it between training (80%) and validation (20%).\n", | ||
| "\n", | ||
| "The key here is `dataset_class` — instead of a built-in Hyrax dataset, we point it at our locally-defined `Galaxy10Dataset` class. Hyrax will import and instantiate it automatically.\n", | ||
| "\n", | ||
| "We also set `channels_first: True` in the dataset config so the images are provided in `(C, H, W)` order as expected by PyTorch." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "0e144812", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "data_request = {\n", | ||
| " \"train\": {\n", | ||
| " \"data\": {\n", | ||
| " \"dataset_class\": \"external_hyrax_example.datasets.galaxy10_dataset.Galaxy10Dataset\",\n", | ||
| " \"data_location\": \"./data/galaxy_10\",\n", | ||
| " \"fields\": [\"image\", \"label\"],\n", | ||
| " \"primary_id_field\": \"object_id\",\n", | ||
| " \"split_fraction\": 0.8,\n", | ||
| " \"dataset_config\": {\n", | ||
| " \"external_hyrax_example\": {\n", | ||
| " \"galaxy10_dataset\": {\n", | ||
| " \"channels_first\": True,\n", | ||
| " },\n", | ||
| " },\n", | ||
| " },\n", | ||
| " },\n", | ||
| " },\n", | ||
| " \"validate\": {\n", | ||
| " \"data\": {\n", | ||
| " \"dataset_class\": \"external_hyrax_example.datasets.galaxy10_dataset.Galaxy10Dataset\",\n", | ||
| " \"data_location\": \"./data/galaxy_10\",\n", | ||
| " \"fields\": [\"image\", \"label\"],\n", | ||
| " \"primary_id_field\": \"object_id\",\n", | ||
| " \"split_fraction\": 0.2,\n", | ||
| " \"dataset_config\": {\n", | ||
| " \"external_hyrax_example\": {\n", | ||
| " \"galaxy10_dataset\": {\n", | ||
| " \"channels_first\": True,\n", | ||
| " },\n", | ||
| " },\n", | ||
| " },\n", | ||
| " },\n", | ||
| " },\n", | ||
| "}\n", | ||
| "\n", | ||
| "h.set_config(\"data_request\", data_request)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "b53b5634", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Train HyraxCNN\n", | ||
| "\n", | ||
| "We'll start by training `HyraxCNN`, a lightweight CNN that ships with Hyrax. To do so, we need to set the model name in config." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "f37265bb", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "h.set_config(\"model.name\", \"HyraxCNN\")" | ||
drewoldag marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "ce9b22cf", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "HyraxCNN's default `prepare_inputs` expects images in `(C, H, W)` format with standard normalization. Because Galaxy10 images are already in `(C, H, W)` order (we set `channels_first: True` earlier), we override `prepare_inputs` to apply the correct normalization and extract the label — replacing the default implementation before training begins." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "cbc62ec9", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "@staticmethod\n", | ||
| "def prepare_inputs(data_dict):\n", | ||
| " import numpy as np\n", | ||
| "\n", | ||
| " if \"data\" not in data_dict:\n", | ||
| " raise RuntimeError(\"Unable to find `data` key in data_dict\")\n", | ||
| "\n", | ||
| " data = data_dict[\"data\"]\n", | ||
| " image = np.asarray(data[\"image\"], dtype=np.float32)\n", | ||
| "\n", | ||
| " # normalize the image to have mean 0.5 and std 0.5\n", | ||
| " image = np.asarray(image, dtype=np.float32)\n", | ||
| " mean = np.asarray([0.5, 0.5, 0.5], dtype=np.float32)\n", | ||
| " std = np.asarray([0.5, 0.5, 0.5], dtype=np.float32)\n", | ||
| "\n", | ||
| " mean = mean[:, None, None]\n", | ||
| " std = std[:, None, None]\n", | ||
| "\n", | ||
| " normalized_image = (image - mean) / (std + 1e-8)\n", | ||
| "\n", | ||
| " label = None\n", | ||
| " if \"label\" in data:\n", | ||
| " label = np.asarray(data[\"label\"], dtype=np.int64)\n", | ||
| "\n", | ||
| " return (normalized_image, label)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "1265691e", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "model = h.model()\n", | ||
| "model.prepare_inputs = prepare_inputs" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "69a6bcf0", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "With the model configured, we kick off training. Hyrax handles the training loop, validation, and result logging automatically." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "f58f025d", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "model = h.train()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "5793b02d", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Train VGG11\n", | ||
| "\n", | ||
| "Now we switch to a locally-defined `VGG11` model. This is the key demonstration of using an **externally defined model** with Hyrax — the model is declared in `external_hyrax_example.models.vgg11` and registered with Hyrax via the `@hyrax_model` decorator, so we can reference it by its fully-qualified class path.\n", | ||
| "\n", | ||
| "We reuse the same `prepare_inputs` override from easier comparison of the results." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "6116906c", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "h.set_config(\"model.name\", \"external_hyrax_example.models.vgg11.VGG11\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "1797a069", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "model = h.model()\n", | ||
| "model.prepare_inputs = prepare_inputs" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "d163c8c5", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "Train VGG11 using the same data request as before." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "7f4e3d40", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "model = h.train()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "5bfb204c", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Results\n", | ||
| "\n", | ||
| "Below is a comparison of training loss between the two models over the same number of epochs. VGG11 is a significantly deeper network and converges to a lower loss on this dataset.\n", | ||
| "\n", | ||
| "\n", | ||
| "\n", | ||
| "*Orange = HyraxCNN, Green = VGG11*" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "external-hyrax-example", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.13.12" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.