-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathAutoencoder.py
More file actions
85 lines (63 loc) · 2.54 KB
/
Autoencoder.py
File metadata and controls
85 lines (63 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# keras
from keras.models import Sequential
from keras.utils import np_utils, generic_utils
from keras.optimizers import RMSprop, SGD
from keras.layers.core import Dense, Activation, AutoEncoder
from keras.regularizers import activity_l1, l2
from keras.preprocessing.image import ImageDataGenerator
from myutils import *
class Autoencoder(object):
def __init__(self, n_in, n_hid,
lr=1e-2, l2reg=3e-6, corruption_level=0.3, act='sigmoid'):
self.lr = lr
self.l2reg = l2reg
self.corruption_level = corruption_level
self.ae = Sequential()
encoder = Sequential()
encoder.add(Dense(n_in, n_hid, init='uniform', W_regularizer=l2(l2reg)))
encoder.add(Activation(act))
decoder = Sequential()
decoder.add(Dense(n_hid, n_in, init='uniform', W_regularizer=l2(l2reg)))
decoder.add(Activation(act))
self.ae.add(AutoEncoder(encoder=encoder, decoder=decoder,
output_reconstruction=True))
opt = RMSprop(lr=lr, rho=0.9, epsilon=1e-6)
self.ae.compile(loss='mean_squared_error', optimizer=opt)
def train(self, X_in, X_out, n_epoch=100, batch_size=32, filter_imgfile=None, recon_imgfile=None, verbose=True):
gdatagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False # apply ZCA whitening
)
e = 0
self.rec_losses = []
while e < n_epoch:
e += 1
if verbose:
print('-'*40)
print('Epoch', e)
print('-'*40)
if verbose:
progbar = generic_utils.Progbar(X_in.shape[0])
for X_batch, Y_batch in gdatagen.flow(X_in, X_out, batch_size=batch_size):
X_batch = get_corrupted_output(X_batch, corruption_level=self.corruption_level)
train_score = self.ae.train_on_batch(X_batch, Y_batch)
if verbose:
progbar.add(X_batch.shape[0], values=[("train generative loss", train_score)])
# Evaluate
self.loss = self.ae.evaluate(X_in, X_out, batch_size=1024, verbose=0)
if filter_imgfile is not None:
# visualize the weights
W0 = self.ae.get_weights()[0]
show_images(np.transpose(W0[:,0:100],(1,0)), grayscale=True, filename=filter_imgfile)
if recon_imgfile is not None:
# AE recontruction
# Get random samples
idx = np.random.permutation(X.shape[0])
idx = idx[:100]
Xs = X[idx]
# Reconstruct input
Xr = self.ae.predict(Xs)
show_images(Xr, grayscale=True, filename=recon_imgfile)