-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmake_encoder.py
More file actions
28 lines (23 loc) · 1.01 KB
/
make_encoder.py
File metadata and controls
28 lines (23 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os.path
import torch
import collections
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
from define_network import Compression_encoder,AutoEncoder_3
from sample_set import Sample_set
if __name__ == '__main__': # trunc the autoencoder and reserve the first half part
path_ = os.path.abspath('.')
fname = path_ + '/autoencoder_layer3.pth'
ae = AutoEncoder_3()
ae.load_state_dict(torch.load(fname))
ce = Compression_encoder()
new_dict = collections.OrderedDict()
new_dict['encoder1.weight'] = ae.state_dict()['encoder1.weight']
new_dict['encoder1.bias'] = ae.state_dict()['encoder1.bias']
new_dict['encoder2.weight'] = ae.state_dict()['encoder2.weight']
new_dict['encoder2.bias'] = ae.state_dict()['encoder2.bias']
new_dict['encoder3.weight'] = ae.state_dict()['encoder3.weight']
new_dict['encoder3.bias'] = ae.state_dict()['encoder3.bias']
ce.load_state_dict(new_dict)
torch.save(ce.state_dict(),path_+'/compression_encoder.pth')