!11859 Added Wavenet model Scripts in CPU mode.

From: @huangbo77
Reviewed-by: 
Signed-off-by:
pull/11859/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 894d93b4fd

@ -77,6 +77,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio) - [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md) - [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
- [DeepSpeech2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2/README.md) - [DeepSpeech2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2/README.md)
- [Wavenet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/wavenet/README.md)
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc) - [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md) - [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md)
- [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md) - [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md)

@ -18,7 +18,7 @@
# [WaveNet Description](#contents) # [WaveNet Description](#contents)
WaveNet is a deep neural network for generating raw audio waveforms. The model is fully probabilistic and autoregressive, with the predictive distribution for each audio sample conditioned on all previous ones. We support training and evaluation on GPU. WaveNet is a deep neural network for generating raw audio waveforms. The model is fully probabilistic and autoregressive, with the predictive distribution for each audio sample conditioned on all previous ones. We support training and evaluation on both GPU and CPU.
[Paper](https://arxiv.org/pdf/1609.03499.pdf): ord A, Dieleman S, Zen H, et al. Wavenet: A generative model for raw audio [Paper](https://arxiv.org/pdf/1609.03499.pdf): ord A, Dieleman S, Zen H, et al. Wavenet: A generative model for raw audio
@ -47,8 +47,8 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
# [Environment Requirements](#contents) # [Environment Requirements](#contents)
- HardwareGPU - HardwareGPU/CPU
- Prepare hardware environment with GPU processor. - Prepare hardware environment with GPU/CPU processor.
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en) - [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below - For more information, please check the resources below
@ -65,37 +65,38 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
. .
├── audio ├── audio
└──wavenet └──wavenet
├──datasets // Note the datasets folder should be download from the above link ├──datasets // Note the datasets folder should be downloaded from the above link
├──egs // Note the egs folder should be download from the above link ├──egs // Note the egs folder should be downloaded from the above link
├──utils // Note the utils folder should be download from the above link ├──utils // Note the utils folder should be downloaded from the above link
├── audio.py // audio utils. Note this script should be download from a third party ├── audio.py // Audio utils. Note this script should be downloaded from a third party
├── compute-meanvar-stats.py // Compute mean-variance normalization stats. Note this script should be download from the above link ├── compute-meanvar-stats.py // Compute mean-variance normalization stats. Note this script should be downloaded from the above link
├── evaluate.py // evaluation ├── evaluate.py // Evaluation
├── export.py // convert mindspore model to air model ├── export.py // Convert mindspore model to air model
├── hparams.py // hyper-parameter configuration. Note this script should be download from the above link ├── hparams.py // Hyper-parameter configuration. Note this script should be downloaded from the above link
├── mksubset.py // Make subset of dataset. Note this script should be download from the above link ├── lrschedule.py // Learning rate scheduler. Note this script should be downloaded from the above link
├── preprocess.py // Preprocess dataset. Note this script should be download from the above link ├── mksubset.py // Make subset of dataset. Note this script should be downloaded from the above link
├── preprocess_normalize.py // Perform meanvar normalization to preprocessed features. Note this script should be download from the above link ├── preprocess.py // Preprocess dataset. Note this script should be downloaded from the above link
├── README.md // descriptions about WaveNet ├── preprocess_normalize.py // Perform meanvar normalization to preprocessed features. Note this script should be downloaded from the above link
├── train.py // training scripts ├── README.md // Descriptions about WaveNet
├── train_pytorch.py // Note this script should be download from the above link. The initial name of this script is train.py in the project from the link ├── train.py // Training scripts
├── train_pytorch.py // Note this script should be downloaded from the above link. The initial name of this script is train.py in the project from the link
├── src ├── src
│ ├──__init__.py │ ├──__init__.py
│ ├──dataset.py // generate dataloader and data processing entry │ ├──dataset.py // Generate dataloader and data processing entry
│ ├──callback.py // callbacks to monitor the training │ ├──callback.py // Callbacks to monitor the training
│ ├──lr_generator.py // learning rate generator │ ├──lr_generator.py // Learning rate generator
│ └──loss.py // loss function definition │ └──loss.py // Loss function definition
└── wavenet_vocoder └── wavenet_vocoder
├──__init__.py ├──__init__.py
├──conv.py // extended 1D convolution ├──conv.py // Extended 1D convolution
├──mixture.py // loss function for training and sample function for testing ├──mixture.py // Loss function for training and sample function for testing
├──modules.py // modules for Wavenet construction ├──modules.py // Modules for Wavenet construction
├──upsample.py // upsample layer definition ├──upsample.py // Upsample layer definition
├──util.py // utils. Note this script should be download from the above link ├──util.py // Utils. Note this script should be downloaded from the above link
├──wavenet.py // WaveNet networks ├──wavenet.py // WaveNet networks
└──tfcompat // Note this script should be download from the above link └──tfcompat // Note this script should be downloaded from the above link
├──__init__.py ├──__init__.py
└──hparam.py // param management tools └──hparam.py // Param management tools
``` ```
## [Script Parameters](#contents) ## [Script Parameters](#contents)
@ -105,13 +106,15 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
```text ```text
usage: train.py [--data_path DATA_PATH] [--preset PRESET] usage: train.py [--data_path DATA_PATH] [--preset PRESET]
[--checkpoint_dir CHECKPOINT_DIR] [--checkpoint CHECKPOINT] [--checkpoint_dir CHECKPOINT_DIR] [--checkpoint CHECKPOINT]
[--speaker_id SPEAKER_ID] [--is_distributed IS_DISTRIBUTED] [--speaker_id SPEAKER_ID] [--platform PLATFORM]
[--is_distributed IS_DISTRIBUTED]
options: options:
--data_path dataset path --data_path dataset path
--preset path of preset parameters (json) --preset path of preset parameters (json)
--checkpoint_dir directory of saving model checkpoints --checkpoint_dir directory of saving model checkpoints
--checkpoint pre-trained ckpt path, default is "./checkpoints" --checkpoint pre-trained ckpt path, default is "./checkpoints"
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently --speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
--platform specify platform to be used, defeault is "GPU"
--is_distributed whether distributed training or not --is_distributed whether distributed training or not
``` ```
@ -120,8 +123,9 @@ options:
```text ```text
usage: evaluate.py [--data_path DATA_PATH] [--preset PRESET] usage: evaluate.py [--data_path DATA_PATH] [--preset PRESET]
[--pretrain_ckpt PRETRAIN_CKPT] [--output_path OUTPUT_PATH] [--pretrain_ckpt PRETRAIN_CKPT] [--is_numpy]
[--speaker_id SPEAKER_ID] [--output_path OUTPUT_PATH] [--speaker_id SPEAKER_ID]
[--platform PLATFORM]
options: options:
--data_path dataset path --data_path dataset path
--preset path of preset parameters (json) --preset path of preset parameters (json)
@ -129,6 +133,7 @@ options:
--is_numpy whether using numpy for inference or not --is_numpy whether using numpy for inference or not
--output_path path to save synthesized audio --output_path path to save synthesized audio
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently --speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
--platform specify platform to be used, defeault is "GPU"
``` ```
More parameters for training and evaluation can be set in file `hparams.py`. More parameters for training and evaluation can be set in file `hparams.py`.
@ -194,18 +199,19 @@ After the processing, the directory of gaussian will be as follows:
└──eval └──eval
``` ```
The train_no_dev folder contains the final training data. For mulaw256 and mol, the process is the same. When the training data is prepared, The train_no_dev folder contains the final training data. For mol and gaussian, the process is the same. When the training data is prepared,
you can run the following command to train the network: you can run the following command to train the network:
```bash ```bash
# standalone training Standalone training
python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt GPU:
python train.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt
distributed training CPU:
CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt --is_distributed=True python train.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt --platform=CPU
eval Distributed training (on GPU only)
python evaluate.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --output_path=path_to_save_audio CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt --is_distributed=True
``` ```
## [Evaluation Process](#contents) ## [Evaluation Process](#contents)
@ -214,21 +220,29 @@ WaveNet has a process of auto-regression and this process currently cannot be ru
this [link](https://bbs.huaweicloud.com/forum/thread-94852-1-1.html) this [link](https://bbs.huaweicloud.com/forum/thread-94852-1-1.html)
```bash ```bash
eval Evaluation
python evaluate.py --data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio GPU:
python evaluate.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio
CPU:
python evaluate.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio --platform=CPU
``` ```
## [Convert Process](#contents) ## [Convert Process](#contents)
```bash ```bash
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt GPU:
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_dump_hparams --pretrain_ckpt=path_to_load_ckpt
CPU:
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_dump_hparams --pretrain_ckpt=path_to_load_ckpt --platform=CPU
``` ```
# [Model Description](#contents) # [Model Description](#contents)
## [Performance](#contents) ## [Performance](#contents)
### Training Performance ### Training Performance on GPU
| Parameters | WaveNet | | Parameters | WaveNet |
| -------------------------- | ---------------------------------------------------------------| | -------------------------- | ---------------------------------------------------------------|

@ -36,10 +36,12 @@ parser.add_argument('--data_path', type=str, required=True, default='',
help='Directory contains preprocessed features.') help='Directory contains preprocessed features.')
parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).') parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path') parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
parser.add_argument('--is_numpy', action="store_false", default=True, help='Using numpy for inference or not') parser.add_argument('--is_numpy', action="store_true", default=False, help='Using numpy for inference or not')
parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios') parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios')
parser.add_argument('--speaker_id', type=str, default='', parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.') help=' Use specific speaker of data in case for multi-speaker datasets.')
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
help='run platform, support GPU and CPU. Default: GPU')
args = parser.parse_args() args = parser.parse_args()
@ -183,7 +185,7 @@ def save_ref_audio(hparam, ref, length, target_wav_path_):
if __name__ == '__main__': if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
if args.preset is not None: if args.preset is not None:
with open(args.preset) as f: with open(args.preset) as f:

