Add dscnn network to modelzoo.

pull/7456/head
zhanghuiyao 5 years ago
parent cb61cfd07c
commit 4637744100

File diff suppressed because it is too large Load Diff

@ -0,0 +1,112 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN eval."""
import os
import datetime
import glob
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor, Model
from mindspore.common import dtype as mstype
from src.config import eval_config
from src.log import get_logger
from src.dataset import audio_dataset
from src.ds_cnn import DSCNN
from src.models import load_ckpt
def get_top5_acc(top5_arg, gt_class):
sub_count = 0
for top5, gt in zip(top5_arg, gt_class):
if gt in top5:
sub_count += 1
return sub_count
def val(args, model, test_de):
'''Eval.'''
eval_dataloader = test_de.create_tuple_iterator()
img_tot = 0
top1_correct = 0
top5_correct = 0
for data, gt_classes in eval_dataloader:
output = model.predict(Tensor(data, mstype.float32))
output = output.asnumpy()
top1_output = np.argmax(output, (-1))
top5_output = np.argsort(output)[:, -5:]
gt_classes = gt_classes.asnumpy()
t1_correct = np.equal(top1_output, gt_classes).sum()
top1_correct += t1_correct
top5_correct += get_top5_acc(top5_output, gt_classes)
img_tot += output.shape[0]
results = [[top1_correct], [top5_correct], [img_tot]]
results = np.array(results)
top1_correct = results[0, 0]
top5_correct = results[1, 0]
img_tot = results[2, 0]
acc1 = 100.0 * top1_correct / img_tot
acc5 = 100.0 * top5_correct / img_tot
if acc1 > args.best_acc:
args.best_acc = acc1
args.best_index = args.index
args.logger.info('Eval: top1_cor:{}, top5_cor:{}, tot:{}, acc@1={:.2f}%, acc@5={:.2f}%' \
.format(top1_correct, top5_correct, img_tot, acc1, acc5))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
args, model_settings = eval_config(parser)
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", device_id=args.device_id)
# Logger
args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir)
# show args
args.logger.save_args(args)
# find model path
if os.path.isdir(args.model_dir):
models = list(glob.glob(os.path.join(args.model_dir, '*.ckpt')))
print(models)
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[0].split('epoch')[-1])
args.models = sorted(models, key=f)
else:
args.models = [args.model_dir]
args.best_acc = 0
args.index = 0
args.best_index = 0
for model_path in args.models:
test_de = audio_dataset(args.feat_dir, 'testing', model_settings['spectrogram_length'],
model_settings['dct_coefficient_count'], args.per_batch_size)
network = DSCNN(model_settings, args.model_size_info)
load_ckpt(network, model_path, False)
network.set_train(False)
model = Model(network)
args.logger.info('load model {} success'.format(model_path))
val(args, model, test_de)
args.index += 1
args.logger.info('Best model:{} acc:{:.2f}%'.format(args.models[args.best_index], args.best_acc))
if __name__ == "__main__":
main()

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN export."""
import argparse
import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import export
from src.config import eval_config
from src.ds_cnn import DSCNN
from src.models import load_ckpt
parser = argparse.ArgumentParser()
args, model_settings = eval_config(parser)
network = DSCNN(model_settings, args.model_size_info)
load_ckpt(network, args.model_dir, False)
x = np.random.uniform(0.0, 1.0, size=[1, 1, model_settings['spectrogram_length'],
model_settings['dct_coefficient_count']]).astype(np.float32)
export(network, Tensor(x), file_name=args.model_dir.replace('.ckpt', '.air'), file_format='AIR')

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
python src/download_process_data.py

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
python eval.py --device_id $1 --model_dir $2

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
python train.py --device_id $1

