You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/research/audio/wavenet/evaluate.py

256 lines
10 KiB

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluation"""
import os
from os.path import join
import argparse
import glob
from hparams import hparams, hparams_debug_string
import audio
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
from nnmnkwii import preprocessing as P
from nnmnkwii.datasets import FileSourceDataset
from wavenet_vocoder import WaveNet
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_scalar_input
from src.dataset import RawAudioDataSource, MelSpecDataSource, DualDataset
parser = argparse.ArgumentParser(description='TTS training')
parser.add_argument('--data_path', type=str, required=True, default='',
help='Directory contains preprocessed features.')
parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
parser.add_argument('--is_numpy', action="store_true", default=False, help='Using numpy for inference or not')
parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios')
parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.')
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
help='run platform, support GPU and CPU. Default: GPU')
args = parser.parse_args()
def get_data_loader(hparam, data_dir):
"""
test data loader
"""
wav_paths = glob.glob(os.path.join(data_dir, "*-wave.npy"))
if wav_paths:
X = FileSourceDataset(RawAudioDataSource(data_dir,
hop_size=audio.get_hop_size(),
max_steps=None, cin_pad=hparam.cin_pad))
else:
X = None
C = FileSourceDataset(MelSpecDataSource(data_dir,
hop_size=audio.get_hop_size(),
max_steps=None, cin_pad=hparam.cin_pad))
length_x = np.array(C.file_data_source.lengths)
if C[0].shape[-1] != hparam.cin_channels:
raise RuntimeError("Invalid cin_channnels {}. Expected to be {}.".format(hparam.cin_channels, C[0].shape[-1]))
dataset = DualDataset(X, C, length_x, batch_size=hparam.batch_size, hparams=hparam)
data_loader = de.GeneratorDataset(dataset, ["x_batch", "y_batch", "c_batch", "g_batch", "input_lengths", "mask"])
return data_loader, dataset
def batch_wavegen(hparam, net, c_input=None, g_input=None, tqdm_=None, is_numpy=True):
"""
generate audio
"""
assert c_input is not None
B = c_input.shape[0]
net.set_train(False)
if hparam.upsample_conditional_features:
length = (c_input.shape[-1] - hparam.cin_pad * 2) * audio.get_hop_size()
else:
# already dupulicated
length = c_input.shape[-1]
y_hat = net.incremental_forward(c=c_input, g=g_input, T=length, tqdm=tqdm_, softmax=True, quantize=True,
log_scale_min=hparam.log_scale_min, is_numpy=is_numpy)
if is_mulaw_quantize(hparam.input_type):
# needs to be float since mulaw_inv returns in range of [-1, 1]
y_hat = np.reshape(np.argmax(y_hat, 1), (B, -1))
y_hat = y_hat.astype(np.float32)
for k in range(B):
y_hat[k] = P.inv_mulaw_quantize(y_hat[k], hparam.quantize_channels - 1)
elif is_mulaw(hparam.input_type):
y_hat = np.reshape(y_hat, (B, -1))
for k in range(B):
y_hat[k] = P.inv_mulaw(y_hat[k], hparam.quantize_channels - 1)
else:
y_hat = np.reshape(y_hat, (B, -1))
if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
for k in range(B):
y_hat[k] = getattr(audio, hparam.postprocess)(y_hat[k])
if hparam.global_gain_scale > 0:
for k in range(B):
y_hat[k] /= hparam.global_gain_scale
return y_hat
def to_int16(x_):
"""
convert datatype to int16
"""
if x_.dtype == np.int16:
return x_
assert x_.dtype == np.float32
assert x_.min() >= -1 and x_.max() <= 1.0
return (x_ * 32767).astype(np.int16)
def get_reference_file(hparam, dataset_source, idx):
"""
get reference files
"""
reference_files = []
reference_feats = []
for _ in range(hparam.batch_size):
if hasattr(dataset_source, "X"):
reference_files.append(dataset_source.X.collected_files[idx][0])
else:
pass
if hasattr(dataset_source, "Mel"):
reference_feats.append(dataset_source.Mel.collected_files[idx][0])
else:
reference_feats.append(dataset_source.collected_files[idx][0])
idx += 1
return reference_files, reference_feats, idx
def get_saved_audio_name(has_ref_file_, ref_file, ref_feat, g_fp):
"""get path to save reference audio"""
if has_ref_file_:
target_audio_path = ref_file
name = os.path.splitext(os.path.basename(target_audio_path))[0].replace("-wave", "")
else:
target_feat_path = ref_feat
name = os.path.splitext(os.path.basename(target_feat_path))[0].replace("-feats", "")
# Paths
if g_fp is None:
dst_wav_path_ = join(args.output_path, "{}_gen.wav".format(name))
target_wav_path_ = join(args.output_path, "{}_ref.wav".format(name))
else:
dst_wav_path_ = join(args.output_path, "speaker{}_{}_gen.wav".format(g, name))
target_wav_path_ = join(args.output_path, "speaker{}_{}_ref.wav".format(g, name))
return dst_wav_path_, target_wav_path_
def save_ref_audio(hparam, ref, length, target_wav_path_):
"""
save reference audio
"""
if is_mulaw_quantize(hparam.input_type):
ref = np.reshape(np.argmax(ref, 0), (-1))[:length]
ref = ref.astype(np.float32)
else:
ref = np.reshape(ref, (-1))[:length]
if is_mulaw_quantize(hparam.input_type):
ref = P.inv_mulaw_quantize(ref, hparam.quantize_channels - 1)
elif is_mulaw(hparam.input_type):
ref = P.inv_mulaw(ref, hparam.quantize_channels - 1)
if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
ref = getattr(audio, hparam.postprocess)(ref)
if hparam.global_gain_scale > 0:
ref /= hparam.global_gain_scale
ref = np.clip(ref, -1.0, 1.0)
wavfile.write(target_wav_path_, hparam.sample_rate, to_int16(ref))
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
if args.preset is not None:
with open(args.preset) as f:
hparams.parse_json(f.read())
assert hparams.name == "wavenet_vocoder"
print(hparams_debug_string())
fs = hparams.sample_rate
hparams.batch_size = 10
hparams.max_time_sec = None
hparams.max_time_steps = None
data_loaders, source_dataset = get_data_loader(hparam=hparams, data_dir=args.data_path)
upsample_params = hparams.upsample_params
upsample_params["cin_channels"] = hparams.cin_channels
upsample_params["cin_pad"] = hparams.cin_pad
model = WaveNet(
out_channels=hparams.out_channels,
layers=hparams.layers,
stacks=hparams.stacks,
residual_channels=hparams.residual_channels,
gate_channels=hparams.gate_channels,
skip_out_channels=hparams.skip_out_channels,
cin_channels=hparams.cin_channels,
gin_channels=hparams.gin_channels,
n_speakers=hparams.n_speakers,
dropout=hparams.dropout,
kernel_size=hparams.kernel_size,
cin_pad=hparams.cin_pad,
upsample_conditional_features=hparams.upsample_conditional_features,
upsample_params=upsample_params,
scalar_input=is_scalar_input(hparams.input_type),
output_distribution=hparams.output_distribution,
)
param_dict = load_checkpoint(args.pretrain_ckpt)
load_param_into_net(model, param_dict)
print('Successfully loading the pre-trained model')
os.makedirs(args.output_path, exist_ok=True)
cin_pad = hparams.cin_pad
file_idx = 0
for data in data_loaders.create_dict_iterator():
x, y, c, g, input_lengths = data['x_batch'], data['y_batch'], data['c_batch'], data['g_batch'], data[
'input_lengths']
if cin_pad > 0:
c = c.asnumpy()
c = np.pad(c, pad_width=(cin_pad, cin_pad), mode="edge")
c = Tensor(c)
ref_files, ref_feats, file_idx = get_reference_file(hparams, source_dataset, file_idx)
# Generate
y_hats = batch_wavegen(hparams, model, data['c_batch'], tqdm_=tqdm, is_numpy=args.is_numpy)
x = x.asnumpy()
input_lengths = input_lengths.asnumpy()
# Save each utt.
has_ref_file = bool(ref_files)
for i, (ref_, gen_, length_) in enumerate(zip(x, y_hats, input_lengths)):
dst_wav_path, target_wav_path = get_saved_audio_name(has_ref_file_=has_ref_file, ref_file=ref_files[i],
ref_feat=ref_feats[i], g_fp=g)
save_ref_audio(hparams, ref_, length_, target_wav_path)
gen = gen_[:length_]
gen = np.clip(gen, -1.0, 1.0)
wavfile.write(dst_wav_path, hparams.sample_rate, to_int16(gen))