@ -27,14 +27,18 @@ from src.loss import PredictNet
parser = argparse.ArgumentParser(description='TTS training') parser = argparse.ArgumentParser(description='TTS training')
parser.add_argument('--preset', type=str, default='', help='Path of preset parameters (json).') parser.add_argument('--preset', type=str, default='', help='Path of preset parameters (json).')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test',
help='Directory where to save model checkpoints [default: checkpoints].')
parser.add_argument('--speaker_id', type=str, default='', parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.') help=' Use specific speaker of data in case for multi-speaker datasets.')
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path') parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
help='run platform, support GPU and CPU. Default: GPU')
args = parser.parse_args() args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
if args.preset is not None: if args.preset is not None:
@ -82,13 +86,14 @@ if __name__ == '__main__':
Net = PredictNet(model) Net = PredictNet(model)
Net.set_train(False) Net.set_train(False)
receptive_field = model.receptive_field
print("Receptive field (samples / ms): {} / {}".format(receptive_field, receptive_field / fs * 1000))
param_dict = load_checkpoint(args.pretrain_ckpt) param_dict = load_checkpoint(args.pretrain_ckpt)
load_param_into_net(model, param_dict) load_param_into_net(model, param_dict)
print('Successfully loading the pre-trained model') print('Successfully loading the pre-trained model')
x = np.array(np.random.random((2, 256, 10240)), dtype=np.float32) if is_mulaw_quantize(hparams.input_type):
x = np.array(np.random.random((2, 256, 10240)), dtype=np.float32)
else:
x = np.array(np.random.random((2, 1, 10240)), dtype=np.float32)
c = np.array(np.random.random((2, 80, 44)), dtype=np.float32) c = np.array(np.random.random((2, 80, 44)), dtype=np.float32)
g = np.array([0, 0], dtype=np.int64) g = np.array([0, 0], dtype=np.int64)

