A beginner-friendly PyTorch image classification project with a clean modular structure, multiple CNN architectures, support for multiple datasets, and automatic input adaptation.
This project started from a simple LeNet training script and was incrementally refactored into a more organized codebase that is easier to extend, experiment with, and understand.
- Modular project structure
- Multiple CNN models:
- LeNet
- AlexNet
- VGG16
- ResNet18
- GoogLeNet
- Multiple dataset support:
- FashionMNIST
- CIFAR10
- Automatic input adaptation based on
dataset_nameandmodel_name - Train / validation / test pipeline
- Automatic result saving:
- best model weights
- training history CSV
- training curves image
- test summary JSON
LeNet/
├─ main.py
├─ README.md
├─ .gitignore
│
├─ data/
│ ├─ __init__.py
│ ├─ dataloader.py
│ └─ FashionMNIST/ # downloaded dataset cache
│
├─ eval/
│ ├─ __init__.py
│ └─ evaluate.py
│
├─ models/
│ ├─ __init__.py
│ ├─ LeNet.py
│ ├─ AlexNet.py
│ ├─ VGG.py
│ ├─ ResNet.py
│ └─ GooLeNet.py
│
├─ trainers/
│ ├─ __init__.py
│ └─ trainer.py
│
├─ utils/
│ ├─ __init__.py
│ ├─ device.py
│ ├─ metrics.py
│ ├─ plot.py
│ ├─ results.py
│ └─ seed.py
│
└─ results/
└─ <dataset_name>/
└─ <model_name>/
The training entry point is main.py.
At a high level:
- choose a dataset name
- choose a model name
- build the model through the model factory
- build train/validation/test loaders through the data module
- train with the
Trainer - evaluate on the test set
- save results automatically
The project currently supports:
lenetalexnetvggresnetgooglenet
Model creation is handled through:
from models import get_modelExample:
model = get_model("lenet", "fashionmnist")
model = get_model("resnet", "cifar10")The project currently supports:
fashionmnistcifar10
Dataset loading is handled in data/dataloader.py.
The project automatically chooses the correct preprocessing based on both:
dataset_namemodel_name
-
LeNet- uses
32x32 - uses
1input channel
- uses
-
AlexNet,VGG,ResNet,GoogLeNet- use
224x224 - use
3input channels
- use
-
FashionMNIST- grayscale dataset
- stays as
1channel for LeNet - is converted to
3channels for larger CNN models
-
CIFAR10- RGB dataset
- stays as
3channels for larger CNN models - is converted to
1channel when used with LeNet
This keeps the training entry simple while still matching each model’s expected input format.
Create and activate a Python environment, then install the required packages.
Typical dependencies used by this project include:
torchtorchvisionpandasmatplotlibtorchinfotorchsummary
Example:
pip install torch torchvision pandas matplotlib torchinfo torchsummaryOpen main.py and choose the dataset and model:
dataset_name = "fashionmnist"
model_name = "lenet"Then run:
python main.pyFashionMNIST + LeNet:
dataset_name = "fashionmnist"
model_name = "lenet"FashionMNIST + ResNet:
dataset_name = "fashionmnist"
model_name = "resnet"CIFAR10 + AlexNet:
dataset_name = "cifar10"
model_name = "alexnet"CIFAR10 + VGG:
dataset_name = "cifar10"
model_name = "vgg"The current pipeline is:
- build model with
get_model(...) - build train/validation loaders with
train_val_data_process(...) - train for a fixed number of epochs
- evaluate with
test_model_process(...) - save weights and training results automatically
This pipeline is intentionally simple and easy to follow.
Training results are saved automatically under:
results/<dataset_name>/<model_name>/
Example:
results/fashionmnist/lenet/
Saved files include:
-
best_model.pth- best model weights from validation performance
-
history.csv- per-epoch training and validation history
-
training_curves.png- loss and accuracy curves
-
summary.json- lightweight summary including dataset name, model name, and test accuracy
A root-level best_model.pth is also saved for compatibility with earlier workflow.
results/fashionmnist/lenet/
├─ best_model.pth
├─ history.csv
├─ training_curves.png
└─ summary.json
Contains model definitions and the model factory.
Contains dataset selection, transform selection, and data loader creation.
Contains the training loop and training history logic.
Contains test-set evaluation.
Contains small helper utilities such as:
- device selection
- seed setting
- accuracy calculation
- results saving
- plotting helpers
A few design choices are intentionally simple:
- configuration is currently done directly in
main.py - training settings are easy to find and edit
- dataset and model switching uses plain strings
- the code favors readability over abstraction
If you are learning PyTorch, this makes it easier to trace how data and models move through the project.
This project does not aim to be a full training framework. For example, it does not currently include:
- command-line argument parsing
- advanced experiment tracking
- distributed training
- hyperparameter search
- mixed precision training
- checkpoint resume logic
That is intentional. The project focuses on being clear, modular, and easy to extend.
Possible next steps include:
- command-line configuration for dataset and model selection
- improved experiment summaries
- per-dataset/per-model comparison tables
- more datasets and more architectures
- better visualization utilities
Add a license here if you plan to publish the repository publicly.