@ -0,0 +1,88 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Callback."""
import time
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig, Callback
class ProgressMonitor(Callback):
'''Progress Monitor.'''
def __init__(self, args):
super(ProgressMonitor, self).__init__()
self.args = args
self.epoch_start_time = 0
self.step_start_time = 0
self.globe_step_cnt = 0
self.local_step_cnt = 0
self.ckpt_history = []
def begin(self, run_context):
if not self.args.epoch_cnt:
self.args.logger.info('start network train...')
if run_context is None:
pass
def step_begin(self, run_context):
if self.local_step_cnt == 0:
self.step_start_time = time.time()
if run_context is None:
pass
def step_end(self, run_context):
'''Callback when step end.'''
if self.local_step_cnt % self.args.log_interval == 0 and self.local_step_cnt > 0:
cb_params = run_context.original_args()
time_used = time.time() - self.step_start_time
fps_mean = self.args.per_batch_size * self.args.log_interval / time_used
self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt,
self.globe_step_cnt +
self.local_step_cnt,
cb_params.net_outputs,
fps_mean))
self.step_start_time = time.time()
self.local_step_cnt += 1
def epoch_begin(self, run_context):
self.epoch_start_time = time.time()
if run_context is None:
pass
def epoch_end(self, run_context):
'''Callback when epoch end.'''
cb_params = run_context.original_args()
self.globe_step_cnt = self.args.steps_per_epoch * (self.args.epoch_cnt + 1) - 1
time_used = time.time() - self.epoch_start_time
fps_mean = self.args.per_batch_size * self.args.steps_per_epoch / time_used
self.args.logger.info(
'epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt, self.globe_step_cnt,
cb_params.net_outputs, fps_mean))
self.args.epoch_cnt += 1
self.local_step_cnt = 0
def end(self, run_context):
pass
def callback_func(args, cb, prefix):
callbacks = [cb]
if args.rank_save_ckpt_flag:
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.outputs_dir, prefix=prefix)
callbacks.append(ckpt_cb)
return callbacks

@ -0,0 +1,161 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Config setting, will be used in train.py and eval.py"""
from src.utils import prepare_words_list
def data_config(parser):
'''config for data.'''
parser.add_argument('--data_url', type=str,
default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
help='Location of speech training data archive on the web.')
parser.add_argument('--data_dir', type=str, default='data',
help='Where to download the dataset.')
parser.add_argument('--feat_dir', type=str, default='feat',
help='Where to save the feature of audios')
parser.add_argument('--background_volume', type=float, default=0.1,
help='How loud the background noise should be, between 0 and 1.')
parser.add_argument('--background_frequency', type=float, default=0.8,
help='How many of the training samples have background noise mixed in.')
parser.add_argument('--silence_percentage', type=float, default=10.0,
help='How much of the training data should be silence.')
parser.add_argument('--unknown_percentage', type=float, default=10.0,
help='How much of the training data should be unknown words.')
parser.add_argument('--time_shift_ms', type=float, default=100.0,
help='Range to randomly shift the training audio by in time.')
parser.add_argument('--testing_percentage', type=int, default=10,
help='What percentage of wavs to use as a test set.')
parser.add_argument('--validation_percentage', type=int, default=10,
help='What percentage of wavs to use as a validation set.')
parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go',
help='Words to use (others will be added to an unknown label)')
parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs')
parser.add_argument('--clip_duration_ms', type=int, default=1000,
help='Expected duration in milliseconds of the wavs')
parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is')
parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is')
parser.add_argument('--dct_coefficient_count', type=int, default=20,
help='How many bins to use for the MFCC fingerprint')
def train_config(parser):
'''config for train.'''
data_config(parser)
# network related
parser.add_argument('--model_size_info', type=int, nargs="+",
default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1,
1, 276, 3, 3, 1, 1],
help='Model dimensions - different for various models')
parser.add_argument('--drop', type=float, default=0.9, help='dropout')
parser.add_argument('--pretrained', type=str, default='', help='model_path, local pretrained model to load')
# training related
parser.add_argument('--use_graph_mode', default=1, type=int, help='use graph mode or feed mode')
parser.add_argument('--val_interval', type=int, default=1, help='validate interval')
# dataset related
parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu')
# optimizer and lr related
parser.add_argument('--lr_scheduler', default='multistep', type=str,
help='lr-scheduler, option type: multistep, cosine_annealing')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate of the training')
parser.add_argument('--lr_epochs', type=str, default='20,40,60,80', help='epoch of lr changing')
parser.add_argument('--lr_gamma', type=float, default=0.1,
help='decrease lr by a factor of exponential lr_scheduler')
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
parser.add_argument('--T_max', type=int, default=80, help='T-max in cosine_annealing scheduler')
parser.add_argument('--max_epoch', type=int, default=80, help='max epoch num to train the model')
parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.98, help='momentum')
# logging related
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
parser.add_argument('--ckpt_path', type=str, default='train_outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=100, help='save ckpt_interval')
flags, _ = parser.parse_known_args()
flags.dataset_sink_mode = bool(flags.use_graph_mode)
flags.lr_epochs = list(map(int, flags.lr_epochs.split(',')))
model_settings = prepare_model_settings(
len(prepare_words_list(flags.wanted_words.split(','))),
flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms,
flags.window_stride_ms, flags.dct_coefficient_count)
model_settings['dropout1'] = flags.drop
return flags, model_settings
def eval_config(parser):
'''config for eval.'''
parser.add_argument('--feat_dir', type=str, default='feat',
help='Where to save the feature of audios')
parser.add_argument('--model_dir', type=str,
default='outputs',
help='which folder the models are saved in or specific path of one model')
parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go',
help='Words to use (others will be added to an unknown label)')
parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs')
parser.add_argument('--clip_duration_ms', type=int, default=1000,
help='Expected duration in milliseconds of the wavs')
parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is')
parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is')
parser.add_argument('--dct_coefficient_count', type=int, default=20,
help='How many bins to use for the MFCC fingerprint')
parser.add_argument('--model_size_info', type=int, nargs="+",
default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1,
1, 276, 3, 3, 1, 1],
help='Model dimensions - different for various models')
parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu')
parser.add_argument('--drop', type=float, default=0.9, help='dropout')
# logging related
parser.add_argument('--log_path', type=str, default='eval_outputs/', help='path to save eval log')
flags, _ = parser.parse_known_args()
model_settings = prepare_model_settings(
len(prepare_words_list(flags.wanted_words.split(','))),
flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms,
flags.window_stride_ms, flags.dct_coefficient_count)
model_settings['dropout1'] = flags.drop
return flags, model_settings
def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
window_size_ms, window_stride_ms,
dct_coefficient_count):
'''Prepare model setting.'''
desired_samples = int(sample_rate * clip_duration_ms / 1000)
window_size_samples = int(sample_rate * window_size_ms / 1000)
window_stride_samples = int(sample_rate * window_stride_ms / 1000)
length_minus_window = (desired_samples - window_size_samples)
if length_minus_window < 0:
spectrogram_length = 0
else:
spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
fingerprint_size = dct_coefficient_count * spectrogram_length
return {
'desired_samples': desired_samples,
'window_size_samples': window_size_samples,
'window_stride_samples': window_stride_samples,
'spectrogram_length': spectrogram_length,
'dct_coefficient_count': dct_coefficient_count,
'fingerprint_size': fingerprint_size,
'label_count': label_count,
'sample_rate': sample_rate,
}

