-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
72 lines (56 loc) · 1.99 KB
/
test.py
File metadata and controls
72 lines (56 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.backends.cudnn as cudnn
import logging
import os
import numpy as np
from datetime import datetime
from cfgs.default_conf import cfg, load_cfg_fom_args
from common.utils import progress_bar, set_seed
from datasets.utils import build_dataloader
from models.utils import build_model
from algorithms.utils import build_algorithm
# get config file and logger
logger = logging.getLogger(__name__)
description = 'TRAI lab code demo'
load_cfg_fom_args(description)
# set seed for training
seed = set_seed(cfg.TRAIN.SEED)
# build dataset and dataloader
train_loader, test_loader = build_dataloader(cfg=cfg)
# load model and send to device
model = build_model(cfg=cfg)
model = model.cuda()
# check GPU ID and use them
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
# # load pretrained checkpoint or not. If you need add module in state_dict.
# if cfg.MODEL.CKPT_DIR is not None:
# print('==> Resuming from checkpoint..')
# checkpoint = torch.load(cfg.MODEL.CKPT_DIR)
# # Add "module." prefix to state_dict keys
# new_state_dict = {}
# for k, v in checkpoint['state_dict'].items():
# name = 'module.' + k
# new_state_dict[name] = v
# # Load modified state_dict
# model.load_state_dict(new_state_dict)
# else:
# raise ValueError("You need a pretrain model")
# load pretrained checkpoint or not
if cfg.MODEL.CKPT_DIR is not None:
print('==> Resuming from checkpoint..')
checkpoint = torch.load(cfg.MODEL.CKPT_DIR)
# Load modified state_dict
model.load_state_dict(checkpoint['state_dict'])
else:
raise ValueError("You need a pretrain model")
# build trainer
trainer = build_algorithm(
cfg = cfg,
model = model,
)
# test model
test_loss, acc = trainer.test(test_loader)
logger.info(f"Eval! loss: {test_loss.data.item():.4f}, accuracy rate: {acc:.2f}")