-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
97 lines (73 loc) · 2.82 KB
/
train.py
File metadata and controls
97 lines (73 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.backends.cudnn as cudnn
import logging
import timeit
import os.path as osp
from cfgs.default_conf import cfg, load_cfg_fom_args
from common.utils import progress_bar, set_seed, get_lr
from datasets.utils import build_dataloader
from models.utils import build_model
from algorithms.utils import build_algorithm
# get config file and logger
logger = logging.getLogger(__name__)
description = 'TRAI lab code demo'
load_cfg_fom_args(description)
# set seed for training
seed = set_seed(cfg.TRAIN.SEED)
# build dataset and dataloader
train_loader, test_loader = build_dataloader(cfg=cfg)
# load model and send to device
model = build_model(cfg=cfg)
model = model.cuda()
# check GPU ID and use them
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
# load pretrained checkpoint or not
if cfg.MODEL.CKPT_DIR is not None:
print('==> Resuming from checkpoint..')
checkpoint = torch.load(cfg.MODEL.CKPT_DIR)
# Add "module." prefix to state_dict keys
new_state_dict = {}
for k, v in checkpoint['state_dict'].items():
name = 'module.' + k
new_state_dict[name] = v
# Load modified state_dict
model.load_state_dict(new_state_dict)
# build trainer
trainer = build_algorithm(
cfg = cfg,
model = model,
)
best_acc = 0.0
best_epoch = 0
for epoch in range(cfg.TRAIN.EPOCH):
torch.cuda.empty_cache()
start_time = timeit.default_timer()
loss = trainer.train(train_loader)
# write the log
loss /= len(train_loader)
stop_time = timeit.default_timer()
logger.info(f"epoch: {epoch}, execution time: {(stop_time - start_time):.2f}s, lr: {get_lr(trainer.optimizer):.5f}, loss: {loss:.4f}")
# eval model during training or in the last epoch
if (epoch + 1) % cfg.TRAIN.EVAL_INTERVAL== 0 or (epoch +1) == cfg.TRAIN.EPOCH:
test_loss, acc = trainer.test(test_loader)
logger.info(f"Eval! loss: {test_loss.data.item():.4f}, accuracy rate: {acc:.2f}")
# save model or not
if cfg.TEST.SAVE_MODEL == True:
state = {
'state_dict': model.state_dict(),
'acc': acc,
'epoch': (epoch+1),
}
torch.save(state, osp.join(cfg.OUT_DIR, '%s_%s_checkpoint_%s.pth' %
(cfg.DATA.NAME, cfg.TRAIN.TRAINER, str(epoch+1))))
# store best epoch and show after training
if acc > best_acc:
best_acc = acc
best_epoch = epoch
# log final results
if (epoch +1) ==cfg.TRAIN.EPOCH:
logger.info(f"Best epoch --> epoch: {best_epoch+1}, accuracy rate: {best_acc:.2f}")
logger.info(f"Last epoch --> epoch: {epoch+1}, accuracy rate: {acc:.2f}")