@ -20,6 +20,8 @@ import matplotlib.pyplot as plt
from mindspore import nn, Tensor from mindspore import nn, Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
from mindspore import context
from nnmnkwii import preprocessing as P1 from nnmnkwii import preprocessing as P1
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw
@ -204,6 +206,7 @@ class NetWithLossClass(nn.Cell):
Returns: Returns:
Tensor, loss tensor. Tensor, loss tensor.
""" """
def __init__(self, network, hparams): def __init__(self, network, hparams):
super(NetWithLossClass, self).__init__(auto_prefix=False) super(NetWithLossClass, self).__init__(auto_prefix=False)
self.network = network self.network = network
@ -213,6 +216,7 @@ class NetWithLossClass(nn.Cell):
self.transpose_op = P.Transpose() self.transpose_op = P.Transpose()
self.reshape_op = P.Reshape() self.reshape_op = P.Reshape()
self.is_mulaw_quant = is_mulaw_quantize(hparams.input_type) self.is_mulaw_quant = is_mulaw_quantize(hparams.input_type)
self.cast = P.Cast()
if self.is_mulaw_quant: if self.is_mulaw_quant:
self.criterion = MaskedCrossEntropyLoss() self.criterion = MaskedCrossEntropyLoss()
@ -225,13 +229,33 @@ class NetWithLossClass(nn.Cell):
self.criterion = None self.criterion = None
raise RuntimeError( raise RuntimeError(
"Not supported output distribution type: {}".format(hparams.output_distribution)) "Not supported output distribution type: {}".format(hparams.output_distribution))
self.device_target = context.get_context("device_target")
def construct(self, x, y, c, g, input_lengths, mask): def construct(self, x, y, c, g, input_lengths, mask):
"""
Args:
x (Tensor): input
y (Tensor): predition
c (Tensor): local_conditioning
g (Tensor): global_conditioning
input_lengths (Tensor): input_lengths
mask (Tensor): Mask
Returns:
Tensor: Loss tensor
"""
y_hat = self.network(x, c, g, False) y_hat = self.network(x, c, g, False)
if self.is_mulaw_quant: if self.is_mulaw_quant:
y_hat = self.transpose_op(y_hat[:, :, :-1], (0, 2, 1)) y_hat = self.transpose_op(y_hat[:, :, :-1], (0, 2, 1))
y_hat = self.reshape_op(y_hat, (-1, y_hat.shape[-1])) y_hat = self.reshape_op(y_hat, (-1, y_hat.shape[-1]))
y = self.reshape_op(y[:, 1:, 0], (-1,)) if self.device_target == "CPU":
y = self.cast(y, mstype.float32)
y = self.reshape_op(y[:, 1:, 0], (-1,))
y = self.cast(y, mstype.int32)
else:
y = self.reshape_op(y[:, 1:, 0], (-1,))
loss = self.criterion(y_hat, y) loss = self.criterion(y_hat, y)
else: else:
loss = self.criterion(y_hat[:, :, :-1], y[:, 1:, :], mask[:, 1:, :]) loss = self.criterion(y_hat[:, :, :-1], y[:, 1:, :], mask[:, 1:, :])

