pull/10862/head
wanyiming 4 years ago
parent 142f9c2d3e
commit b6b2239ffe

@ -76,6 +76,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [AutoDis](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis/README.md)
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
- [DeepSpeech2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2/README.md)
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md)
- [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,112 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""
Eval DeepSpeech2
"""
import argparse
import json
import pickle
import numpy as np
from src.config import eval_config
from src.deepspeech2 import DeepSpeechModel, PredictWithSoftmax
from src.dataset import create_dataset
from src.greedydecoder import MSGreedyDecoder
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
parser = argparse.ArgumentParser(description='DeepSpeech evaluation')
parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN')
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
args = parser.parse_args()
if __name__ == '__main__':
config = eval_config
with open(config.DataConfig.labels_path) as label_file:
labels = json.load(label_file)
model = PredictWithSoftmax(DeepSpeechModel(batch_size=config.DataConfig.batch_size,
rnn_hidden_size=config.ModelConfig.hidden_size,
nb_layers=config.ModelConfig.hidden_layers,
labels=labels,
rnn_type=config.ModelConfig.rnn_type,
audio_conf=config.DataConfig.SpectConfig,
bidirectional=args.bidirectional))
ds_eval = create_dataset(audio_conf=config.DataConfig.SpectConfig,
manifest_filepath=config.DataConfig.test_manifest,
labels=labels, normalize=True, train_mode=False,
batch_size=config.DataConfig.batch_size, rank=0, group_size=1)
param_dict = load_checkpoint(args.pretrain_ckpt)
load_param_into_net(model, param_dict)
print('Successfully loading the pre-trained model')
if config.LMConfig.decoder_type == 'greedy':
decoder = MSGreedyDecoder(labels=labels, blank_index=labels.index('_'))
else:
raise NotImplementedError("Only greedy decoder is supported now")
target_decoder = MSGreedyDecoder(labels, blank_index=labels.index('_'))
model.set_train(False)
total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
output_data = []
for data in ds_eval.create_dict_iterator():
inputs, input_length, target_indices, targets = data['inputs'], data['input_length'], data['target_indices'], \
data['label_values']
split_targets = []
start, count, last_id = 0, 0, 0
target_indices, targets = target_indices.asnumpy(), targets.asnumpy()
for i in range(np.shape(targets)[0]):
if target_indices[i, 0] == last_id:
count += 1
else:
split_targets.append(list(targets[start:count]))
last_id += 1
start = count
count += 1
out, output_sizes = model(inputs, input_length)
decoded_output, _ = decoder.decode(out, output_sizes)
target_strings = target_decoder.convert_to_strings(split_targets)
if config.save_output is not None:
output_data.append((out.asnumpy(), output_sizes.asnumpy(), target_strings))
for doutput, toutput in zip(decoded_output, target_strings):
transcript, reference = doutput[0], toutput[0]
wer_inst = decoder.wer(transcript, reference)
cer_inst = decoder.cer(transcript, reference)
total_wer += wer_inst
total_cer += cer_inst
num_tokens += len(reference.split())
num_chars += len(reference.replace(' ', ''))
if config.verbose:
print("Ref:", reference.lower())
print("Hyp:", transcript.lower())
print("WER:", float(wer_inst) / len(reference.split()),
"CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n")
wer = float(total_wer) / num_tokens
cer = float(total_cer) / num_chars
print('Test Summary \t'
'Average WER {wer:.3f}\t'
'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))
if config.save_output is not None:
with open(config.save_output + '.bin', 'wb') as output:
pickle.dump(output_data, output)

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
export checkpoint file to mindir model
"""
import json
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.deepspeech2 import DeepSpeechModel
from src.config import train_config
parser = argparse.ArgumentParser(description='Export DeepSpeech model to Mindir')
parser.add_argument('--pre_trained_model_path', type=str, default='', help=' existed checkpoint path')
args = parser.parse_args()
if __name__ == '__main__':
config = train_config
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
with open(config.DataConfig.labels_path) as label_file:
labels = json.load(label_file)
deepspeech_net = DeepSpeechModel(batch_size=1,
rnn_hidden_size=config.ModelConfig.hidden_size,
nb_layers=config.ModelConfig.hidden_layers,
labels=labels,
rnn_type=config.ModelConfig.rnn_type,
audio_conf=config.DataConfig.SpectConfig,
bidirectional=True)
param_dict = load_checkpoint(args.pre_trained_model_path)
load_param_into_net(deepspeech_net, param_dict)
print('Successfully loading the pre-trained model')
# 3500 is the max length in evaluation dataset(LibriSpeech). This is consistent with that in dataset.py
# The length is fixed to this value because Mindspore does not support dynamic shape currently
input_np = np.random.uniform(0.0, 1.0, size=[1, 1, 161, 3500]).astype(np.float32)
length = np.array([15], dtype=np.int32)
export(deepspeech_net, Tensor(input_np), Tensor(length), file_name="deepspeech2.mindir", file_format='MINDIR')

@ -0,0 +1,31 @@
[
"'",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
" ",
"_"
]

@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Defined callback for DeepSpeech.
"""
import time
from mindspore.train.callback import Callback
from mindspore import Tensor
import numpy as np
class TimeMonitor(Callback):
"""
Time monitor for calculating cost of each epoch.
Args
data_size (int) step size of an epoch.
"""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / self.data_size
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
step_mseconds = (time.time() - self.step_time) * 1000
print(f"step time {step_mseconds}", flush=True)
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
"""
def __init__(self, lr_init=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time)
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
"""
Args:
run_context:
Returns:
"""
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time)
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:.9f}]".format(
cb_params.cur_epoch_num -
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1].asnumpy()))

