-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata.py
More file actions
114 lines (89 loc) · 4.54 KB
/
data.py
File metadata and controls
114 lines (89 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import cv2
import torch
import torch.distributed as dist
import numpy as np
from torch.utils.data import Dataset
class DataLoader:
@staticmethod
def load(datapath: str, dataset: str) -> tuple:
def _get_image_paths(subdir):
base_path = os.path.join(datapath, subdir)
return [os.path.join(base_path, img) for img in os.listdir(base_path)]
left_train = _get_image_paths("left")
if dataset == 'US3D':
right_train = [path.replace("LEFT_RGB", "RIGHT_RGB").replace("left", "right") for path in left_train]
disp_train = [path.replace("LEFT_RGB", "LEFT_DSP").replace("left", "disp") for path in left_train]
left_valid = _get_image_paths("valid_left")
right_valid = [path.replace("LEFT_RGB", "RIGHT_RGB").replace("left", "right") for path in left_valid]
disp_valid = [path.replace("LEFT_RGB", "LEFT_DSP").replace("left", "disp") for path in left_valid]
elif dataset == 'Gaofen7':
right_train = [path.replace("left", "right") for path in left_train]
disp_train = [path.replace("left", "disparity") for path in left_train]
left_valid = _get_image_paths("valid_left")
right_valid = [path.replace("left", "right") for path in left_valid]
disp_valid = [path.replace("left", "disparity") for path in left_valid]
train_data = (left_train, right_train, disp_train)
valid_data = (left_valid, right_valid, disp_valid)
return train_data, valid_data
class StereoDataset(Dataset):
def __init__(self, left_images, right_images, disp_images, training=True):
self.left = left_images
self.right = right_images
self.disp = disp_images
self.training = training
def __len__(self):
return len(self.left)
def __getitem__(self, index):
left = self._read_image(self.left[index])
right = self._read_image(self.right[index])
disp = self._read_image(self.disp[index], is_disp=True)
return left, right, disp
def _read_image(self, path, is_disp=False):
img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype('float32')
if len(img.shape) == 3:
img = np.moveaxis(img, -1, 0) / 127.5 - 1.0
return img
if is_disp:
disp_16x = cv2.resize(img, (64, 64)) / 16.0
disp_8x = cv2.resize(img, (128, 128)) / 8.0
disp_4x = cv2.resize(img, (256, 256)) / 4.0
return disp_16x, disp_8x, disp_4x, img
img = np.expand_dims(img, axis=0)
img = (img - np.mean(img)) / np.std(img)
return img
def generate(dataset, datapath):
assert os.path.basename(datapath) == dataset
if dataset in ["US3D", 'Gaofen7']:
train_data, valid_data = DataLoader.load(datapath, dataset)
train_dataset = StereoDataset(
left_images=train_data[0],
right_images=train_data[1],
disp_images=train_data[2],
training=True
)
valid_dataset = StereoDataset(
left_images=valid_data[0],
right_images=valid_data[1],
disp_images=valid_data[2],
training=False
)
return train_dataset, valid_dataset
def initialize_dataloaders(args, train_dataset, valid_dataset):
if args.is_distributed:
train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=dist.get_world_size(),
rank=dist.get_rank())
valid_sampler = torch.utils.data.DistributedSampler(valid_dataset, num_replicas=dist.get_world_size(),
rank=dist.get_rank())
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, num_workers=args.num_workers,
sampler=train_sampler, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers,
sampler=valid_sampler, pin_memory=True)
else:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers, drop_last=False)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers, drop_last=False)
return train_loader, valid_loader