You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/research/audio/wavenet/wavenet_vocoder/mixture.py

387 lines
12 KiB

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
loss function for training and sample function for testing
"""
import numpy as np
import mindspore as ms
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops as P
from mindspore import context
class log_sum_exp(nn.Cell):
"""Numerically stable log_sum_exp
"""
def __init__(self):
super(log_sum_exp, self).__init__()
self.maxi = P.ReduceMax()
self.maxi_dim = P.ReduceMax(keep_dims=True)
self.log = P.Log()
self.sums = P.ReduceSum()
self.exp = P.Exp()
def construct(self, x):
axis = len(x.shape) - 1
m = self.maxi(x, axis)
m2 = self.maxi_dim(x, axis)
return m + self.log(self.sums(self.exp(x - m2), axis))
class log_softmax(nn.Cell):
"""
replacement of P.LogSoftmax(-1) in CPU mode
only support x.shape == 2 or 3
"""
def __init__(self):
super(log_softmax, self).__init__()
self.maxi = P.ReduceMax()
self.log = P.Log()
self.sums = P.ReduceSum()
self.exp = P.Exp()
self.axis = -1
self.concat = P.Concat(-1)
self.expanddims = P.ExpandDims()
def construct(self, x):
"""
Args:
x (Tensor): input
Returns:
Tensor: log_softmax of input
"""
c = self.maxi(x, self.axis)
logs, lsm = None, None
if len(x.shape) == 2:
for j in range(x.shape[-1]):
temp = self.expanddims(self.exp(x[:, j] - c), -1)
logs = temp if j == 0 else self.concat((logs, temp))
sums = self.sums(logs, -1)
for i in range(x.shape[-1]):
temp = self.expanddims(x[:, i] - c - self.log(sums), -1)
lsm = temp if i == 0 else self.concat((lsm, temp))
return lsm
if len(x.shape) == 3:
for j in range(x.shape[-1]):
temp = self.expanddims(self.exp(x[:, :, j] - c), -1)
logs = temp if j == 0 else self.concat((logs, temp))
sums = self.sums(logs, -1)
for i in range(x.shape[-1]):
temp = self.expanddims(x[:, :, i] - c - self.log(sums), -1)
lsm = temp if i == 0 else self.concat((lsm, temp))
return lsm
return None
class Stable_softplus(nn.Cell):
"""Numerically stable softplus
"""
def __init__(self):
super(Stable_softplus, self).__init__()
self.log_op = P.Log()
self.abs_op = P.Abs()
self.relu_op = P.ReLU()
self.exp_op = P.Exp()
def construct(self, x):
return self.log_op(1 + self.exp_op(- self.abs_op(x))) + self.relu_op(x)
class discretized_mix_logistic_loss(nn.Cell):
"""
Discretized_mix_logistic_loss
Args:
num_classes (int): Num_classes
log_scale_min (float): Log scale minimum value
"""
def __init__(self, num_classes=256, log_scale_min=-7.0, reduce=True):
super(discretized_mix_logistic_loss, self).__init__()
self.num_classes = num_classes
self.log_scale_min = log_scale_min
self.reduce = reduce
self.transpose_op = P.Transpose()
self.exp = P.Exp()
self.sigmoid = P.Sigmoid()
self.softplus = Stable_softplus()
self.log = P.Log()
self.cast = P.Cast()
self.expand_dims = P.ExpandDims()
self.tile = P.Tile()
self.maximum = P.Maximum()
self.sums = P.ReduceSum()
self.lse = log_sum_exp()
self.reshape = P.Reshape()
self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32))
self.tensor_one = Tensor(1., ms.float32)
if context.get_context("device_target") == "CPU":
self.logsoftmax = log_softmax()
else:
self.logsoftmax = P.LogSoftmax(-1)
def construct(self, y_hat, y):
"""
Args:
y_hat (Tensor): Predicted distribution
y (Tensor): Target
Returns:
Tensor: Discretized_mix_logistic_loss
"""
nr_mix = y_hat.shape[1] // 3
# (B x T x C)
y_hat = self.transpose_op(y_hat, (0, 2, 1))
# (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix]
min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix))
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut)
# B x T x 1 -> B x T x num_mixtures
y = self.tile(y, (1, 1, nr_mix))
centered_y = y - means
inv_stdv = self.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1. / (self.num_classes - 1))
cdf_plus = self.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1. / (self.num_classes - 1))
cdf_min = self.sigmoid(min_in)
log_cdf_plus = plus_in - self.softplus(plus_in)
log_one_minus_cdf_min = -self.softplus(min_in)
cdf_delta = cdf_plus - cdf_min
mid_in = inv_stdv * centered_y
log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in)
inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32)
min_cut2 = 1e-12 * self.tile(self.tensor_one, cdf_delta.shape)
inner_inner_out = inner_inner_cond * \
self.log(self.maximum(cdf_delta, min_cut2)) + \
(1. - inner_inner_cond) * (log_pdf_mid - self.factor)
inner_cond = self.cast(y > 0.999, ms.float32)
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
cond = self.cast(y < -0.999, ms.float32)
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
a, b, c = logit_probs.shape[0], logit_probs.shape[1], logit_probs.shape[2]
logit_probs = self.logsoftmax(self.reshape(logit_probs, (-1, c)))
logit_probs = self.reshape(logit_probs, (a, b, c))
log_probs = log_probs + logit_probs
if self.reduce:
return -self.sums(self.lse(log_probs))
return self.expand_dims(-self.lse(log_probs), -1)
def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0):
"""
Sample from discretized mixture of logistic distributions
Args:
y (ndarray): B x C x T
log_scale_min (float): Log scale minimum value
Returns:
ndarray
"""
nr_mix = y.shape[1] // 3
# B x T x C
y = np.transpose(y, (0, 2, 1))
logit_probs = y[:, :, :nr_mix]
temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape)
temp = logit_probs - np.log(- np.log(temp))
argmax = np.argmax(temp, axis=-1)
# (B, T) -> (B, T, nr_mix)
one_hot = np.eye(nr_mix)[argmax]
means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
log_scales = np.clip(np.sum(
y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), a_min=log_scale_min, a_max=None)
u = np.random.uniform(1e-5, 1.0 - 1e-5, means.shape)
x = means + np.exp(log_scales) * (np.log(u) - np.log(1. - u))
x = np.clip(x, -1., 1.)
return x.astype(np.float32)
class mix_gaussian_loss(nn.Cell):
"""
Mix gaussian loss
"""
def __init__(self, log_scale_min=-7.0, reduce=True):
super(mix_gaussian_loss, self).__init__()
self.log_scale_min = log_scale_min
self.reduce = reduce
self.transpose_op = P.Transpose()
self.maximum = P.Maximum()
self.tile = P.Tile()
self.exp = P.Exp()
self.expand_dims = P.ExpandDims()
self.sums = P.ReduceSum()
self.lse = log_sum_exp()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.const = P.ScalarToArray()
self.log = P.Log()
self.tensor_one = Tensor(1., ms.float32)
if context.get_context("device_target") == "CPU":
self.logsoftmax = log_softmax()
else:
self.logsoftmax = P.LogSoftmax(-1)
def construct(self, y_hat, y):
"""
Args:
y_hat (Tensor): Predicted probability
y (Tensor): Target
Returns:
Tensor: Mix_gaussian_loss
"""
C = y_hat.shape[1]
if C == 2:
nr_mix = 1
else:
nr_mix = y_hat.shape[1] // 3
# (B x T x C)
y_hat = self.transpose_op(y_hat, (0, 2, 1))
if C == 2:
logit_probs = None
means = y_hat[:, :, 0:1]
min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], 1))
log_scales = self.maximum(y_hat[:, :, 1:2], min_cut)
else:
# (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix]
min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix))
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut)
# B x T x 1 -> B x T x num_mixtures
y = self.tile(y, (1, 1, nr_mix))
centered_y = y - means
sd = self.exp(log_scales)
unnormalized_log_prob = -1. * (self.sq(centered_y - 0.)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
log_probs = unnormalized_log_prob + neg_normalization
if nr_mix > 1:
log_probs = log_probs + self.logsoftmax(logit_probs)
if self.reduce:
if nr_mix == 1:
return -self.sums(log_probs)
return -self.sums(self.lse(log_probs))
if nr_mix == 1:
return -log_probs
return self.expand_dims(-self.lse(log_probs), -1)
def sample_from_mix_gaussian(y, log_scale_min=-7.0):
"""
Sample_from_mix_gaussian
Args:
y (ndarray): B x C x T
Returns:
ndarray
"""
C = y.shape[1]
if C == 2:
nr_mix = 1
else:
nr_mix = y.shape[1] // 3
# B x T x C
y = np.transpose(y, (0, 2, 1))
if C == 2:
logit_probs = None
else:
logit_probs = y[:, :, :nr_mix]
if nr_mix > 1:
temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape)
temp = logit_probs - np.log(- np.log(temp))
argmax = np.argmax(temp, axis=-1)
# (B, T) -> (B, T, nr_mix)
one_hot = np.eye(nr_mix)[argmax]
means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
log_scales = np.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1)
else:
if C == 2:
means, log_scales = y[:, :, 0], y[:, :, 1]
elif C == 3:
means, log_scales = y[:, :, 1], y[:, :, 2]
else:
assert False, "shouldn't happen"
scales = np.exp(log_scales)
x = np.random.normal(loc=means, scale=scales)
x = np.clip(x, -1., 1.)
return x.astype(np.float32)
# self-implemented onehotcategorical distribution
# https://zhuanlan.zhihu.com/p/59550457
def sample_from_mix_onehotcategorical(x):
"""
Sample_from_mix_onehotcategorical
Args:
x (ndarray): Predicted softmax probability
Returns:
ndarray
"""
pi = np.log(x)
u = np.random.uniform(0, 1, x.shape)
g = -np.log(-np.log(u))
c = np.argmax(pi + g, axis=1)
return np.array(np.eye(256)[c], dtype=np.float32)