@ -44,6 +44,8 @@ parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test',
parser.add_argument('--checkpoint', type=str, default='', help='Restore model from checkpoint path if given.') parser.add_argument('--checkpoint', type=str, default='', help='Restore model from checkpoint path if given.')
parser.add_argument('--speaker_id', type=str, default='', parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.') help=' Use specific speaker of data in case for multi-speaker datasets.')
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
help='run platform, support GPU and CPU. Default: GPU')
parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training') parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training')
args = parser.parse_args() args = parser.parse_args()
@ -57,7 +59,7 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
else: else:
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
rank_id = 0 rank_id = 0
group_size = 1 group_size = 1
@ -132,4 +134,4 @@ if __name__ == '__main__':
config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=10) config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=10)
ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck) ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck)
callback_list.append(ckpt_cb) callback_list.append(ckpt_cb)
model.train(hparams.nepochs, data_loaders, callbacks=callback_list) model.train(hparams.nepochs, data_loaders, callbacks=callback_list, dataset_sink_mode=False)

@ -18,6 +18,7 @@ import math
from mindspore import nn, Tensor from mindspore import nn, Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context
import numpy as np import numpy as np
class Conv1d(nn.Conv1d): class Conv1d(nn.Conv1d):
@ -84,7 +85,12 @@ class Conv1d(nn.Conv1d):
self.input_buffer = self.concat_op((self.input_buffer[:, 1:, :], inputs[:, 0:1, :])) self.input_buffer = self.concat_op((self.input_buffer[:, 1:, :], inputs[:, 0:1, :]))
inputs = self.input_buffer inputs = self.input_buffer
if dilation > 1: if dilation > 1:
inputs = inputs[:, 0::dilation, :] if context.get_context("device_target") == "CPU":
inputs = self.transpose_op(inputs, (1, 0, 2))
inputs = inputs[0::dilation, :, :]
inputs = self.transpose_op(inputs, (1, 0, 2))
else:
inputs = inputs[:, 0::dilation, :]
output = self.matmul(self.reshape_op(inputs, (bsz, -1)), self.get_weight) output = self.matmul(self.reshape_op(inputs, (bsz, -1)), self.get_weight)
if self.bias is not None: if self.bias is not None:

