-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_loader.py
More file actions
28 lines (20 loc) · 1004 Bytes
/
data_loader.py
File metadata and controls
28 lines (20 loc) · 1004 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
def get_emoji_loader(emoji_type, opts):
"""Creates training and test data loaders.
"""
transform = transforms.Compose([
transforms.Scale(opts.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_path = os.path.join('./emojis', emoji_type)
test_path = os.path.join('./emojis', 'Test_{}'.format(emoji_type))
train_dataset = datasets.ImageFolder(train_path, transform)
test_dataset = datasets.ImageFolder(test_path, transform)
train_dloader = DataLoader(dataset=train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.num_workers)
test_dloader = DataLoader(dataset=test_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=opts.num_workers)
return train_dloader, test_dloader