@ -0,0 +1,113 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
train_config = ed({
"TrainingConfig": {
"epochs": 70,
},
"DataConfig": {
"train_manifest": 'data/libri_train_manifest.csv',
# "val_manifest": 'data/libri_val_manifest.csv',
"batch_size": 20,
"labels_path": "labels.json",
"SpectConfig": {
"sample_rate": 16000,
"window_size": 0.02,
"window_stride": 0.01,
"window": "hamming"
},
"AugmentationConfig": {
"speed_volume_perturb": False,
"spec_augment": False,
"noise_dir": '',
"noise_prob": 0.4,
"noise_min": 0.0,
"noise_max": 0.5,
}
},
"ModelConfig": {
"rnn_type": "LSTM",
"hidden_size": 1024,
"hidden_layers": 5,
"lookahead_context": 20,
},
"OptimConfig": {
"learning_rate": 3e-4,
"learning_anneal": 1.1,
"weight_decay": 1e-5,
"momentum": 0.9,
"eps": 1e-8,
"betas": (0.9, 0.999),
"loss_scale": 1024,
"epsilon": 0.00001
},
"CheckpointConfig": {
"ckpt_file_name_prefix": 'DeepSpeech',
"ckpt_path": './checkpoint',
"keep_checkpoint_max": 10
}
})
eval_config = ed({
"save_output": 'librispeech_val_output',
"verbose": True,
"DataConfig": {
"test_manifest": 'data/libri_test_clean_manifest.csv',
# "test_manifest": 'data/libri_test_other_manifest.csv',
# "test_manifest": 'data/libri_val_manifest.csv',
"batch_size": 20,
"labels_path": "labels.json",
"SpectConfig": {
"sample_rate": 16000,
"window_size": 0.02,
"window_stride": 0.01,
"window": "hanning"
},
},
"ModelConfig": {
"rnn_type": "LSTM",
"hidden_size": 1024,
"hidden_layers": 5,
"lookahead_context": 20,
},
"LMConfig": {
"decoder_type": "greedy",
"lm_path": './3-gram.pruned.3e-7.arpa',
"top_paths": 1,
"alpha": 1.818182,
"beta": 0,
"cutoff_top_n": 40,
"cutoff_prob": 1.0,
"beam_width": 1024,
"lm_workers": 4
},
})

@ -0,0 +1,215 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Create train or eval dataset.
"""
import math
import numpy as np
import mindspore.dataset.engine as de
import librosa
import soundfile as sf
TRAIN_INPUT_PAD_LENGTH = 1501
TRAIN_LABEL_PAD_LENGTH = 350
TEST_INPUT_PAD_LENGTH = 3500
class LoadAudioAndTranscript():
"""
parse audio and transcript
"""
def __init__(self,
audio_conf=None,
normalize=False,
labels=None):
super(LoadAudioAndTranscript, self).__init__()
self.window_stride = audio_conf.window_stride
self.window_size = audio_conf.window_size
self.sample_rate = audio_conf.sample_rate
self.window = audio_conf.window
self.is_normalization = normalize
self.labels = labels
def load_audio(self, path):
"""
load audio
"""
sound, _ = sf.read(path, dtype='int16')
sound = sound.astype('float32') / 32767
if len(sound.shape) > 1:
if sound.shape[1] == 1:
sound = sound.squeeze()
else:
sound = sound.mean(axis=1)
return sound
def parse_audio(self, audio_path):
"""
parse audio
"""
audio = self.load_audio(audio_path)
n_fft = int(self.sample_rate * self.window_size)
win_length = n_fft
hop_length = int(self.sample_rate * self.window_stride)
D = librosa.stft(y=audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=self.window)
mag, _ = librosa.magphase(D)
mag = np.log1p(mag)
if self.is_normalization:
mean = mag.mean()
std = mag.std()
mag = (mag - mean) / std
return mag
def parse_transcript(self, transcript_path):
with open(transcript_path, 'r', encoding='utf8') as transcript_file:
transcript = transcript_file.read().replace('\n', '')
transcript = list(filter(None, [self.labels.get(x) for x in list(transcript)]))
return transcript
class ASRDataset(LoadAudioAndTranscript):
"""
create ASRDataset
Args:
audio_conf: Config containing the sample rate, window and the window length/stride in seconds
manifest_filepath (str): manifest_file path.
labels (list): List containing all the possible characters to map to
normalize: Apply standard mean and deviation normalization to audio tensor
batch_size (int): Dataset batch size (default=32)
"""
def __init__(self, audio_conf=None,
manifest_filepath='',
labels=None,
normalize=False,
batch_size=32,
is_training=True):
with open(manifest_filepath) as f:
ids = f.readlines()
ids = [x.strip().split(',') for x in ids]
self.is_training = is_training
self.ids = ids
self.blank_id = int(labels.index('_'))
self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)]
if len(self.ids) % batch_size != 0:
self.bins = self.bins[:-1]
self.bins.append(ids[-batch_size:])
self.size = len(self.bins)
self.batch_size = batch_size
self.labels_map = {labels[i]: i for i in range(len(labels))}
super(ASRDataset, self).__init__(audio_conf, normalize, self.labels_map)
def __getitem__(self, index):
batch_idx = self.bins[index]
batch_size = len(batch_idx)
batch_spect, batch_script, target_indices = [], [], []
input_length = np.zeros(batch_size, np.int32)
for data in batch_idx:
audio_path, transcript_path = data[0], data[1]
spect = self.parse_audio(audio_path)
transcript = self.parse_transcript(transcript_path)
batch_spect.append(spect)
batch_script.append(transcript)
freq_size = np.shape(batch_spect[-1])[0]
if self.is_training:
# 1501 is the max length in train dataset(LibriSpeech).
# The length is fixed to this value because Mindspore does not support dynamic shape currently
inputs = np.zeros((batch_size, 1, freq_size, TRAIN_INPUT_PAD_LENGTH), dtype=np.float32)
# The target length is fixed to this value because Mindspore does not support dynamic shape currently
# 350 may be greater than the max length of labels in train dataset(LibriSpeech).
targets = np.ones((self.batch_size, TRAIN_LABEL_PAD_LENGTH), dtype=np.int32) * self.blank_id
for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script):
seq_length = np.shape(spect_)[1]
input_length[k] = seq_length
script_length = len(scripts_)
targets[k, :script_length] = scripts_
for m in range(350):
target_indices.append([k, m])
inputs[k, 0, :, 0:seq_length] = spect_
targets = np.reshape(targets, (-1,))
else:
inputs = np.zeros((batch_size, 1, freq_size, TEST_INPUT_PAD_LENGTH), dtype=np.float32)
targets = []
for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script):
seq_length = np.shape(spect_)[1]
input_length[k] = seq_length
targets.extend(scripts_)
for m in range(len(scripts_)):
target_indices.append([k, m])
inputs[k, 0, :, 0:seq_length] = spect_
return inputs, input_length, np.array(target_indices, dtype=np.int64), np.array(targets, dtype=np.int32)
def __len__(self):
return self.size
class DistributedSampler():
"""
function to distribute and shuffle sample
"""
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
self.dataset = dataset
self.rank = rank
self.group_size = group_size
self.dataset_len = len(self.dataset)
self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
self.total_size = self.num_samplers * self.group_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
if self.shuffle:
self.seed = (self.seed + 1) & 0xffffffff
np.random.seed(self.seed)
indices = np.random.permutation(self.dataset_len).tolist()
else:
indices = list(range(self.dataset_len))
indices += indices[:(self.total_size - len(indices))]
indices = indices[self.rank::self.group_size]
return iter(indices)
def __len__(self):
return self.num_samplers
def create_dataset(audio_conf, manifest_filepath, labels, normalize, batch_size, train_mode=True,
rank=None, group_size=None):
"""
create train dataset
Args:
audio_conf: Config containing the sample rate, window and the window length/stride in seconds
manifest_filepath (str): manifest_file path.
labels (list): list containing all the possible characters to map to
normalize: Apply standard mean and deviation normalization to audio tensor
train_mode (bool): Whether dataset is use for train or eval (default=True).
batch_size (int): Dataset batch size
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided into (default=None).
Returns:
Dataset.
"""
dataset = ASRDataset(audio_conf=audio_conf, manifest_filepath=manifest_filepath, labels=labels, normalize=normalize,
batch_size=batch_size, is_training=train_mode)
sampler = DistributedSampler(dataset, rank, group_size, shuffle=True)
ds = de.GeneratorDataset(dataset, ["inputs", "input_length", "target_indices", "label_values"], sampler=sampler)
ds = ds.repeat(1)
return ds

File diff suppressed because it is too large Load Diff

@ -0,0 +1,52 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
modify GreedyDecoder to adapt to MindSpore
"""
import numpy as np
from deepspeech_pytorch.decoder import GreedyDecoder
class MSGreedyDecoder(GreedyDecoder):
"""
GreedyDecoder used for MindSpore
"""
def process_string(self, sequence, size, remove_repetitions=False):
"""
process string
"""
string = ''
offsets = []
for i in range(size):
char = self.int_to_char[sequence[i].item()]
if char != self.int_to_char[self.blank_index]:
if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]:
pass
elif char == self.labels[self.space_index]:
string += ' '
offsets.append(i)
else:
string = string + char
offsets.append(i)
return string, offsets
def decode(self, probs, sizes=None):
probs = probs.asnumpy()
sizes = sizes.asnumpy()
max_probs = np.argmax(probs, axis=-1)
strings, offsets = self.convert_to_strings(max_probs, sizes, remove_repetitions=True, return_offsets=True)
return strings, offsets