@ -0,0 +1,47 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN dataset."""
import os
import numpy as np
import mindspore.dataset as de
class NpyDataset():
'''Dataset from numpy.'''
def __init__(self, data_dir, data_type, h, w):
super(NpyDataset, self).__init__()
self.data = np.load(os.path.join(data_dir, '{}_data.npy'.format(data_type)))
self.data = np.reshape(self.data, (-1, 1, h, w))
self.label = np.load(os.path.join(data_dir, '{}_label.npy'.format(data_type)))
def __len__(self):
return self.data.shape[0]
def __getitem__(self, item):
data = self.data[item]
label = self.label[item]
# return data, label
return data.astype(np.float32), label.astype(np.int32)
def audio_dataset(data_dir, data_type, h, w, batch_size):
if 'testing' in data_dir:
shuffle = False
else:
shuffle = True
dataset = NpyDataset(data_dir, data_type, h, w)
de_dataset = de.GeneratorDataset(dataset, ["feats", "labels"], shuffle=shuffle)
de_dataset = de_dataset.batch(batch_size, drop_remainder=False)
return de_dataset

File diff suppressed because it is too large Load Diff

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN network."""
import math
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from mindspore import Parameter
class DepthWiseConv(nn.Cell):
'''Build DepthWise conv.'''
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
super(DepthWiseConv, self).__init__()
self.has_bias = has_bias
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
stride=stride, pad_mode=pad_mode, pad=pad)
self.bias_add = P.BiasAdd()
weight_shape = [channel_multiplier, in_planes, kernel_size[0], kernel_size[1]]
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
if has_bias:
bias_shape = [channel_multiplier * in_planes]
self.bias = Parameter(initializer('zeros', bias_shape), name='bias')
else:
self.bias = None
def construct(self, x):
output = self.depthwise_conv(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
class DSCNN(nn.Cell):
'''Build DSCNN network.'''
def __init__(self, model_settings, model_size_info):
super(DSCNN, self).__init__()
# N C H W
label_count = model_settings['label_count']
input_frequency_size = model_settings['dct_coefficient_count']
input_time_size = model_settings['spectrogram_length']
t_dim = input_time_size
f_dim = input_frequency_size
num_layers = model_size_info[0]
conv_feat = [None] * num_layers
conv_kt = [None] * num_layers
conv_kf = [None] * num_layers
conv_st = [None] * num_layers
conv_sf = [None] * num_layers
i = 1
for layer_no in range(0, num_layers):
conv_feat[layer_no] = model_size_info[i]
i += 1
conv_kt[layer_no] = model_size_info[i]
i += 1
conv_kf[layer_no] = model_size_info[i]
i += 1
conv_st[layer_no] = model_size_info[i]
i += 1
conv_sf[layer_no] = model_size_info[i]
i += 1
seq_cell = []
in_channel = 1
for layer_no in range(0, num_layers):
if layer_no == 0:
seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no],
kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
stride=(conv_st[layer_no], conv_sf[layer_no]),
pad_mode="same", padding=0, has_bias=False))
seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98))
in_channel = conv_feat[layer_no]
else:
seq_cell.append(DepthWiseConv(in_planes=in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same', pad=0))
seq_cell.append(nn.BatchNorm2d(num_features=in_channel, momentum=0.98))
seq_cell.append(nn.ReLU())
seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no], kernel_size=(1, 1),
pad_mode="same"))
seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98))
seq_cell.append(nn.ReLU())
in_channel = conv_feat[layer_no]
t_dim = math.ceil(t_dim / float(conv_st[layer_no]))
f_dim = math.ceil(f_dim / float(conv_sf[layer_no]))
seq_cell.append(nn.AvgPool2d(kernel_size=(t_dim, f_dim))) # to fix ?
seq_cell.append(nn.Flatten())
seq_cell.append(nn.Dropout(model_settings['dropout1']))
seq_cell.append(nn.Dense(in_channel, label_count))
self.model = nn.SequentialCell(seq_cell)
def construct(self, x):
x = self.model(x)
return x

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Logger."""
import os
import sys
import logging
from datetime import datetime
logger_name_1 = 'ds-cnn'
class LOGGER(logging.Logger):
'''Build logger.'''
def __init__(self, logger_name):
super(LOGGER, self).__init__(logger_name)
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir):
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
line_width = 2
important_msg = '\n'
important_msg += ('*' * 70 + '\n') * line_width
important_msg += ('*' * line_width + '\n') * 2
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
important_msg += ('*' * line_width + '\n') * 2
important_msg += ('*' * 70 + '\n') * line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path):
logger = LOGGER(logger_name_1)
logger.setup_logging_file(path)
return logger

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN loss."""
import mindspore.nn as nn
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
class CrossEntropy(_Loss):
'''Build CrossEntropy Loss.'''
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label,
F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss

