mod_deepspech

pull/11742/head
wanyiming 4 years ago
parent 8e1a556e67
commit cdf66c3ae7

@ -192,7 +192,25 @@ dataset directory structure is as follows:
```
The three *.csv file stores the absolute path of the corresponding
data. The three*.csv files will be used in training and evaluation process.
data. After obtaining the 3 csv file, you should modify the configurations in `src/config.py`.
For training config, the train_manifest should be configured with the path of `libri_train_manifest.csv` and for eval config, it should be configured
with `libri_test_other_manifest.csv` or `libri_train_manifest.csv`, depending on which dataset is evaluated.
```shell
...
for training configuration
"DataConfig":{
train_manifest:'path_to_csv/libri_train_manifest.csv'
}
for evaluation configuration
"DataConfig":{
train_manifest:'path_to_csv/libri_test_clean_manifest.csv'
}
```
The three*.csv files will be used in training and evaluation process. Before training, some requirements should be installed, including `librosa` and `Levenshtein`
After installing MindSpore via the official website and finishing dataset processing, you can start training as follows:
```shell
@ -201,7 +219,7 @@ After installing MindSpore via the official website and finishing dataset proces
CUDA_VISIBLE_DEVICES='0' python train.py
# distributed training
CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py --is_distributed=True > log 2>&1 &
CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py --is_distributed > log 2>&1 &
```
@ -253,8 +271,8 @@ python export.py --pre_trained_model_path='ckpt_path'
| Dataset | LibriSpeech |
| batch_size | 20 |
| outputs | probability |
| Accuracy(test-clean) | WER: 9.732 CER: 3.270|
| Accuracy(test-others) | WER: 28.198 CER: 12.253|
| Accuracy(test-clean) | 2p: WER: 9.902 CER: 3.317 8p: WER: 11.593 CER: 3.907|
| Accuracy(test-others) | 2p: WER: 28.693 CER: 12.473 8p: WER: 31.397 CER: 13.696|
| Model for inference | 330M (.mindir file) |
# [ModelZoo Homepage](#contents)

@ -81,6 +81,7 @@ if __name__ == '__main__':
last_id += 1
start = count
count += 1
split_targets.append(list(targets[start:]))
out, output_sizes = model(inputs, input_length)
decoded_output, _ = decoder.decode(out, output_sizes)
target_strings = target_decoder.convert_to_strings(split_targets)

@ -31,10 +31,11 @@ def get_lr(lr_init, total_epochs, steps_per_epoch):
"""
lr_each_step = []
half_epoch = total_epochs // 2
for i in range(total_epochs * steps_per_epoch):
if i < half_epoch:
lr_each_step.append(lr_init)
else:
lr_each_step.append(lr_init / (1.1 ** (i - half_epoch)))
for i in range(total_epochs):
for _ in range(steps_per_epoch):
if i < half_epoch:
lr_each_step.append(lr_init)
else:
lr_each_step.append(lr_init / (1.1 ** (i - half_epoch)))
learning_rate = np.array(lr_each_step).astype(np.float32)
return learning_rate

@ -81,8 +81,8 @@ if __name__ == '__main__':
optimizer = Adam(weights, learning_rate=config.OptimConfig.learning_rate, eps=config.OptimConfig.epsilon,
loss_scale=config.OptimConfig.loss_scale)
train_net = TrainOneStepCell(loss_net, optimizer)
if args.pre_trained_model_path is not None:
train_net.set_train(True)
if args.pre_trained_model_path != '':
param_dict = load_checkpoint(args.pre_trained_model_path)
load_param_into_net(train_net, param_dict)
print('Successfully loading the pre-trained model')

Loading…
Cancel
Save