Skip to content

user-3-1415926/basic-models

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LeNet Project

A beginner-friendly PyTorch image classification project with a clean modular structure, multiple CNN architectures, support for multiple datasets, and automatic input adaptation.

This project started from a simple LeNet training script and was incrementally refactored into a more organized codebase that is easier to extend, experiment with, and understand.

Features

  • Modular project structure
  • Multiple CNN models:
    • LeNet
    • AlexNet
    • VGG16
    • ResNet18
    • GoogLeNet
  • Multiple dataset support:
    • FashionMNIST
    • CIFAR10
  • Automatic input adaptation based on dataset_name and model_name
  • Train / validation / test pipeline
  • Automatic result saving:
    • best model weights
    • training history CSV
    • training curves image
    • test summary JSON

Project Structure

LeNet/
├─ main.py
├─ README.md
├─ .gitignore
│
├─ data/
│  ├─ __init__.py
│  ├─ dataloader.py
│  └─ FashionMNIST/              # downloaded dataset cache
│
├─ eval/
│  ├─ __init__.py
│  └─ evaluate.py
│
├─ models/
│  ├─ __init__.py
│  ├─ LeNet.py
│  ├─ AlexNet.py
│  ├─ VGG.py
│  ├─ ResNet.py
│  └─ GooLeNet.py
│
├─ trainers/
│  ├─ __init__.py
│  └─ trainer.py
│
├─ utils/
│  ├─ __init__.py
│  ├─ device.py
│  ├─ metrics.py
│  ├─ plot.py
│  ├─ results.py
│  └─ seed.py
│
└─ results/
   └─ <dataset_name>/
      └─ <model_name>/

How It Works

The training entry point is main.py.

At a high level:

  1. choose a dataset name
  2. choose a model name
  3. build the model through the model factory
  4. build train/validation/test loaders through the data module
  5. train with the Trainer
  6. evaluate on the test set
  7. save results automatically

Supported Models

The project currently supports:

  • lenet
  • alexnet
  • vgg
  • resnet
  • googlenet

Model creation is handled through:

from models import get_model

Example:

model = get_model("lenet", "fashionmnist")
model = get_model("resnet", "cifar10")

Supported Datasets

The project currently supports:

  • fashionmnist
  • cifar10

Dataset loading is handled in data/dataloader.py.

Automatic Input Adaptation

The project automatically chooses the correct preprocessing based on both:

  • dataset_name
  • model_name

Current behavior

  • LeNet

    • uses 32x32
    • uses 1 input channel
  • AlexNet, VGG, ResNet, GoogLeNet

    • use 224x224
    • use 3 input channels

Dataset-specific adaptation

  • FashionMNIST

    • grayscale dataset
    • stays as 1 channel for LeNet
    • is converted to 3 channels for larger CNN models
  • CIFAR10

    • RGB dataset
    • stays as 3 channels for larger CNN models
    • is converted to 1 channel when used with LeNet

This keeps the training entry simple while still matching each model’s expected input format.

Installation

Create and activate a Python environment, then install the required packages.

Typical dependencies used by this project include:

  • torch
  • torchvision
  • pandas
  • matplotlib
  • torchinfo
  • torchsummary

Example:

pip install torch torchvision pandas matplotlib torchinfo torchsummary

Usage

Open main.py and choose the dataset and model:

dataset_name = "fashionmnist"
model_name = "lenet"

Then run:

python main.py

Example configurations

FashionMNIST + LeNet:

dataset_name = "fashionmnist"
model_name = "lenet"

FashionMNIST + ResNet:

dataset_name = "fashionmnist"
model_name = "resnet"

CIFAR10 + AlexNet:

dataset_name = "cifar10"
model_name = "alexnet"

CIFAR10 + VGG:

dataset_name = "cifar10"
model_name = "vgg"

Training Pipeline

The current pipeline is:

  • build model with get_model(...)
  • build train/validation loaders with train_val_data_process(...)
  • train for a fixed number of epochs
  • evaluate with test_model_process(...)
  • save weights and training results automatically

This pipeline is intentionally simple and easy to follow.

Results Saving

Training results are saved automatically under:

results/<dataset_name>/<model_name>/

Example:

results/fashionmnist/lenet/

Saved files include:

  • best_model.pth

    • best model weights from validation performance
  • history.csv

    • per-epoch training and validation history
  • training_curves.png

    • loss and accuracy curves
  • summary.json

    • lightweight summary including dataset name, model name, and test accuracy

A root-level best_model.pth is also saved for compatibility with earlier workflow.

Example Output Files

results/fashionmnist/lenet/
├─ best_model.pth
├─ history.csv
├─ training_curves.png
└─ summary.json

Main Modules

models/

Contains model definitions and the model factory.

data/

Contains dataset selection, transform selection, and data loader creation.

trainers/

Contains the training loop and training history logic.

eval/

Contains test-set evaluation.

utils/

Contains small helper utilities such as:

  • device selection
  • seed setting
  • accuracy calculation
  • results saving
  • plotting helpers

Beginner Notes

A few design choices are intentionally simple:

  • configuration is currently done directly in main.py
  • training settings are easy to find and edit
  • dataset and model switching uses plain strings
  • the code favors readability over abstraction

If you are learning PyTorch, this makes it easier to trace how data and models move through the project.

What This Project Does Not Try To Be

This project does not aim to be a full training framework. For example, it does not currently include:

  • command-line argument parsing
  • advanced experiment tracking
  • distributed training
  • hyperparameter search
  • mixed precision training
  • checkpoint resume logic

That is intentional. The project focuses on being clear, modular, and easy to extend.

Future Improvements

Possible next steps include:

  • command-line configuration for dataset and model selection
  • improved experiment summaries
  • per-dataset/per-model comparison tables
  • more datasets and more architectures
  • better visualization utilities

License

Add a license here if you plan to publish the repository publicly.

About

Modular PyTorch image classification framework supporting multiple models and datasets with automatic input adaptation

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages