-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathsample_visualization.py
More file actions
146 lines (123 loc) · 5.67 KB
/
sample_visualization.py
File metadata and controls
146 lines (123 loc) · 5.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import csv
import os
import sys
from pathlib import Path
try:
import streamlit as st
except ModuleNotFoundError:
pass
import torch
import torchvision
import yaml
from omegaconf import OmegaConf
from specvqgan.util import get_ckpt_path
sys.path.insert(0, '.') # nopep8
import soundfile
from feature_extraction.extract_mel_spectrogram import inv_transforms
from train import instantiate_from_config
from vocoder.modules import Generator
def load_model_from_config(config, sd, gpu=True, eval_mode=True, load_new_first_stage=False):
if "ckpt_path" in config.params:
print("Deleting the restore-ckpt path from the config...")
config.params.ckpt_path = None
if "downsample_cond_size" in config.params:
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
config.params.downsample_cond_size = -1
config.params["downsample_cond_factor"] = 0.5
try:
if "ckpt_path" in config.params.first_stage_config.params and not load_new_first_stage:
config.params.first_stage_config.params.ckpt_path = None
print("Deleting the first-stage restore-ckpt path from the config...")
if "ckpt_path" in config.params.cond_stage_config.params:
config.params.cond_stage_config.params.ckpt_path = None
print("Deleting the cond-stage restore-ckpt path from the config...")
except:
pass
model = instantiate_from_config(config)
if load_new_first_stage:
first_stage_model = model.first_stage_model
if sd is not None:
missing, unexpected = model.load_state_dict(sd, strict=False)
try:
print(f"Missing Keys in State Dict: {missing}")
print(f"Unexpected Keys in State Dict: {unexpected}")
except NameError:
pass
if load_new_first_stage:
print('replace with new codebook model')
model.first_stage_model = first_stage_model
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def load_vocoder(ckpt_vocoder: str, eval_mode: bool):
ckpt_vocoder = Path(ckpt_vocoder)
vocoder_sd = torch.load(ckpt_vocoder / 'best_netG.pt', map_location='cpu')
with open(ckpt_vocoder / 'args.yml', 'r') as f:
args = yaml.load(f, Loader=yaml.UnsafeLoader)
vocoder = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers)
vocoder.load_state_dict(vocoder_sd)
if eval_mode:
vocoder.eval()
return {'model': vocoder}
def load_feature_extractor(gpu, eval_mode=True):
s = '''
feature_extractor:
target: evaluation.feature_extractors.melception.Melception
params:
num_classes: 309
features_list: ['logits']
feature_extractor_weights_path: ./evaluation/logs/21-05-10T09-28-40/melception-21-05-10T09-28-40.pt
transform_dset_out_to_inception_in:
- target: evaluation.datasets.transforms.FromMinusOneOneToZeroOne
- target: specvqgan.modules.losses.vggishish.transforms.StandardNormalizeAudio
params:
specs_dir: ./data/vggsound/melspec_10s_22050hz
cache_path: ./specvqgan/modules/losses/vggishish/data/
- target: evaluation.datasets.transforms.GetInputFromBatchByKey
params:
input_key: image
- target: evaluation.datasets.transforms.ToFloat32'''
feat_extractor_cfg = OmegaConf.create(s)
# downloading the checkpoint for melception
get_ckpt_path('melception', 'evaluation/logs/21-05-10T09-28-40')
pl_sd = torch.load(feat_extractor_cfg.feature_extractor.params.feature_extractor_weights_path,
map_location="cpu")
# use gpu=False to compute it on CPU
feat_extractor = load_model_from_config(
feat_extractor_cfg.feature_extractor, pl_sd['model'], gpu=gpu, eval_mode=eval_mode)['model']
if feat_extractor_cfg.transform_dset_out_to_inception_in is not None:
transforms = [instantiate_from_config(c) for c in feat_extractor_cfg.transform_dset_out_to_inception_in]
else:
transforms = [lambda x: x]
transforms = torchvision.transforms.Compose(transforms)
vggsound_meta = list(csv.reader(open('./data/vggsound.csv'), quotechar='"'))
unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
label2target = {label: target for target, label in enumerate(unique_classes)}
target2label = {target: label for label, target in label2target.items()}
return {'model': feat_extractor, 'transforms': transforms, 'target2label': target2label}
def show_wave_in_streamlit(wave_npy, sample_rate, caption):
# showing in streamlit. We cannot just show the npy wave and we need to save it first
temp_wav_file_path = 'todel.wav'
soundfile.write(temp_wav_file_path, wave_npy, sample_rate, 'PCM_24')
st.text(caption)
st.audio(temp_wav_file_path, format='audio/wav')
os.remove(temp_wav_file_path)
def spec_to_audio_to_st(x, spec_dir_path, sample_rate, show_griffin_lim, vocoder=None, show_in_st=True):
# audios are in [-1, 1], making them in [0, 1]
spec = (x.data.squeeze(0) + 1) / 2
out = {}
if vocoder:
# (L,) <- wave: (1, 1, L).squeeze() <- spec: (1, F, T)
wave_from_vocoder = vocoder(spec).squeeze().cpu().numpy()
out['vocoder'] = wave_from_vocoder
if show_in_st:
show_wave_in_streamlit(wave_from_vocoder, sample_rate, 'Reconstructed Wave via MelGAN')
if show_griffin_lim:
spec = spec.squeeze(0).cpu().numpy()
wave_from_griffinlim = inv_transforms(spec, Path(spec_dir_path).stem)
out['inv_transforms'] = wave_from_griffinlim
if show_in_st:
show_wave_in_streamlit(wave_from_griffinlim, sample_rate, 'Reconstructed Wave via Griffin Lim')
return out