@ -0,0 +1,40 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import numpy as np
def get_lr(lr_init, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
half_epoch = total_epochs // 2
for i in range(total_epochs * steps_per_epoch):
if i < half_epoch:
lr_each_step.append(lr_init)
else:
lr_each_step.append(lr_init / (1.1 ** (i - half_epoch)))
learning_rate = np.array(lr_each_step).astype(np.float32)
return learning_rate

@ -0,0 +1,103 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_criteo."""
import os
import json
import argparse
from mindspore import context, Tensor, ParameterTuple
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Adam
from mindspore.nn import TrainOneStepCell
from mindspore.train import Model
from src.deepspeech2 import DeepSpeechModel, NetWithLossClass
from src.lr_generator import get_lr
from src.callback import Monitor
from src.config import train_config
from src.dataset import create_dataset
parser = argparse.ArgumentParser(description='DeepSpeech2 training')
parser.add_argument('--pre_trained_model_path', type=str, default='', help='Pretrained checkpoint path')
parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training')
parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN')
args = parser.parse_args()
if __name__ == '__main__':
rank_id = 0
group_size = 1
config = train_config
if args.is_distributed:
init('nccl')
rank_id = get_rank()
group_size = get_group_size()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
with open(config.DataConfig.labels_path) as label_file:
labels = json.load(label_file)
ds_train = create_dataset(audio_conf=config.DataConfig.SpectConfig,
manifest_filepath=config.DataConfig.train_manifest,
labels=labels, normalize=True, train_mode=True,
batch_size=config.DataConfig.batch_size, rank=rank_id, group_size=group_size)
steps_size = ds_train.get_dataset_size()
lr = get_lr(lr_init=config.OptimConfig.learning_rate, total_epochs=config.TrainingConfig.epochs,
steps_per_epoch=steps_size)
lr = Tensor(lr)
deepspeech_net = DeepSpeechModel(batch_size=config.DataConfig.batch_size,
rnn_hidden_size=config.ModelConfig.hidden_size,
nb_layers=config.ModelConfig.hidden_layers,
labels=labels,
rnn_type=config.ModelConfig.rnn_type,
audio_conf=config.DataConfig.SpectConfig,
bidirectional=True)
loss_net = NetWithLossClass(deepspeech_net)
weights = ParameterTuple(deepspeech_net.trainable_params())
optimizer = Adam(weights, learning_rate=config.OptimConfig.learning_rate, eps=config.OptimConfig.epsilon,
loss_scale=config.OptimConfig.loss_scale)
train_net = TrainOneStepCell(loss_net, optimizer)
if args.pre_trained_model_path is not None:
param_dict = load_checkpoint(args.pre_trained_model_path)
load_param_into_net(train_net, param_dict)
print('Successfully loading the pre-trained model')
model = Model(train_net)
lr_cb = Monitor(lr)
callback_list = [lr_cb]
if args.is_distributed:
config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank())
config.CheckpointConfig.ckpt_path = os.path.join(config.CheckpointConfig.ckpt_path,
'ckpt_' + str(get_rank()) + '/')
config_ck = CheckpointConfig(save_checkpoint_steps=1,
keep_checkpoint_max=config.CheckpointConfig.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=config.CheckpointConfig.ckpt_file_name_prefix,
directory=config.CheckpointConfig.ckpt_path, config=config_ck)
callback_list.append(ckpt_cb)
model.train(config.TrainingConfig.epochs, ds_train, callbacks=callback_list)
Loading…
Cancel
Save