You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/research/audio/wavenet/wavenet_vocoder/wavenet.py

347 lines
13 KiB

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""WaveNet construction"""
from __future__ import with_statement, print_function, absolute_import
import math
import numpy as np
from mindspore import nn, Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from wavenet_vocoder import upsample
from .modules import Embedding
from .modules import Conv1d1x1
from .modules import ResidualConv1dGLU
from .mixture import sample_from_discretized_mix_logistic
from .mixture import sample_from_mix_gaussian
from .mixture import sample_from_mix_onehotcategorical
class WaveNet(nn.Cell):
"""
WaveNet model definition. Only local condition is supported
Args:
out_channels (int): Output channels. If input_type is mu-law quantized one-hot vecror, it should equal to the
quantize channels. Otherwise, it equals to num_mixtures x 3. Default: 256.
layers (int): Number of ResidualConv1dGLU layers
stacks (int): Number of dilation cycles
residual_channels (int): Residual input / output channels
gate_channels (int): Gated activation channels.
skip_out_channels (int): Skip connection channels.
kernel_size (int): Kernel size .
dropout (float): Dropout rate.
cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled.
gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled.
n_speakers (int): Number of speakers. This is used when global conditioning is enabled.
upsample_conditional_features (bool): Whether upsampling local conditioning features by resize_nearestneighbor
and conv or not.
scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise, quantized one-hot vector
is expected.
use_speaker_embedding (Bool): Use speaker embedding or Not.
"""
def __init__(self, out_channels=256, layers=20, stacks=2,
residual_channels=512,
gate_channels=512,
skip_out_channels=512,
kernel_size=3, dropout=1 - 0.95,
cin_channels=-1, gin_channels=-1, n_speakers=None,
upsample_conditional_features=False,
upsample_net="ConvInUpsampleNetwork",
upsample_params=None,
scalar_input=False,
use_speaker_embedding=False,
output_distribution="Logistic",
cin_pad=0,
):
super(WaveNet, self).__init__()
self.transpose_op = P.Transpose()
self.softmax = P.Softmax(axis=1)
self.reshape_op = P.Reshape()
self.zeros_op = P.Zeros()
self.ones_op = P.Ones()
self.relu_op = P.ReLU()
self.squeeze_op = P.Squeeze()
self.expandim_op = P.ExpandDims()
self.transpose_op = P.Transpose()
self.tile_op = P.Tile()
self.scalar_input = scalar_input
self.out_channels = out_channels
self.cin_channels = cin_channels
self.output_distribution = output_distribution
self.fack_data = P.Zeros()
assert layers % stacks == 0
layers_per_stack = layers // stacks
if scalar_input:
self.first_conv = Conv1d1x1(1, residual_channels)
else:
self.first_conv = Conv1d1x1(out_channels, residual_channels)
conv_layers = []
for layer in range(layers):
dilation = 2 ** (layer % layers_per_stack)
conv = ResidualConv1dGLU(
residual_channels, gate_channels,
kernel_size=kernel_size,
skip_out_channels=skip_out_channels,
bias=True,
dropout=dropout,
dilation=dilation,
cin_channels=cin_channels,
gin_channels=gin_channels)
conv_layers.append(conv)
self.conv_layers = nn.CellList(conv_layers)
self.last_conv_layers = nn.CellList([
nn.ReLU(),
Conv1d1x1(skip_out_channels, skip_out_channels),
nn.ReLU(),
Conv1d1x1(skip_out_channels, out_channels)])
if gin_channels > 0 and use_speaker_embedding:
assert n_speakers is not None
self.embed_speakers = Embedding(
n_speakers, gin_channels, padding_idx=None, std=0.1)
else:
self.embed_speakers = None
if upsample_conditional_features:
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
else:
self.upsample_net = None
self.factor = math.sqrt(1.0 / len(self.conv_layers))
def _expand_global_features(self, batch_size, time_step, g_fp, is_expand=True):
"""Expand global conditioning features to all time steps
Args:
batch_size (int): Batch size.
time_step (int): Time length.
g_fp (Tensor): Global features, (B x C) or (B x C x 1).
is_expand (bool) : Expanded global conditioning features
Returns:
Tensor: B x C x T or B x T x C or None
"""
if g_fp is None:
return None
if len(g_fp.shape) == 2:
g_fp = self.expandim_op(g_fp, -1)
else:
g_fp = g_fp
if is_expand:
expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step))
return expand_fp
expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step))
expand_fp = self.transpose_op(expand_fp, (0, 2, 1))
return expand_fp
def construct(self, x, c=None, g=None, softmax=False):
"""
Args:
x (Tensor): One-hot encoded audio signal
c (Tensor): Local conditioning feature
g (Tensor): Global conditioning feature
softmax (bool): Whether use softmax or not
Returns:
Tensor: Net output
"""
g = None
B, _, T = x.shape
if g is not None:
if self.embed_speakers is not None:
g = self.embed_speakers(self.reshape_op(g, (B, -1)))
g = self.transpose_op(g, (0, 2, 1))
g_bct = self._expand_global_features(B, T, g, is_expand=True)
if c is not None and self.upsample_net is not None:
c = self.upsample_net(c)
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, c, g_bct)
skips += h
skips *= self.factor
x = skips
for f in self.last_conv_layers:
x = f(x)
x = self.softmax(x) if softmax else x
return x
def relu_numpy(self, inX):
"""numpy relu function"""
return np.maximum(0, inX)
def softmax_numpy(self, x):
""" numpy softmax function """
x -= np.max(x, axis=1, keepdims=True)
return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
def incremental_forward(self, initial_input=None, c=None, g=None,
T=100, test_inputs=None,
tqdm=lambda x: x, softmax=True, quantize=True,
log_scale_min=-50.0, is_numpy=True):
"""
Incremental forward. Current output depends on last output.
Args:
initial_input (Tensor): Initial input, the shape is B x C x 1
c (Tensor): Local conditioning feature, the shape is B x C x T
g (Tensor): Global conditioning feature, the shape is B x C or B x C x 1
T (int): decoding time step.
test_inputs: Teacher forcing inputs (for debugging)
tqdm (lamda): tqmd
softmax (bool): Whether use softmax or not
quantize (bool): Whether quantize softmax output in last step when decoding current step
log_scale_min (float): Log scale minimum value
Returns:
Tensor: Predicted on-hot encoded samples or scalar vector depending on loss type
"""
self.clear_buffer()
B = 1
if test_inputs is not None:
if self.scalar_input:
if test_inputs.shape[1] == 1:
test_inputs = self.transpose_op(test_inputs, (0, 2, 1))
else:
if test_inputs.shape[1] == self.out_channels:
test_inputs = self.transpose_op(test_inputs, (0, 2, 1))
B = test_inputs.shape[0]
if T is None:
T = test_inputs.shape[1]
else:
T = max(T, test_inputs.shape[1])
T = int(T)
# Global conditioning
if g is not None:
if self.embed_speakers is not None:
g = self.embed_speakers(self.reshape_op(g, (B, -1)))
g = self.transpose_op(g, (0, 2, 1))
assert g.dim() == 3
g_btc = self._expand_global_features(B, T, g, is_expand=False)
# Local conditioning
if c is not None:
B = c.shape[0]
if self.upsample_net is not None:
c = self.upsample_net(c)
assert c.shape[-1] == T
if c.shape[-1] == T:
c = self.transpose_op(c, (0, 2, 1))
outputs = []
if initial_input is None:
if self.scalar_input:
initial_input = self.zeros_op((B, 1, 1), mstype.float32)
else:
initial_input = np.zeros((B, 1, self.out_channels), np.float32)
initial_input[:, :, 127] = 1
initial_input = Tensor(initial_input)
else:
if initial_input.shape[1] == self.out_channels:
initial_input = self.transpose_op(initial_input, (0, 2, 1))
if is_numpy:
current_input = initial_input.asnumpy()
else:
current_input = initial_input
for t in tqdm(range(T)):
if test_inputs is not None and t < test_inputs.shape[1]:
current_input = self.expandim_op(test_inputs[:, t, :], 1)
else:
if t > 0:
if not is_numpy:
current_input = Tensor(outputs[-1])
else:
current_input = outputs[-1]
# Conditioning features for single time step
ct = None if c is None else self.expandim_op(c[:, t, :], 1)
gt = None if g is None else self.expandim_op(g_btc[:, t, :], 1)
x = current_input
if is_numpy:
ct = ct.asnumpy()
x = self.first_conv.incremental_forward(x, is_numpy=is_numpy)
skips = 0
for f in self.conv_layers:
x, h = f.incremental_forward(x, ct, gt, is_numpy=is_numpy)
skips += h
skips *= self.factor
x = skips
for f in self.last_conv_layers:
try:
x = f.incremental_forward(x, is_numpy=is_numpy)
except AttributeError:
if is_numpy:
x = self.relu_numpy(x)
else:
x = self.relu_op(x)
# Generate next input by sampling
if not is_numpy:
x = x.asnumpy()
if self.scalar_input:
if self.output_distribution == "Logistic":
x = sample_from_discretized_mix_logistic(x.reshape((B, -1, 1)), log_scale_min=log_scale_min)
elif self.output_distribution == "Normal":
x = sample_from_mix_gaussian(x.reshape((B, -1, 1)), log_scale_min=log_scale_min)
else:
assert False
else:
x = self.softmax_numpy(np.reshape(x, (B, -1))) if softmax else np.reshape(x, (B, -1))
if quantize:
x = sample_from_mix_onehotcategorical(x)
outputs += [x]
# T x B x C
outputs = np.stack(outputs, 0)
# B x C x T
outputs = np.transpose(outputs, (1, 2, 0))
self.clear_buffer()
return outputs
def clear_buffer(self):
"""clear buffer"""
self.first_conv.clear_buffer()
for f in self.conv_layers:
f.clear_buffer()
for f in self.last_conv_layers:
try:
f.clear_buffer()
except AttributeError:
pass