Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/_static/HyraxCNN_vs_VGG11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 0 additions & 2 deletions docs/notebooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,3 @@ Notebooks
========================================================================================

.. toctree::

Introducing Jupyter Notebooks <notebooks/intro_notebook>
84 changes: 0 additions & 84 deletions docs/notebooks/intro_notebook.ipynb

This file was deleted.

296 changes: 296 additions & 0 deletions docs/pre_executed/external_dataset_and_model_training_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "42dfb8e1",
"metadata": {},
"source": [
"# External Dataset and Model Usage Example\n",
"\n",
"This notebook demonstrates how to use dataset and model classes **defined outside of the Hyrax package** within Hyrax's training pipeline. It covers:\n",
"\n",
"- Loading the [Galaxy10 dataset](https://astronn.readthedocs.io/en/stable/galaxy10sdss.html) via a custom `HyraxDataset` subclass (`Galaxy10Dataset`)\n",
"- Training Hyrax's built-in `HyraxCNN` model on that dataset\n",
"- Training a locally-defined `VGG11` model (registered with Hyrax) on the same dataset\n",
"- Comparing training loss between the two models"
]
},
{
"cell_type": "markdown",
"id": "411d03bc",
"metadata": {},
"source": [
"## Download the Galaxy10 Dataset\n",
"\n",
"The [Galaxy10 dataset](https://astronn.readthedocs.io/en/stable/galaxy10sdss.html) from [AstroNN](https://github.com/henrysky/astroNN) contains ~22k labeled galaxy images across 10 morphological classes. It serves as a stand-in for any external astronomical image dataset.\n",
"\n",
"The ~200MB HDF5 file will be downloaded and cached locally via `pooch`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c5817aa3",
"metadata": {},
"outputs": [],
"source": [
"import pooch\n",
"\n",
"file_path = pooch.retrieve(\n",
" url=\"http://www.astro.utoronto.ca/~bovy/Galaxy10/Galaxy10.h5\",\n",
" known_hash=\"sha256:969A6B1CEFCC36E09FFFA86FEBD2F699A4AA19B837BA0427F01B0BC6DED458AF\",\n",
" fname=\"Galaxy10.h5\",\n",
" path=\"./data/galaxy_10\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c027bd3b",
"metadata": {},
"outputs": [],
"source": [
"from hyrax import Hyrax\n",
"\n",
"h = Hyrax()"
]
},
{
"cell_type": "markdown",
"id": "ff9b0ab4",
"metadata": {},
"source": [
"## Configure Hyrax\n",
"\n",
"Next we build a `data_request` dictionary that tells Hyrax which dataset class to use, where the data lives, and how to split it between training (80%) and validation (20%).\n",
"\n",
"The key here is `dataset_class` — instead of a built-in Hyrax dataset, we point it at our locally-defined `Galaxy10Dataset` class. Hyrax will import and instantiate it automatically.\n",
"\n",
"We also set `channels_first: True` in the dataset config so the images are provided in `(C, H, W)` order as expected by PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e144812",
"metadata": {},
"outputs": [],
"source": [
"data_request = {\n",
" \"train\": {\n",
" \"data\": {\n",
" \"dataset_class\": \"external_hyrax_example.datasets.galaxy10_dataset.Galaxy10Dataset\",\n",
" \"data_location\": \"./data/galaxy_10\",\n",
" \"fields\": [\"image\", \"label\"],\n",
" \"primary_id_field\": \"object_id\",\n",
" \"split_fraction\": 0.8,\n",
" \"dataset_config\": {\n",
" \"external_hyrax_example\": {\n",
" \"galaxy10_dataset\": {\n",
" \"channels_first\": True,\n",
" },\n",
" },\n",
" },\n",
" },\n",
" },\n",
" \"validate\": {\n",
" \"data\": {\n",
" \"dataset_class\": \"external_hyrax_example.datasets.galaxy10_dataset.Galaxy10Dataset\",\n",
" \"data_location\": \"./data/galaxy_10\",\n",
" \"fields\": [\"image\", \"label\"],\n",
" \"primary_id_field\": \"object_id\",\n",
" \"split_fraction\": 0.2,\n",
" \"dataset_config\": {\n",
" \"external_hyrax_example\": {\n",
" \"galaxy10_dataset\": {\n",
" \"channels_first\": True,\n",
" },\n",
" },\n",
" },\n",
" },\n",
" },\n",
"}\n",
"\n",
"h.set_config(\"data_request\", data_request)"
]
},
{
"cell_type": "markdown",
"id": "b53b5634",
"metadata": {},
"source": [
"## Train HyraxCNN\n",
"\n",
"We'll start by training `HyraxCNN`, a lightweight CNN that ships with Hyrax. To do so, we need to set the model name in config."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f37265bb",
"metadata": {},
"outputs": [],
"source": [
"h.set_config(\"model.name\", \"HyraxCNN\")"
]
},
{
"cell_type": "markdown",
"id": "ce9b22cf",
"metadata": {},
"source": [
"HyraxCNN's default `prepare_inputs` expects images in `(C, H, W)` format with standard normalization. Because Galaxy10 images are already in `(C, H, W)` order (we set `channels_first: True` earlier), we override `prepare_inputs` to apply the correct normalization and extract the label — replacing the default implementation before training begins."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cbc62ec9",
"metadata": {},
"outputs": [],
"source": [
"@staticmethod\n",
"def prepare_inputs(data_dict):\n",
" import numpy as np\n",
"\n",
" if \"data\" not in data_dict:\n",
" raise RuntimeError(\"Unable to find `data` key in data_dict\")\n",
"\n",
" data = data_dict[\"data\"]\n",
" image = np.asarray(data[\"image\"], dtype=np.float32)\n",
"\n",
" # normalize the image to have mean 0.5 and std 0.5\n",
" image = np.asarray(image, dtype=np.float32)\n",
" mean = np.asarray([0.5, 0.5, 0.5], dtype=np.float32)\n",
" std = np.asarray([0.5, 0.5, 0.5], dtype=np.float32)\n",
"\n",
" mean = mean[:, None, None]\n",
" std = std[:, None, None]\n",
"\n",
" normalized_image = (image - mean) / (std + 1e-8)\n",
"\n",
" label = None\n",
" if \"label\" in data:\n",
" label = np.asarray(data[\"label\"], dtype=np.int64)\n",
"\n",
" return (normalized_image, label)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1265691e",
"metadata": {},
"outputs": [],
"source": [
"model = h.model()\n",
"model.prepare_inputs = prepare_inputs"
]
},
{
"cell_type": "markdown",
"id": "69a6bcf0",
"metadata": {},
"source": [
"With the model configured, we kick off training. Hyrax handles the training loop, validation, and result logging automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f58f025d",
"metadata": {},
"outputs": [],
"source": [
"model = h.train()"
]
},
{
"cell_type": "markdown",
"id": "5793b02d",
"metadata": {},
"source": [
"## Train VGG11\n",
"\n",
"Now we switch to a locally-defined `VGG11` model. This is the key demonstration of using an **externally defined model** with Hyrax — the model is declared in `external_hyrax_example.models.vgg11` and registered with Hyrax via the `@hyrax_model` decorator, so we can reference it by its fully-qualified class path.\n",
"\n",
"We reuse the same `prepare_inputs` override from easier comparison of the results."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6116906c",
"metadata": {},
"outputs": [],
"source": [
"h.set_config(\"model.name\", \"external_hyrax_example.models.vgg11.VGG11\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1797a069",
"metadata": {},
"outputs": [],
"source": [
"model = h.model()\n",
"model.prepare_inputs = prepare_inputs"
]
},
{
"cell_type": "markdown",
"id": "d163c8c5",
"metadata": {},
"source": [
"Train VGG11 using the same data request as before."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f4e3d40",
"metadata": {},
"outputs": [],
"source": [
"model = h.train()"
]
},
{
"cell_type": "markdown",
"id": "5bfb204c",
"metadata": {},
"source": [
"## Results\n",
"\n",
"Below is a comparison of training loss between the two models over the same number of epochs. VGG11 is a significantly deeper network and converges to a lower loss on this dataset.\n",
"\n",
"![loss_values](../_static/HyraxCNN_vs_VGG11.png)\n",
"\n",
"*Orange = HyraxCNN, Green = VGG11*"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "external-hyrax-example",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading