-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
116 lines (82 loc) · 3.84 KB
/
train.py
File metadata and controls
116 lines (82 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import csv
import torch
import torch.nn.functional as F
import torch.distributed as dist
from tqdm import tqdm
from utils import create_mask, adjust_learning_rate, masked_cross_entropy_loss
def save_checkpoint(model, epoch, args, local_rank):
savefilename = os.path.join(args.save_ckpt_path, f'train_ckpt_{epoch}.tar')
if args.is_distributed:
if args.rank == 0:
torch.save({'state_dict': model.module.state_dict()}, savefilename)
else:
torch.save({'state_dict': model.state_dict()}, savefilename)
def log_results(epoch, train_loss, valid_loss, args, fieldnames):
with open(args.save_csv_file_path, 'a', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writerow({
'epoch': epoch,
'train_loss': train_loss,
'valid_loss': valid_loss,
})
def train_one_epoch(epoch, model, optimizer, train_loader, local_rank, args):
model.train()
total_train_loss = 0.0
# Learning rate adjustment and distributed training setup
adjust_learning_rate(optimizer, epoch)
if args.is_distributed:
train_loader.sampler.set_epoch(epoch)
# Training loop
for batch_idx, (left, right, disp, cls) in tqdm(enumerate(train_loader), total=len(train_loader)):
left, right, disp = [
tensor.to(local_rank).float() for tensor in
[left, right, disp]
]
cls = cls.to(local_rank).long()
optimizer.zero_grad()
pred_disp1, pred_disp2, pred_disp3, pred_cls = model(left, right)
# Create masks for different disparity scales
disp, mask = create_mask(disp, args.maxdisp, args.mindisp)
# Compute losses for different scales
loss1 = 0.5*F.smooth_l1_loss(pred_disp1[mask], disp[mask], size_average=True) + \
0.7*F.smooth_l1_loss(pred_disp2[mask], disp[mask], size_average=True) + \
F.smooth_l1_loss(pred_disp3[mask], disp[mask], size_average=True)
loss2 = masked_cross_entropy_loss(pred_cls, cls)
loss = 0.15*loss1 + loss2
if args.is_distributed:
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss = loss / args.world_size
loss.backward()
optimizer.step()
if args.rank == 0:
total_train_loss += loss.detach().cpu().numpy()
return total_train_loss / len(train_loader)
def validate_one_epoch(epoch, model, valid_loader, local_rank, args):
model.eval()
total_valid_loss = 0.0
if args.is_distributed:
valid_loader.sampler.set_epoch(epoch)
# Validation loop
with torch.no_grad():
for batch_idx, (left, right, disp, cls) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
left, right, disp = [
tensor.to(local_rank).float() for tensor in
[left, right, disp]
]
cls = cls.to(local_rank).long()
pred_disp1, pred_disp2, pred_disp3, pred_cls = model(left, right)
# Create masks for different disparity scales
disp, mask = create_mask(disp, args.maxdisp, args.mindisp)
# Compute losses for different scales
loss1 = 0.5*F.smooth_l1_loss(pred_disp1[mask], disp[mask], size_average=True) + \
0.7*F.smooth_l1_loss(pred_disp2[mask], disp[mask], size_average=True) + \
F.smooth_l1_loss(pred_disp3[mask], disp[mask], size_average=True)
loss2 = masked_cross_entropy_loss(pred_cls, cls)
loss = 0.15*loss1 + loss2
if args.is_distributed:
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss = loss / args.world_size
if args.rank == 0:
total_valid_loss += loss.detach().cpu().numpy()
return total_valid_loss / len(valid_loader)