-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
101 lines (83 loc) · 3.56 KB
/
train.py
File metadata and controls
101 lines (83 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Main script to train a PyTorch model for image classification.
This script allows for running experiments with different models, datasets,
and hyperparameters via command-line arguments. It leverages the modular
components from the `foodvision` package.
"""
import argparse
import os
import torch
import torchvision
from food_vision import data_setup, engine, model_builder, utils
def main(args: argparse.Namespace) -> None:
"""
Sets up and runs a single training experiment.
Args:
args: Command-line arguments parsed by argparse.
"""
# ---- Setup ----
# Set up device-agnostic code, including Apple Silicon (MPS)
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"[INFO] Using device: {device}")
# ---- Download Data ----
data_path = utils.download_and_unzip_data(
source_url=args.data_url,
destination_name=args.data_name,
)
train_dir = data_path / "train"
# The test set is consistent across experiments (10% version)
test_dir = data_path / "test"
# ---- Create DataLoaders ----
# Get the appropriate transforms for the selected model
weights = getattr(torchvision.models, f"EfficientNet_{args.model[6:].upper()}_Weights").DEFAULT
auto_transforms = weights.transforms()
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
train_dir=train_dir,
test_dir=test_dir,
transform=auto_transforms,
batch_size=args.batch_size,
)
# ---- Create Model ----
model = model_builder.EfficientNet(
model_name=args.model,
num_classes=len(class_names),
).to(device)
# ---- Setup Loss, Optimizer, and TensorBoard Writer ----
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
writer = utils.create_writer(
experiment_name=args.data_name,
model_name=args.model,
extra=f"{args.epochs}_epochs",
)
# ---- Start Training ----
print(f"[INFO] Starting training for {args.model} on {args.data_name} for {args.epochs} epochs.")
engine.train(
model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
loss_fn=loss_fn,
optimizer=optimizer,
epochs=args.epochs,
device=device,
writer=writer,
)
# ---- Save the Trained Model ----
save_filepath = f"{args.model}_{args.data_name}_{args.epochs}_epochs.pth"
utils.save_model(model=model, target_dir="models", model_name=save_filepath)
print("-" * 50 + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train an EfficientNet model on custom data.")
parser.add_argument("--model", type=str, default="effnetb0", choices=["effnetb0", "effnetb2", "effnetb4"], help="Model architecture to use.")
parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for DataLoaders.")
parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for the optimizer.")
parser.add_argument("--data_name", type=str, default="pizza_steak_sushi_10_percent", help="Name of the data directory.")
parser.add_argument("--data_url", type=str, default="https://github.com/GoJo-Rika/PyTorch-FoodVision-Mini/raw/main/data/pizza_steak_sushi.zip", help="URL to download the data from.")
args = parser.parse_args()
main(args)