@ -20,6 +20,7 @@ import mindspore as ms
from mindspore import Tensor from mindspore import Tensor
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops as P import mindspore.ops as P
from mindspore import context
class log_sum_exp(nn.Cell): class log_sum_exp(nn.Cell):
@ -41,6 +42,55 @@ class log_sum_exp(nn.Cell):
return m + self.log(self.sums(self.exp(x - m2), axis)) return m + self.log(self.sums(self.exp(x - m2), axis))
class log_softmax(nn.Cell):
"""
replacement of P.LogSoftmax(-1) in CPU mode
only support x.shape == 2 or 3
"""
def __init__(self):
super(log_softmax, self).__init__()
self.maxi = P.ReduceMax()
self.log = P.Log()
self.sums = P.ReduceSum()
self.exp = P.Exp()
self.axis = -1
self.concat = P.Concat(-1)
self.expanddims = P.ExpandDims()
def construct(self, x):
"""
Args:
x (Tensor): input
Returns:
Tensor: log_softmax of input
"""
c = self.maxi(x, self.axis)
logs, lsm = None, None
if len(x.shape) == 2:
for j in range(x.shape[-1]):
temp = self.expanddims(self.exp(x[:, j] - c), -1)
logs = temp if j == 0 else self.concat((logs, temp))
sums = self.sums(logs, -1)
for i in range(x.shape[-1]):
temp = self.expanddims(x[:, i] - c - self.log(sums), -1)
lsm = temp if i == 0 else self.concat((lsm, temp))
return lsm
if len(x.shape) == 3:
for j in range(x.shape[-1]):
temp = self.expanddims(self.exp(x[:, :, j] - c), -1)
logs = temp if j == 0 else self.concat((logs, temp))
sums = self.sums(logs, -1)
for i in range(x.shape[-1]):
temp = self.expanddims(x[:, :, i] - c - self.log(sums), -1)
lsm = temp if i == 0 else self.concat((lsm, temp))
return lsm
return None
class Stable_softplus(nn.Cell): class Stable_softplus(nn.Cell):
"""Numerically stable softplus """Numerically stable softplus
""" """
@ -77,7 +127,6 @@ class discretized_mix_logistic_loss(nn.Cell):
self.softplus = Stable_softplus() self.softplus = Stable_softplus()
self.log = P.Log() self.log = P.Log()
self.cast = P.Cast() self.cast = P.Cast()
self.logsoftmax = P.LogSoftmax(-1)
self.expand_dims = P.ExpandDims() self.expand_dims = P.ExpandDims()
self.tile = P.Tile() self.tile = P.Tile()
self.maximum = P.Maximum() self.maximum = P.Maximum()
@ -85,6 +134,12 @@ class discretized_mix_logistic_loss(nn.Cell):
self.lse = log_sum_exp() self.lse = log_sum_exp()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32)) self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32))
self.tensor_one = Tensor(1., ms.float32)
if context.get_context("device_target") == "CPU":
self.logsoftmax = log_softmax()
else:
self.logsoftmax = P.LogSoftmax(-1)
def construct(self, y_hat, y): def construct(self, y_hat, y):
""" """
@ -105,7 +160,8 @@ class discretized_mix_logistic_loss(nn.Cell):
# (B, T, num_mixtures) x 3 # (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix] logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix] means = y_hat[:, :, nr_mix:2 * nr_mix]
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min) min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix))
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut)
# B x T x 1 -> B x T x num_mixtures # B x T x 1 -> B x T x num_mixtures
y = self.tile(y, (1, 1, nr_mix)) y = self.tile(y, (1, 1, nr_mix))
@ -127,8 +183,9 @@ class discretized_mix_logistic_loss(nn.Cell):
log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in) log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in)
inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32) inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32)
min_cut2 = 1e-12 * self.tile(self.tensor_one, cdf_delta.shape)
inner_inner_out = inner_inner_cond * \ inner_inner_out = inner_inner_cond * \
self.log(self.maximum(cdf_delta, 1e-12)) + \ self.log(self.maximum(cdf_delta, min_cut2)) + \
(1. - inner_inner_cond) * (log_pdf_mid - self.factor) (1. - inner_inner_cond) * (log_pdf_mid - self.factor)
inner_cond = self.cast(y > 0.999, ms.float32) inner_cond = self.cast(y > 0.999, ms.float32)
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
@ -192,15 +249,19 @@ class mix_gaussian_loss(nn.Cell):
self.maximum = P.Maximum() self.maximum = P.Maximum()
self.tile = P.Tile() self.tile = P.Tile()
self.exp = P.Exp() self.exp = P.Exp()
self.logsoftmax = P.LogSoftmax(-1)
self.expand_dims = P.ExpandDims() self.expand_dims = P.ExpandDims()
self.sums = P.ReduceSum() self.sums = P.ReduceSum()
self.lse = log_sum_exp() self.lse = log_sum_exp()
self.sq = P.Square() self.sq = P.Square()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.log = P.Log() self.log = P.Log()
self.tensor_one = Tensor(1., ms.float32)
if context.get_context("device_target") == "CPU":
self.logsoftmax = log_softmax()
else:
self.logsoftmax = P.LogSoftmax(-1)
def construct(self, y_hat, y): def construct(self, y_hat, y):
""" """
@ -225,12 +286,14 @@ class mix_gaussian_loss(nn.Cell):
if C == 2: if C == 2:
logit_probs = None logit_probs = None
means = y_hat[:, :, 0:1] means = y_hat[:, :, 0:1]
log_scales = self.maximum(y_hat[:, :, 1:2], self.log_scale_min) min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], 1))
log_scales = self.maximum(y_hat[:, :, 1:2], min_cut)
else: else:
# (B, T, num_mixtures) x 3 # (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix] logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix] means = y_hat[:, :, nr_mix:2 * nr_mix]
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min) min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix))
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut)
# B x T x 1 -> B x T x num_mixtures # B x T x 1 -> B x T x num_mixtures
y = self.tile(y, (1, 1, nr_mix)) y = self.tile(y, (1, 1, nr_mix))

Loading…
Cancel
Save