This project involves training a supervised decision tree model, a semi-supervised model and a Convolutional Neural Network (CNN) to classify images into different categories. The model is built using sklearn and PyTorch and is trained, validated, and tested on a dataset of images. The code includes data preprocessing, model training, validation, testing, and visualization of results.
The following Python libraries are required to run the code:
- torch
- torchvision
- numpy
- matplotlib
- scikit-learn You can install these libraries using pip:
pip install torch torchvision numpy matplotlib scikit-learn
The dataset should be structured in a way that datasets.ImageFolder can read it. This means the dataset directory should contain one subdirectory per class, and each subdirectory should contain the images for that class.
You can download a sample small test dataset from this repository.
To train the model, follow these steps:
- Place your dataset in a directory (e.g., C:\dataset).
- Adjust the path variable in the code to point to your dataset directory.
- Run the code to start training. The training progress, including loss and accuracy for both training and validation sets, will be printed to the console.
After training, the model will be evaluated on a test set. The test loss and accuracy will be printed, and a confusion matrix and classification report will be generated to show the model's performance.
The code includes functionality to visualize:
- Training and validation loss over epochs
- Training and validation accuracy over epochs
- Confusion matrix
These visualizations help in understanding the model's performance and diagnosing potential issues.
- Ensure you have installed all dependencies.
- Download and place the sample dataset in the specified directory.
- Run the provided code to train and evaluate the model.
- Check the output for training/validation loss, accuracy, and visualizations.
- The project uses scikit-learn for decision trees.
- The project uses PyTorch for building and training the CNN.
- The dataset loading and transformation are handled using torchvision.