-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVariational-Autoencoder.py
More file actions
429 lines (362 loc) · 16.3 KB
/
Variational-Autoencoder.py
File metadata and controls
429 lines (362 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
"""
Deep Probabilistic Programming 101: Variational Autoencoder (VAE)
by Sourabh Kulkarni (https://www.github.com/SourabhKul)
Following instructions from MLTrain@UAI 2018 Pyro Workshop (http://pyro.ai/examples/bayesian_regression.html)
Some basics before we get started:
Problem: Given a dataset, produce more examples from that dataset
Easiest Solution: map input data space to a latent output space through a neural network
use the learnt latent space and neural network to generate data
this is a variational autoencoder!
To learn a VAE we maximize the probability that a model p(x,z) generates data p(x)
Data: MNIST
"""
import os
import numpy as np
import torch
import torchvision.datasets as dset
import torch.nn as nn
import torchvision.transforms as transforms
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
# for loading and batching MNIST dataset
def setup_data_loaders(batch_size=128, use_cuda=False):
root = './data'
download = True
trans = transforms.ToTensor()
train_set = dset.MNIST(root=root, train=True, transform=trans,
download=download)
test_set = dset.MNIST(root=root, train=False, transform=trans)
kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
train_loader = torch.utils.data.DataLoader(dataset=train_set,
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_set,
batch_size=batch_size, shuffle=False, **kwargs)
return train_loader, test_loader
# decoder
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super(Decoder, self).__init__()
# setup the two linear transformations used
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, 784)
# setup the non-linearities
self.softplus = nn.Softplus()
self.sigmoid = nn.Sigmoid()
def forward(self, z):
# define the forward computation on the latent z
# first compute the hidden units
hidden = self.softplus(self.fc1(z))
# return the parameter for the output Bernoulli
# each is of size batch_size x 784
loc_img = self.sigmoid(self.fc21(hidden))
return loc_img
# encoder
class Encoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super(Encoder, self).__init__()
# setup the three linear transformations used
self.fc1 = nn.Linear(784, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, z_dim)
self.fc22 = nn.Linear(hidden_dim, z_dim)
# setup the non-linearities
self.softplus = nn.Softplus()
def forward(self, x):
# define the forward computation on the image x
# first shape the mini-batch to have pixels in the rightmost dimension
x = x.reshape(-1, 784)
# then compute the hidden units
hidden = self.softplus(self.fc1(x))
# then return a mean vector and a (positive) square root covariance
# each of size batch_size x z_dim
z_loc = self.fc21(hidden)
z_scale = torch.exp(self.fc22(hidden))
return z_loc, z_scale
class VAE(nn.Module):
# by default our latent space is 50-dimensional
# and we use 400 hidden units
def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
super(VAE, self).__init__()
# create the encoder and decoder networks
self.encoder = Encoder(z_dim, hidden_dim)
self.decoder = Decoder(z_dim, hidden_dim)
if use_cuda:
# calling cuda() here will put all the parameters of
# the encoder and decoder networks into gpu memory
self.cuda()
self.use_cuda = use_cuda
self.z_dim = z_dim
# define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
loc_img = self.decoder.forward(z)
# score against actual images
pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder.forward(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# define a helper function for reconstructing images
def reconstruct_img(self, x):
# encode image x
z_loc, z_scale = self.encoder(x)
# sample in latent space
z = dist.Normal(z_loc, z_scale).sample()
# decode the image (note we don't sample in image space)
loc_img = self.decoder(z)
return loc_img
vae = VAE()
optimizer = Adam({"lr": 1.0e-3})
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())
def train(svi, train_loader, use_cuda=False):
# initialize loss accumulator
epoch_loss = 0.
# do a training epoch over each mini-batch x returned
# by the data loader
for x, _ in train_loader:
# if on GPU put mini-batch into CUDA memory
if use_cuda:
x = x.cuda()
# do ELBO gradient and accumulate loss
epoch_loss += svi.step(x)
# return epoch loss
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = epoch_loss / normalizer_train
return total_epoch_loss_train
def evaluate(svi, test_loader, use_cuda=False):
# initialize loss accumulator
test_loss = 0.
# compute the loss over the entire test set
for x, _ in test_loader:
# if on GPU put mini-batch into CUDA memory
if use_cuda:
x = x.cuda()
# compute ELBO estimate and accumulate loss
test_loss += svi.evaluate_loss(x)
normalizer_test = len(test_loader.dataset)
total_epoch_loss_test = test_loss / normalizer_test
return total_epoch_loss_test
# Run options
LEARNING_RATE = 1.0e-3
USE_CUDA = False
# Run only for a single iteration for testing
NUM_EPOCHS = 100
TEST_FREQUENCY = 5
train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)
# clear param store
pyro.clear_param_store()
# setup the VAE
vae = VAE(use_cuda=USE_CUDA)
# setup the optimizer
adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)
# setup the inference algorithm
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())
train_elbo = []
test_elbo = []
# training loop
for epoch in range(NUM_EPOCHS):
total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
train_elbo.append(-total_epoch_loss_train)
print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train))
if epoch % TEST_FREQUENCY == 0:
# report test diagnostics
total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
test_elbo.append(-total_epoch_loss_test)
print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
# import argparse
# import numpy as np
# import torch
# import torch.nn as nn
# import visdom
# import pyro
# import pyro.distributions as dist
# from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
# from pyro.optim import Adam
# from utils.mnist_cached import MNISTCached as MNIST
# from utils.mnist_cached import setup_data_loaders
# from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples
# # define the PyTorch module that parameterizes the
# # diagonal gaussian distribution q(z|x)
# class Encoder(nn.Module):
# def __init__(self, z_dim, hidden_dim):
# super(Encoder, self).__init__()
# # setup the three linear transformations used
# self.fc1 = nn.Linear(784, hidden_dim)
# self.fc21 = nn.Linear(hidden_dim, z_dim)
# self.fc22 = nn.Linear(hidden_dim, z_dim)
# # setup the non-linearities
# self.softplus = nn.Softplus()
# def forward(self, x):
# # define the forward computation on the image x
# # first shape the mini-batch to have pixels in the rightmost dimension
# x = x.reshape(-1, 784)
# # then compute the hidden units
# hidden = self.softplus(self.fc1(x))
# # then return a mean vector and a (positive) square root covariance
# # each of size batch_size x z_dim
# z_loc = self.fc21(hidden)
# z_scale = torch.exp(self.fc22(hidden))
# return z_loc, z_scale
# # define the PyTorch module that parameterizes the
# # observation likelihood p(x|z)
# class Decoder(nn.Module):
# def __init__(self, z_dim, hidden_dim):
# super(Decoder, self).__init__()
# # setup the two linear transformations used
# self.fc1 = nn.Linear(z_dim, hidden_dim)
# self.fc21 = nn.Linear(hidden_dim, 784)
# # setup the non-linearities
# self.softplus = nn.Softplus()
# def forward(self, z):
# # define the forward computation on the latent z
# # first compute the hidden units
# hidden = self.softplus(self.fc1(z))
# # return the parameter for the output Bernoulli
# # each is of size batch_size x 784
# loc_img = torch.sigmoid(self.fc21(hidden))
# return loc_img
# # define a PyTorch module for the VAE
# class VAE(nn.Module):
# # by default our latent space is 50-dimensional
# # and we use 400 hidden units
# def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
# super(VAE, self).__init__()
# # create the encoder and decoder networks
# self.encoder = Encoder(z_dim, hidden_dim)
# self.decoder = Decoder(z_dim, hidden_dim)
# if use_cuda:
# # calling cuda() here will put all the parameters of
# # the encoder and decoder networks into gpu memory
# self.cuda()
# self.use_cuda = use_cuda
# self.z_dim = z_dim
# # define the model p(x|z)p(z)
# def model(self, x):
# # register PyTorch module `decoder` with Pyro
# pyro.module("decoder", self.decoder)
# with pyro.plate("data", x.shape[0]):
# # setup hyperparameters for prior p(z)
# z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
# z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
# # sample from prior (value will be sampled by guide when computing the ELBO)
# z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# # decode the latent code z
# loc_img = self.decoder.forward(z)
# # score against actual images
# pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))
# # return the loc so we can visualize it later
# return loc_img
# # define the guide (i.e. variational distribution) q(z|x)
# def guide(self, x):
# # register PyTorch module `encoder` with Pyro
# pyro.module("encoder", self.encoder)
# with pyro.plate("data", x.shape[0]):
# # use the encoder to get the parameters used to define q(z|x)
# z_loc, z_scale = self.encoder.forward(x)
# # sample the latent code z
# pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# # define a helper function for reconstructing images
# def reconstruct_img(self, x):
# # encode image x
# z_loc, z_scale = self.encoder(x)
# # sample in latent space
# z = dist.Normal(z_loc, z_scale).sample()
# # decode the image (note we don't sample in image space)
# loc_img = self.decoder(z)
# return loc_img
# def main(args):
# # clear param store
# pyro.clear_param_store()
# # setup MNIST data loaders
# # train_loader, test_loader
# train_loader, test_loader = setup_data_loaders(MNIST, use_cuda=args.cuda, batch_size=256)
# # setup the VAE
# vae = VAE(use_cuda=args.cuda)
# # setup the optimizer
# adam_args = {"lr": args.learning_rate}
# optimizer = Adam(adam_args)
# # setup the inference algorithm
# elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
# svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)
# # setup visdom for visualization
# if args.visdom_flag:
# vis = visdom.Visdom()
# train_elbo = []
# test_elbo = []
# # training loop
# for epoch in range(args.num_epochs):
# # initialize loss accumulator
# epoch_loss = 0.
# # do a training epoch over each mini-batch x returned
# # by the data loader
# for x, _ in train_loader:
# # if on GPU put mini-batch into CUDA memory
# if args.cuda:
# x = x.cuda()
# # do ELBO gradient and accumulate loss
# epoch_loss += svi.step(x)
# # report training diagnostics
# normalizer_train = len(train_loader.dataset)
# total_epoch_loss_train = epoch_loss / normalizer_train
# train_elbo.append(total_epoch_loss_train)
# print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train))
# if epoch % args.test_frequency == 0:
# # initialize loss accumulator
# test_loss = 0.
# # compute the loss over the entire test set
# for i, (x, _) in enumerate(test_loader):
# # if on GPU put mini-batch into CUDA memory
# if args.cuda:
# x = x.cuda()
# # compute ELBO estimate and accumulate loss
# test_loss += svi.evaluate_loss(x)
# # pick three random test images from the first mini-batch and
# # visualize how well we're reconstructing them
# if i == 0:
# if args.visdom_flag:
# plot_vae_samples(vae, vis)
# reco_indices = np.random.randint(0, x.shape[0], 3)
# for index in reco_indices:
# test_img = x[index, :]
# reco_img = vae.reconstruct_img(test_img)
# vis.image(test_img.reshape(28, 28).detach().cpu().numpy(),
# opts={'caption': 'test image'})
# vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(),
# opts={'caption': 'reconstructed image'})
# # report test diagnostics
# normalizer_test = len(test_loader.dataset)
# total_epoch_loss_test = test_loss / normalizer_test
# test_elbo.append(total_epoch_loss_test)
# print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
# if epoch == args.tsne_iter:
# mnist_test_tsne(vae=vae, test_loader=test_loader)
# plot_llk(np.array(train_elbo), np.array(test_elbo))
# return vae
# if __name__ == '__main__':
# assert pyro.__version__.startswith('0.3.1')
# # parse command line arguments
# parser = argparse.ArgumentParser(description="parse args")
# parser.add_argument('-n', '--num-epochs', default=101, type=int, help='number of training epochs')
# parser.add_argument('-tf', '--test-frequency', default=5, type=int, help='how often we evaluate the test set')
# parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate')
# parser.add_argument('--cuda', action='store_true', default=False, help='whether to use cuda')
# parser.add_argument('--jit', action='store_true', default=False, help='whether to use PyTorch jit')
# parser.add_argument('-visdom', '--visdom_flag', action="store_true", help='Whether plotting in visdom is desired')
# parser.add_argument('-i-tsne', '--tsne_iter', default=100, type=int, help='epoch when tsne visualization runs')
# args = parser.parse_args()
# model = main(args)