You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
256 lines
10 KiB
256 lines
10 KiB
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""evaluation"""
|
|
import os
|
|
from os.path import join
|
|
import argparse
|
|
import glob
|
|
from hparams import hparams, hparams_debug_string
|
|
import audio
|
|
import numpy as np
|
|
from scipy.io import wavfile
|
|
from tqdm import tqdm
|
|
from mindspore import context, Tensor
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
import mindspore.dataset.engine as de
|
|
from nnmnkwii import preprocessing as P
|
|
from nnmnkwii.datasets import FileSourceDataset
|
|
from wavenet_vocoder import WaveNet
|
|
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_scalar_input
|
|
from src.dataset import RawAudioDataSource, MelSpecDataSource, DualDataset
|
|
|
|
parser = argparse.ArgumentParser(description='TTS training')
|
|
parser.add_argument('--data_path', type=str, required=True, default='',
|
|
help='Directory contains preprocessed features.')
|
|
parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
|
|
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
|
|
parser.add_argument('--is_numpy', action="store_true", default=False, help='Using numpy for inference or not')
|
|
parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios')
|
|
parser.add_argument('--speaker_id', type=str, default='',
|
|
help=' Use specific speaker of data in case for multi-speaker datasets.')
|
|
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
|
|
help='run platform, support GPU and CPU. Default: GPU')
|
|
args = parser.parse_args()
|
|
|
|
|
|
def get_data_loader(hparam, data_dir):
|
|
"""
|
|
test data loader
|
|
"""
|
|
wav_paths = glob.glob(os.path.join(data_dir, "*-wave.npy"))
|
|
if wav_paths:
|
|
X = FileSourceDataset(RawAudioDataSource(data_dir,
|
|
hop_size=audio.get_hop_size(),
|
|
max_steps=None, cin_pad=hparam.cin_pad))
|
|
else:
|
|
X = None
|
|
C = FileSourceDataset(MelSpecDataSource(data_dir,
|
|
hop_size=audio.get_hop_size(),
|
|
max_steps=None, cin_pad=hparam.cin_pad))
|
|
|
|
length_x = np.array(C.file_data_source.lengths)
|
|
if C[0].shape[-1] != hparam.cin_channels:
|
|
raise RuntimeError("Invalid cin_channnels {}. Expected to be {}.".format(hparam.cin_channels, C[0].shape[-1]))
|
|
|
|
dataset = DualDataset(X, C, length_x, batch_size=hparam.batch_size, hparams=hparam)
|
|
|
|
data_loader = de.GeneratorDataset(dataset, ["x_batch", "y_batch", "c_batch", "g_batch", "input_lengths", "mask"])
|
|
|
|
return data_loader, dataset
|
|
|
|
|
|
def batch_wavegen(hparam, net, c_input=None, g_input=None, tqdm_=None, is_numpy=True):
|
|
"""
|
|
generate audio
|
|
"""
|
|
assert c_input is not None
|
|
B = c_input.shape[0]
|
|
net.set_train(False)
|
|
|
|
if hparam.upsample_conditional_features:
|
|
length = (c_input.shape[-1] - hparam.cin_pad * 2) * audio.get_hop_size()
|
|
else:
|
|
# already dupulicated
|
|
length = c_input.shape[-1]
|
|
|
|
y_hat = net.incremental_forward(c=c_input, g=g_input, T=length, tqdm=tqdm_, softmax=True, quantize=True,
|
|
log_scale_min=hparam.log_scale_min, is_numpy=is_numpy)
|
|
|
|
if is_mulaw_quantize(hparam.input_type):
|
|
# needs to be float since mulaw_inv returns in range of [-1, 1]
|
|
y_hat = np.reshape(np.argmax(y_hat, 1), (B, -1))
|
|
y_hat = y_hat.astype(np.float32)
|
|
for k in range(B):
|
|
y_hat[k] = P.inv_mulaw_quantize(y_hat[k], hparam.quantize_channels - 1)
|
|
elif is_mulaw(hparam.input_type):
|
|
y_hat = np.reshape(y_hat, (B, -1))
|
|
for k in range(B):
|
|
y_hat[k] = P.inv_mulaw(y_hat[k], hparam.quantize_channels - 1)
|
|
else:
|
|
y_hat = np.reshape(y_hat, (B, -1))
|
|
|
|
if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
|
|
for k in range(B):
|
|
y_hat[k] = getattr(audio, hparam.postprocess)(y_hat[k])
|
|
|
|
if hparam.global_gain_scale > 0:
|
|
for k in range(B):
|
|
y_hat[k] /= hparam.global_gain_scale
|
|
|
|
return y_hat
|
|
|
|
|
|
def to_int16(x_):
|
|
"""
|
|
convert datatype to int16
|
|
"""
|
|
if x_.dtype == np.int16:
|
|
return x_
|
|
assert x_.dtype == np.float32
|
|
assert x_.min() >= -1 and x_.max() <= 1.0
|
|
return (x_ * 32767).astype(np.int16)
|
|
|
|
|
|
def get_reference_file(hparam, dataset_source, idx):
|
|
"""
|
|
get reference files
|
|
"""
|
|
reference_files = []
|
|
reference_feats = []
|
|
for _ in range(hparam.batch_size):
|
|
if hasattr(dataset_source, "X"):
|
|
reference_files.append(dataset_source.X.collected_files[idx][0])
|
|
else:
|
|
pass
|
|
if hasattr(dataset_source, "Mel"):
|
|
reference_feats.append(dataset_source.Mel.collected_files[idx][0])
|
|
else:
|
|
reference_feats.append(dataset_source.collected_files[idx][0])
|
|
idx += 1
|
|
return reference_files, reference_feats, idx
|
|
|
|
|
|
def get_saved_audio_name(has_ref_file_, ref_file, ref_feat, g_fp):
|
|
"""get path to save reference audio"""
|
|
if has_ref_file_:
|
|
target_audio_path = ref_file
|
|
name = os.path.splitext(os.path.basename(target_audio_path))[0].replace("-wave", "")
|
|
else:
|
|
target_feat_path = ref_feat
|
|
name = os.path.splitext(os.path.basename(target_feat_path))[0].replace("-feats", "")
|
|
# Paths
|
|
if g_fp is None:
|
|
dst_wav_path_ = join(args.output_path, "{}_gen.wav".format(name))
|
|
target_wav_path_ = join(args.output_path, "{}_ref.wav".format(name))
|
|
else:
|
|
dst_wav_path_ = join(args.output_path, "speaker{}_{}_gen.wav".format(g, name))
|
|
target_wav_path_ = join(args.output_path, "speaker{}_{}_ref.wav".format(g, name))
|
|
return dst_wav_path_, target_wav_path_
|
|
|
|
|
|
def save_ref_audio(hparam, ref, length, target_wav_path_):
|
|
"""
|
|
save reference audio
|
|
"""
|
|
if is_mulaw_quantize(hparam.input_type):
|
|
ref = np.reshape(np.argmax(ref, 0), (-1))[:length]
|
|
ref = ref.astype(np.float32)
|
|
else:
|
|
ref = np.reshape(ref, (-1))[:length]
|
|
|
|
if is_mulaw_quantize(hparam.input_type):
|
|
ref = P.inv_mulaw_quantize(ref, hparam.quantize_channels - 1)
|
|
elif is_mulaw(hparam.input_type):
|
|
ref = P.inv_mulaw(ref, hparam.quantize_channels - 1)
|
|
if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
|
|
ref = getattr(audio, hparam.postprocess)(ref)
|
|
if hparam.global_gain_scale > 0:
|
|
ref /= hparam.global_gain_scale
|
|
|
|
ref = np.clip(ref, -1.0, 1.0)
|
|
|
|
wavfile.write(target_wav_path_, hparam.sample_rate, to_int16(ref))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
|
|
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
|
|
if args.preset is not None:
|
|
with open(args.preset) as f:
|
|
hparams.parse_json(f.read())
|
|
|
|
assert hparams.name == "wavenet_vocoder"
|
|
print(hparams_debug_string())
|
|
|
|
fs = hparams.sample_rate
|
|
hparams.batch_size = 10
|
|
hparams.max_time_sec = None
|
|
hparams.max_time_steps = None
|
|
data_loaders, source_dataset = get_data_loader(hparam=hparams, data_dir=args.data_path)
|
|
|
|
upsample_params = hparams.upsample_params
|
|
upsample_params["cin_channels"] = hparams.cin_channels
|
|
upsample_params["cin_pad"] = hparams.cin_pad
|
|
model = WaveNet(
|
|
out_channels=hparams.out_channels,
|
|
layers=hparams.layers,
|
|
stacks=hparams.stacks,
|
|
residual_channels=hparams.residual_channels,
|
|
gate_channels=hparams.gate_channels,
|
|
skip_out_channels=hparams.skip_out_channels,
|
|
cin_channels=hparams.cin_channels,
|
|
gin_channels=hparams.gin_channels,
|
|
n_speakers=hparams.n_speakers,
|
|
dropout=hparams.dropout,
|
|
kernel_size=hparams.kernel_size,
|
|
cin_pad=hparams.cin_pad,
|
|
upsample_conditional_features=hparams.upsample_conditional_features,
|
|
upsample_params=upsample_params,
|
|
scalar_input=is_scalar_input(hparams.input_type),
|
|
output_distribution=hparams.output_distribution,
|
|
)
|
|
|
|
param_dict = load_checkpoint(args.pretrain_ckpt)
|
|
load_param_into_net(model, param_dict)
|
|
print('Successfully loading the pre-trained model')
|
|
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
cin_pad = hparams.cin_pad
|
|
|
|
file_idx = 0
|
|
for data in data_loaders.create_dict_iterator():
|
|
x, y, c, g, input_lengths = data['x_batch'], data['y_batch'], data['c_batch'], data['g_batch'], data[
|
|
'input_lengths']
|
|
if cin_pad > 0:
|
|
c = c.asnumpy()
|
|
c = np.pad(c, pad_width=(cin_pad, cin_pad), mode="edge")
|
|
c = Tensor(c)
|
|
|
|
ref_files, ref_feats, file_idx = get_reference_file(hparams, source_dataset, file_idx)
|
|
# Generate
|
|
y_hats = batch_wavegen(hparams, model, data['c_batch'], tqdm_=tqdm, is_numpy=args.is_numpy)
|
|
x = x.asnumpy()
|
|
input_lengths = input_lengths.asnumpy()
|
|
# Save each utt.
|
|
has_ref_file = bool(ref_files)
|
|
for i, (ref_, gen_, length_) in enumerate(zip(x, y_hats, input_lengths)):
|
|
dst_wav_path, target_wav_path = get_saved_audio_name(has_ref_file_=has_ref_file, ref_file=ref_files[i],
|
|
ref_feat=ref_feats[i], g_fp=g)
|
|
save_ref_audio(hparams, ref_, length_, target_wav_path)
|
|
|
|
gen = gen_[:length_]
|
|
gen = np.clip(gen, -1.0, 1.0)
|
|
wavfile.write(dst_wav_path, hparams.sample_rate, to_int16(gen))
|