-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvrnn_predict.py
More file actions
50 lines (42 loc) · 1.72 KB
/
vrnn_predict.py
File metadata and controls
50 lines (42 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
from callbacks import SavePeriodicCheckpoint
from keras.layers import Input
from keras.layers import Lambda
from keras.layers import LSTM
from keras.layers import Dense
from keras.layers import TimeDistributed
from keras.layers import merge
from keras.models import Model
from keras.optimizers import Adam
from keras import backend as K
from config import parse_args
from math import pi
from utils import audio_amplitudes_gen
from utils import write_audio
from vrnn_model import build_vrnn
def predict(wav_dir, model, write_dir, lstm_size=1000, num_steps=40,
z_dim=100, batch_size=32, fc_dim=400, wav_dim=200):
vae = build_vrnn(lstm_size=lstm_size, num_steps=num_steps, z_dim=z_dim,
batch_size=batch_size, fc_dim=fc_dim, wav_dim=wav_dim,
mode="predict")
vae.load_weights(model)
if not os.path.exists(write_dir):
os.mkdir(write_dir)
counter = 0
pred_gen = audio_amplitudes_gen(
wavdir=wav_dir, num_steps=num_steps, batch_size=batch_size,
wav_dim=wav_dim, infinite=False)
for (x_t, y_t), true in pred_gen:
pred = vae.predict([x_t, y_t], batch_size=batch_size)
print("Writing audio %d" % counter)
true_path = os.path.join(write_dir, "%d_true.wav" % counter)
pred_path = os.path.join(write_dir, "%d_pred.wav" % counter)
write_audio(true, true_path)
write_audio(pred, pred_path)
counter += 1
if __name__ == "__main__":
args = parse_args(mode="predict")
predict(args.wav_dir, args.model, args.write_dir,
z_dim=args.z_dim, lstm_size=args.lstm_size, num_steps=args.num_steps,
batch_size=args.batch_size,
fc_dim=args.fc_dim, wav_dim=args.wav_dim)