File diff suppressed because it is too large Load Diff

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN models."""
from mindspore.train.serialization import load_checkpoint, load_param_into_net
def load_ckpt(network, pretrain_ckpt_path, trainable=True):
"""
incremental_learning or not
"""
param_dict = load_checkpoint(pretrain_ckpt_path)
load_param_into_net(network, param_dict)
if not trainable:
for param in network.get_parameters():
param.requires_grad = False

@ -0,0 +1,74 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Logger."""
import os
import sys
import logging
from datetime import datetime
logger_name_1 = 'ds-cnn'
class LOGGER(logging.Logger):
'''Build logger.'''
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
'''Setup logging file.'''
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*' * 70 + '\n') * line_width
important_msg += ('*' * line_width + '\n') * 2
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
important_msg += ('*' * line_width + '\n') * 2
important_msg += ('*' * 70 + '\n') * line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER(logger_name_1, rank)
logger.setup_logging_file(path, rank)
return logger

@ -0,0 +1,19 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Utils."""
SILENCE_LABEL = '_silence_'
UNKNOWN_WORD_LABEL = '_unknown_'
def prepare_words_list(wanted_words):
return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words

@ -0,0 +1,150 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN train."""
import os
import datetime
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor, Model
from mindspore.nn.optim import Momentum
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint
from src.config import train_config
from src.log import get_logger
from src.dataset import audio_dataset
from src.ds_cnn import DSCNN
from src.loss import CrossEntropy
from src.lr_scheduler import MultiStepLR, CosineAnnealingLR
from src.callback import ProgressMonitor, callback_func
def get_top5_acc(top5_arg, gt_class):
sub_count = 0
for top5, gt in zip(top5_arg, gt_class):
if gt in top5:
sub_count += 1
return sub_count
def val(args, model, val_dataset):
'''Eval.'''
val_dataloader = val_dataset.create_tuple_iterator()
img_tot = 0
top1_correct = 0
top5_correct = 0
for data, gt_classes in val_dataloader:
output = model.predict(Tensor(data, mstype.float32))
output = output.asnumpy()
top1_output = np.argmax(output, (-1))
top5_output = np.argsort(output)[:, -5:]
gt_classes = gt_classes.asnumpy()
t1_correct = np.equal(top1_output, gt_classes).sum()
top1_correct += t1_correct
top5_correct += get_top5_acc(top5_output, gt_classes)
img_tot += output.shape[0]
results = [[top1_correct], [top5_correct], [img_tot]]
results = np.array(results)
top1_correct = results[0, 0]
top5_correct = results[1, 0]
img_tot = results[2, 0]
acc1 = 100.0 * top1_correct / img_tot
acc5 = 100.0 * top5_correct / img_tot
if acc1 > args.best_acc:
args.best_acc = acc1
args.best_epoch = args.epoch_cnt - 1
args.logger.info('Eval: top1_cor:{}, top5_cor:{}, tot:{}, acc@1={:.2f}%, acc@5={:.2f}%' \
.format(top1_correct, top5_correct, img_tot, acc1, acc5))
def trainval(args, model, train_dataset, val_dataset, cb):
callbacks = callback_func(args, cb, 'epoch{}'.format(args.epoch_cnt))
model.train(args.val_interval, train_dataset, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode)
val(args, model, val_dataset)
def train():
'''Train.'''
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
args, model_settings = train_config(parser)
context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id, enable_auto_mixed_precision=True)
args.rank_save_ckpt_flag = 1
# Logger
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir)
# Dataloader: train, val
train_dataset = audio_dataset(args.feat_dir, 'training', model_settings['spectrogram_length'],
model_settings['dct_coefficient_count'], args.per_batch_size)
args.steps_per_epoch = train_dataset.get_dataset_size()
val_dataset = audio_dataset(args.feat_dir, 'validation', model_settings['spectrogram_length'],
model_settings['dct_coefficient_count'], args.per_batch_size)
# show args
args.logger.save_args(args)
# Network
args.logger.important_info('start create network')
network = DSCNN(model_settings, args.model_size_info)
# Load pretrain model
if os.path.isfile(args.pretrained):
load_checkpoint(args.pretrained, network)
args.logger.info('load model {} success'.format(args.pretrained))
# Loss
criterion = CrossEntropy(num_classes=model_settings['label_count'])
# LR scheduler
if args.lr_scheduler == 'multistep':
lr_scheduler = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma, args.steps_per_epoch,
args.max_epoch, warmup_epochs=args.warmup_epochs)
elif args.lr_scheduler == 'cosine_annealing':
lr_scheduler = CosineAnnealingLR(args.lr, args.T_max, args.steps_per_epoch, args.max_epoch,
warmup_epochs=args.warmup_epochs, eta_min=args.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
lr_schedule = lr_scheduler.get_lr()
# Optimizer
opt = Momentum(params=network.trainable_params(),
learning_rate=Tensor(lr_schedule),
momentum=args.momentum,
weight_decay=args.weight_decay)
model = Model(network, loss_fn=criterion, optimizer=opt, amp_level='O0')
# Training
args.epoch_cnt = 0
args.best_epoch = 0
args.best_acc = 0
progress_cb = ProgressMonitor(args)
while args.epoch_cnt + args.val_interval < args.max_epoch:
trainval(args, model, train_dataset, val_dataset, progress_cb)
rest_ep = args.max_epoch - args.epoch_cnt
if rest_ep > 0:
trainval(args, model, train_dataset, val_dataset, progress_cb)
args.logger.info('Best epoch:{} acc:{:.2f}%'.format(args.best_epoch, args.best_acc))
if __name__ == "__main__":
train()
Loading…
Cancel
Save