parent
cb61cfd07c
commit
4637744100
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,112 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""DSCNN eval."""
|
||||
import os
|
||||
import datetime
|
||||
import glob
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, Model
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.config import eval_config
|
||||
from src.log import get_logger
|
||||
from src.dataset import audio_dataset
|
||||
from src.ds_cnn import DSCNN
|
||||
from src.models import load_ckpt
|
||||
|
||||
def get_top5_acc(top5_arg, gt_class):
|
||||
sub_count = 0
|
||||
for top5, gt in zip(top5_arg, gt_class):
|
||||
if gt in top5:
|
||||
sub_count += 1
|
||||
return sub_count
|
||||
|
||||
|
||||
def val(args, model, test_de):
|
||||
'''Eval.'''
|
||||
eval_dataloader = test_de.create_tuple_iterator()
|
||||
img_tot = 0
|
||||
top1_correct = 0
|
||||
top5_correct = 0
|
||||
for data, gt_classes in eval_dataloader:
|
||||
output = model.predict(Tensor(data, mstype.float32))
|
||||
output = output.asnumpy()
|
||||
top1_output = np.argmax(output, (-1))
|
||||
top5_output = np.argsort(output)[:, -5:]
|
||||
gt_classes = gt_classes.asnumpy()
|
||||
t1_correct = np.equal(top1_output, gt_classes).sum()
|
||||
top1_correct += t1_correct
|
||||
top5_correct += get_top5_acc(top5_output, gt_classes)
|
||||
img_tot += output.shape[0]
|
||||
|
||||
results = [[top1_correct], [top5_correct], [img_tot]]
|
||||
|
||||
results = np.array(results)
|
||||
|
||||
top1_correct = results[0, 0]
|
||||
top5_correct = results[1, 0]
|
||||
img_tot = results[2, 0]
|
||||
acc1 = 100.0 * top1_correct / img_tot
|
||||
acc5 = 100.0 * top5_correct / img_tot
|
||||
if acc1 > args.best_acc:
|
||||
args.best_acc = acc1
|
||||
args.best_index = args.index
|
||||
args.logger.info('Eval: top1_cor:{}, top5_cor:{}, tot:{}, acc@1={:.2f}%, acc@5={:.2f}%' \
|
||||
.format(top1_correct, top5_correct, img_tot, acc1, acc5))
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
|
||||
args, model_settings = eval_config(parser)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", device_id=args.device_id)
|
||||
|
||||
# Logger
|
||||
args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir)
|
||||
# show args
|
||||
args.logger.save_args(args)
|
||||
# find model path
|
||||
if os.path.isdir(args.model_dir):
|
||||
models = list(glob.glob(os.path.join(args.model_dir, '*.ckpt')))
|
||||
print(models)
|
||||
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[0].split('epoch')[-1])
|
||||
args.models = sorted(models, key=f)
|
||||
else:
|
||||
args.models = [args.model_dir]
|
||||
|
||||
args.best_acc = 0
|
||||
args.index = 0
|
||||
args.best_index = 0
|
||||
for model_path in args.models:
|
||||
test_de = audio_dataset(args.feat_dir, 'testing', model_settings['spectrogram_length'],
|
||||
model_settings['dct_coefficient_count'], args.per_batch_size)
|
||||
network = DSCNN(model_settings, args.model_size_info)
|
||||
|
||||
load_ckpt(network, model_path, False)
|
||||
network.set_train(False)
|
||||
model = Model(network)
|
||||
args.logger.info('load model {} success'.format(model_path))
|
||||
val(args, model, test_de)
|
||||
args.index += 1
|
||||
|
||||
args.logger.info('Best model:{} acc:{:.2f}%'.format(args.models[args.best_index], args.best_acc))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,33 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""DSCNN export."""
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import export
|
||||
|
||||
from src.config import eval_config
|
||||
from src.ds_cnn import DSCNN
|
||||
from src.models import load_ckpt
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
args, model_settings = eval_config(parser)
|
||||
network = DSCNN(model_settings, args.model_size_info)
|
||||
load_ckpt(network, args.model_dir, False)
|
||||
x = np.random.uniform(0.0, 1.0, size=[1, 1, model_settings['spectrogram_length'],
|
||||
model_settings['dct_coefficient_count']]).astype(np.float32)
|
||||
export(network, Tensor(x), file_name=args.model_dir.replace('.ckpt', '.air'), file_format='AIR')
|
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
python src/download_process_data.py
|
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
python eval.py --device_id $1 --model_dir $2
|
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
python train.py --device_id $1
|
@ -0,0 +1,88 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Callback."""
|
||||
import time
|
||||
|
||||
from mindspore.train.callback import ModelCheckpoint
|
||||
from mindspore.train.callback import CheckpointConfig, Callback
|
||||
|
||||
|
||||
class ProgressMonitor(Callback):
|
||||
'''Progress Monitor.'''
|
||||
def __init__(self, args):
|
||||
super(ProgressMonitor, self).__init__()
|
||||
self.args = args
|
||||
self.epoch_start_time = 0
|
||||
self.step_start_time = 0
|
||||
self.globe_step_cnt = 0
|
||||
self.local_step_cnt = 0
|
||||
self.ckpt_history = []
|
||||
|
||||
def begin(self, run_context):
|
||||
if not self.args.epoch_cnt:
|
||||
self.args.logger.info('start network train...')
|
||||
if run_context is None:
|
||||
pass
|
||||
|
||||
def step_begin(self, run_context):
|
||||
if self.local_step_cnt == 0:
|
||||
self.step_start_time = time.time()
|
||||
if run_context is None:
|
||||
pass
|
||||
|
||||
def step_end(self, run_context):
|
||||
'''Callback when step end.'''
|
||||
if self.local_step_cnt % self.args.log_interval == 0 and self.local_step_cnt > 0:
|
||||
cb_params = run_context.original_args()
|
||||
time_used = time.time() - self.step_start_time
|
||||
fps_mean = self.args.per_batch_size * self.args.log_interval / time_used
|
||||
self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt,
|
||||
self.globe_step_cnt +
|
||||
self.local_step_cnt,
|
||||
cb_params.net_outputs,
|
||||
fps_mean))
|
||||
self.step_start_time = time.time()
|
||||
self.local_step_cnt += 1
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_start_time = time.time()
|
||||
if run_context is None:
|
||||
pass
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
'''Callback when epoch end.'''
|
||||
cb_params = run_context.original_args()
|
||||
self.globe_step_cnt = self.args.steps_per_epoch * (self.args.epoch_cnt + 1) - 1
|
||||
|
||||
time_used = time.time() - self.epoch_start_time
|
||||
fps_mean = self.args.per_batch_size * self.args.steps_per_epoch / time_used
|
||||
self.args.logger.info(
|
||||
'epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt, self.globe_step_cnt,
|
||||
cb_params.net_outputs, fps_mean))
|
||||
self.args.epoch_cnt += 1
|
||||
self.local_step_cnt = 0
|
||||
|
||||
def end(self, run_context):
|
||||
pass
|
||||
|
||||
|
||||
def callback_func(args, cb, prefix):
|
||||
callbacks = [cb]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.outputs_dir, prefix=prefix)
|
||||
callbacks.append(ckpt_cb)
|
||||
return callbacks
|
@ -0,0 +1,161 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Config setting, will be used in train.py and eval.py"""
|
||||
from src.utils import prepare_words_list
|
||||
|
||||
def data_config(parser):
|
||||
'''config for data.'''
|
||||
|
||||
parser.add_argument('--data_url', type=str,
|
||||
default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
|
||||
help='Location of speech training data archive on the web.')
|
||||
parser.add_argument('--data_dir', type=str, default='data',
|
||||
help='Where to download the dataset.')
|
||||
parser.add_argument('--feat_dir', type=str, default='feat',
|
||||
help='Where to save the feature of audios')
|
||||
parser.add_argument('--background_volume', type=float, default=0.1,
|
||||
help='How loud the background noise should be, between 0 and 1.')
|
||||
parser.add_argument('--background_frequency', type=float, default=0.8,
|
||||
help='How many of the training samples have background noise mixed in.')
|
||||
parser.add_argument('--silence_percentage', type=float, default=10.0,
|
||||
help='How much of the training data should be silence.')
|
||||
parser.add_argument('--unknown_percentage', type=float, default=10.0,
|
||||
help='How much of the training data should be unknown words.')
|
||||
parser.add_argument('--time_shift_ms', type=float, default=100.0,
|
||||
help='Range to randomly shift the training audio by in time.')
|
||||
parser.add_argument('--testing_percentage', type=int, default=10,
|
||||
help='What percentage of wavs to use as a test set.')
|
||||
parser.add_argument('--validation_percentage', type=int, default=10,
|
||||
help='What percentage of wavs to use as a validation set.')
|
||||
parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go',
|
||||
help='Words to use (others will be added to an unknown label)')
|
||||
parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs')
|
||||
parser.add_argument('--clip_duration_ms', type=int, default=1000,
|
||||
help='Expected duration in milliseconds of the wavs')
|
||||
parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is')
|
||||
parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is')
|
||||
parser.add_argument('--dct_coefficient_count', type=int, default=20,
|
||||
help='How many bins to use for the MFCC fingerprint')
|
||||
|
||||
|
||||
def train_config(parser):
|
||||
'''config for train.'''
|
||||
data_config(parser)
|
||||
|
||||
# network related
|
||||
parser.add_argument('--model_size_info', type=int, nargs="+",
|
||||
default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1,
|
||||
1, 276, 3, 3, 1, 1],
|
||||
help='Model dimensions - different for various models')
|
||||
parser.add_argument('--drop', type=float, default=0.9, help='dropout')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='model_path, local pretrained model to load')
|
||||
|
||||
# training related
|
||||
parser.add_argument('--use_graph_mode', default=1, type=int, help='use graph mode or feed mode')
|
||||
parser.add_argument('--val_interval', type=int, default=1, help='validate interval')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu')
|
||||
|
||||
# optimizer and lr related
|
||||
parser.add_argument('--lr_scheduler', default='multistep', type=str,
|
||||
help='lr-scheduler, option type: multistep, cosine_annealing')
|
||||
parser.add_argument('--lr', default=0.1, type=float, help='learning rate of the training')
|
||||
parser.add_argument('--lr_epochs', type=str, default='20,40,60,80', help='epoch of lr changing')
|
||||
parser.add_argument('--lr_gamma', type=float, default=0.1,
|
||||
help='decrease lr by a factor of exponential lr_scheduler')
|
||||
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
|
||||
parser.add_argument('--T_max', type=int, default=80, help='T-max in cosine_annealing scheduler')
|
||||
parser.add_argument('--max_epoch', type=int, default=80, help='max epoch num to train the model')
|
||||
parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
|
||||
parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay')
|
||||
parser.add_argument('--momentum', type=float, default=0.98, help='momentum')
|
||||
|
||||
# logging related
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
|
||||
parser.add_argument('--ckpt_path', type=str, default='train_outputs/', help='checkpoint save location')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=100, help='save ckpt_interval')
|
||||
|
||||
flags, _ = parser.parse_known_args()
|
||||
flags.dataset_sink_mode = bool(flags.use_graph_mode)
|
||||
flags.lr_epochs = list(map(int, flags.lr_epochs.split(',')))
|
||||
|
||||
model_settings = prepare_model_settings(
|
||||
len(prepare_words_list(flags.wanted_words.split(','))),
|
||||
flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms,
|
||||
flags.window_stride_ms, flags.dct_coefficient_count)
|
||||
model_settings['dropout1'] = flags.drop
|
||||
return flags, model_settings
|
||||
|
||||
|
||||
def eval_config(parser):
|
||||
'''config for eval.'''
|
||||
parser.add_argument('--feat_dir', type=str, default='feat',
|
||||
help='Where to save the feature of audios')
|
||||
parser.add_argument('--model_dir', type=str,
|
||||
default='outputs',
|
||||
help='which folder the models are saved in or specific path of one model')
|
||||
parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go',
|
||||
help='Words to use (others will be added to an unknown label)')
|
||||
parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs')
|
||||
parser.add_argument('--clip_duration_ms', type=int, default=1000,
|
||||
help='Expected duration in milliseconds of the wavs')
|
||||
parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is')
|
||||
parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is')
|
||||
parser.add_argument('--dct_coefficient_count', type=int, default=20,
|
||||
help='How many bins to use for the MFCC fingerprint')
|
||||
parser.add_argument('--model_size_info', type=int, nargs="+",
|
||||
default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1,
|
||||
1, 276, 3, 3, 1, 1],
|
||||
help='Model dimensions - different for various models')
|
||||
|
||||
parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu')
|
||||
parser.add_argument('--drop', type=float, default=0.9, help='dropout')
|
||||
|
||||
# logging related
|
||||
parser.add_argument('--log_path', type=str, default='eval_outputs/', help='path to save eval log')
|
||||
|
||||
flags, _ = parser.parse_known_args()
|
||||
model_settings = prepare_model_settings(
|
||||
len(prepare_words_list(flags.wanted_words.split(','))),
|
||||
flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms,
|
||||
flags.window_stride_ms, flags.dct_coefficient_count)
|
||||
model_settings['dropout1'] = flags.drop
|
||||
return flags, model_settings
|
||||
|
||||
|
||||
def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
|
||||
window_size_ms, window_stride_ms,
|
||||
dct_coefficient_count):
|
||||
'''Prepare model setting.'''
|
||||
desired_samples = int(sample_rate * clip_duration_ms / 1000)
|
||||
window_size_samples = int(sample_rate * window_size_ms / 1000)
|
||||
window_stride_samples = int(sample_rate * window_stride_ms / 1000)
|
||||
length_minus_window = (desired_samples - window_size_samples)
|
||||
if length_minus_window < 0:
|
||||
spectrogram_length = 0
|
||||
else:
|
||||
spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
|
||||
fingerprint_size = dct_coefficient_count * spectrogram_length
|
||||
return {
|
||||
'desired_samples': desired_samples,
|
||||
'window_size_samples': window_size_samples,
|
||||
'window_stride_samples': window_stride_samples,
|
||||
'spectrogram_length': spectrogram_length,
|
||||
'dct_coefficient_count': dct_coefficient_count,
|
||||
'fingerprint_size': fingerprint_size,
|
||||
'label_count': label_count,
|
||||
'sample_rate': sample_rate,
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""DSCNN dataset."""
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore.dataset as de
|
||||
|
||||
|
||||
class NpyDataset():
|
||||
'''Dataset from numpy.'''
|
||||
def __init__(self, data_dir, data_type, h, w):
|
||||
super(NpyDataset, self).__init__()
|
||||
self.data = np.load(os.path.join(data_dir, '{}_data.npy'.format(data_type)))
|
||||
self.data = np.reshape(self.data, (-1, 1, h, w))
|
||||
self.label = np.load(os.path.join(data_dir, '{}_label.npy'.format(data_type)))
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, item):
|
||||
data = self.data[item]
|
||||
label = self.label[item]
|
||||
# return data, label
|
||||
return data.astype(np.float32), label.astype(np.int32)
|
||||
|
||||
|
||||
def audio_dataset(data_dir, data_type, h, w, batch_size):
|
||||
if 'testing' in data_dir:
|
||||
shuffle = False
|
||||
else:
|
||||
shuffle = True
|
||||
dataset = NpyDataset(data_dir, data_type, h, w)
|
||||
de_dataset = de.GeneratorDataset(dataset, ["feats", "labels"], shuffle=shuffle)
|
||||
de_dataset = de_dataset.batch(batch_size, drop_remainder=False)
|
||||
return de_dataset
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,107 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""DSCNN network."""
|
||||
import math
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore import Parameter
|
||||
|
||||
|
||||
class DepthWiseConv(nn.Cell):
|
||||
'''Build DepthWise conv.'''
|
||||
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
|
||||
super(DepthWiseConv, self).__init__()
|
||||
self.has_bias = has_bias
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
|
||||
stride=stride, pad_mode=pad_mode, pad=pad)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
weight_shape = [channel_multiplier, in_planes, kernel_size[0], kernel_size[1]]
|
||||
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
|
||||
|
||||
if has_bias:
|
||||
bias_shape = [channel_multiplier * in_planes]
|
||||
self.bias = Parameter(initializer('zeros', bias_shape), name='bias')
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def construct(self, x):
|
||||
output = self.depthwise_conv(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class DSCNN(nn.Cell):
|
||||
'''Build DSCNN network.'''
|
||||
def __init__(self, model_settings, model_size_info):
|
||||
super(DSCNN, self).__init__()
|
||||
# N C H W
|
||||
label_count = model_settings['label_count']
|
||||
input_frequency_size = model_settings['dct_coefficient_count']
|
||||
input_time_size = model_settings['spectrogram_length']
|
||||
t_dim = input_time_size
|
||||
f_dim = input_frequency_size
|
||||
num_layers = model_size_info[0]
|
||||
conv_feat = [None] * num_layers
|
||||
conv_kt = [None] * num_layers
|
||||
conv_kf = [None] * num_layers
|
||||
conv_st = [None] * num_layers
|
||||
conv_sf = [None] * num_layers
|
||||
i = 1
|
||||
for layer_no in range(0, num_layers):
|
||||
conv_feat[layer_no] = model_size_info[i]
|
||||
i += 1
|
||||
conv_kt[layer_no] = model_size_info[i]
|
||||
i += 1
|
||||
conv_kf[layer_no] = model_size_info[i]
|
||||
i += 1
|
||||
conv_st[layer_no] = model_size_info[i]
|
||||
i += 1
|
||||
conv_sf[layer_no] = model_size_info[i]
|
||||
i += 1
|
||||
seq_cell = []
|
||||
in_channel = 1
|
||||
for layer_no in range(0, num_layers):
|
||||
if layer_no == 0:
|
||||
seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no],
|
||||
kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
|
||||
stride=(conv_st[layer_no], conv_sf[layer_no]),
|
||||
pad_mode="same", padding=0, has_bias=False))
|
||||
seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98))
|
||||
in_channel = conv_feat[layer_no]
|
||||
else:
|
||||
seq_cell.append(DepthWiseConv(in_planes=in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
|
||||
stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same', pad=0))
|
||||
seq_cell.append(nn.BatchNorm2d(num_features=in_channel, momentum=0.98))
|
||||
seq_cell.append(nn.ReLU())
|
||||
seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no], kernel_size=(1, 1),
|
||||
pad_mode="same"))
|
||||
seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98))
|
||||
seq_cell.append(nn.ReLU())
|
||||
in_channel = conv_feat[layer_no]
|
||||
t_dim = math.ceil(t_dim / float(conv_st[layer_no]))
|
||||
f_dim = math.ceil(f_dim / float(conv_sf[layer_no]))
|
||||
seq_cell.append(nn.AvgPool2d(kernel_size=(t_dim, f_dim))) # to fix ?
|
||||
seq_cell.append(nn.Flatten())
|
||||
seq_cell.append(nn.Dropout(model_settings['dropout1']))
|
||||
seq_cell.append(nn.Dense(in_channel, label_count))
|
||||
self.model = nn.SequentialCell(seq_cell)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.model(x)
|
||||
return x
|
@ -0,0 +1,71 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Logger."""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
logger_name_1 = 'ds-cnn'
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
'''Build logger.'''
|
||||
def __init__(self, logger_name):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir):
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*' * 70 + '\n') * line_width
|
||||
important_msg += ('*' * line_width + '\n') * 2
|
||||
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
|
||||
important_msg += ('*' * line_width + '\n') * 2
|
||||
important_msg += ('*' * 70 + '\n') * line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path):
|
||||
logger = LOGGER(logger_name_1)
|
||||
logger.setup_logging_file(path)
|
||||
return logger
|
@ -0,0 +1,39 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""DSCNN loss."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
'''Build CrossEntropy Loss.'''
|
||||
def __init__(self, smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(label,
|
||||
F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, one_hot_label)
|
||||
loss = self.mean(loss, 0)
|
||||
return loss
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,26 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""DSCNN models."""
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
def load_ckpt(network, pretrain_ckpt_path, trainable=True):
|
||||
"""
|
||||
incremental_learning or not
|
||||
"""
|
||||
param_dict = load_checkpoint(pretrain_ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
if not trainable:
|
||||
for param in network.get_parameters():
|
||||
param.requires_grad = False
|
@ -0,0 +1,74 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Logger."""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
logger_name_1 = 'ds-cnn'
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
'''Build logger.'''
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
'''Setup logging file.'''
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*' * 70 + '\n') * line_width
|
||||
important_msg += ('*' * line_width + '\n') * 2
|
||||
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
|
||||
important_msg += ('*' * line_width + '\n') * 2
|
||||
important_msg += ('*' * 70 + '\n') * line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
logger = LOGGER(logger_name_1, rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
@ -0,0 +1,19 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Utils."""
|
||||
SILENCE_LABEL = '_silence_'
|
||||
UNKNOWN_WORD_LABEL = '_unknown_'
|
||||
def prepare_words_list(wanted_words):
|
||||
return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words
|
@ -0,0 +1,150 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""DSCNN train."""
|
||||
import os
|
||||
import datetime
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, Model
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
|
||||
from src.config import train_config
|
||||
from src.log import get_logger
|
||||
from src.dataset import audio_dataset
|
||||
from src.ds_cnn import DSCNN
|
||||
from src.loss import CrossEntropy
|
||||
from src.lr_scheduler import MultiStepLR, CosineAnnealingLR
|
||||
from src.callback import ProgressMonitor, callback_func
|
||||
|
||||
|
||||
def get_top5_acc(top5_arg, gt_class):
|
||||
sub_count = 0
|
||||
for top5, gt in zip(top5_arg, gt_class):
|
||||
if gt in top5:
|
||||
sub_count += 1
|
||||
return sub_count
|
||||
|
||||
|
||||
def val(args, model, val_dataset):
|
||||
'''Eval.'''
|
||||
val_dataloader = val_dataset.create_tuple_iterator()
|
||||
img_tot = 0
|
||||
top1_correct = 0
|
||||
top5_correct = 0
|
||||
for data, gt_classes in val_dataloader:
|
||||
output = model.predict(Tensor(data, mstype.float32))
|
||||
output = output.asnumpy()
|
||||
top1_output = np.argmax(output, (-1))
|
||||
top5_output = np.argsort(output)[:, -5:]
|
||||
gt_classes = gt_classes.asnumpy()
|
||||
t1_correct = np.equal(top1_output, gt_classes).sum()
|
||||
top1_correct += t1_correct
|
||||
top5_correct += get_top5_acc(top5_output, gt_classes)
|
||||
img_tot += output.shape[0]
|
||||
|
||||
results = [[top1_correct], [top5_correct], [img_tot]]
|
||||
|
||||
results = np.array(results)
|
||||
|
||||
top1_correct = results[0, 0]
|
||||
top5_correct = results[1, 0]
|
||||
img_tot = results[2, 0]
|
||||
acc1 = 100.0 * top1_correct / img_tot
|
||||
acc5 = 100.0 * top5_correct / img_tot
|
||||
if acc1 > args.best_acc:
|
||||
args.best_acc = acc1
|
||||
args.best_epoch = args.epoch_cnt - 1
|
||||
args.logger.info('Eval: top1_cor:{}, top5_cor:{}, tot:{}, acc@1={:.2f}%, acc@5={:.2f}%' \
|
||||
.format(top1_correct, top5_correct, img_tot, acc1, acc5))
|
||||
|
||||
|
||||
def trainval(args, model, train_dataset, val_dataset, cb):
|
||||
callbacks = callback_func(args, cb, 'epoch{}'.format(args.epoch_cnt))
|
||||
model.train(args.val_interval, train_dataset, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode)
|
||||
val(args, model, val_dataset)
|
||||
|
||||
|
||||
def train():
|
||||
'''Train.'''
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
|
||||
args, model_settings = train_config(parser)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id, enable_auto_mixed_precision=True)
|
||||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
# Logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir)
|
||||
|
||||
# Dataloader: train, val
|
||||
train_dataset = audio_dataset(args.feat_dir, 'training', model_settings['spectrogram_length'],
|
||||
model_settings['dct_coefficient_count'], args.per_batch_size)
|
||||
args.steps_per_epoch = train_dataset.get_dataset_size()
|
||||
val_dataset = audio_dataset(args.feat_dir, 'validation', model_settings['spectrogram_length'],
|
||||
model_settings['dct_coefficient_count'], args.per_batch_size)
|
||||
|
||||
# show args
|
||||
args.logger.save_args(args)
|
||||
|
||||
# Network
|
||||
args.logger.important_info('start create network')
|
||||
network = DSCNN(model_settings, args.model_size_info)
|
||||
|
||||
# Load pretrain model
|
||||
if os.path.isfile(args.pretrained):
|
||||
load_checkpoint(args.pretrained, network)
|
||||
args.logger.info('load model {} success'.format(args.pretrained))
|
||||
|
||||
# Loss
|
||||
criterion = CrossEntropy(num_classes=model_settings['label_count'])
|
||||
|
||||
# LR scheduler
|
||||
if args.lr_scheduler == 'multistep':
|
||||
lr_scheduler = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma, args.steps_per_epoch,
|
||||
args.max_epoch, warmup_epochs=args.warmup_epochs)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr_scheduler = CosineAnnealingLR(args.lr, args.T_max, args.steps_per_epoch, args.max_epoch,
|
||||
warmup_epochs=args.warmup_epochs, eta_min=args.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
lr_schedule = lr_scheduler.get_lr()
|
||||
|
||||
# Optimizer
|
||||
opt = Momentum(params=network.trainable_params(),
|
||||
learning_rate=Tensor(lr_schedule),
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
model = Model(network, loss_fn=criterion, optimizer=opt, amp_level='O0')
|
||||
|
||||
# Training
|
||||
args.epoch_cnt = 0
|
||||
args.best_epoch = 0
|
||||
args.best_acc = 0
|
||||
progress_cb = ProgressMonitor(args)
|
||||
while args.epoch_cnt + args.val_interval < args.max_epoch:
|
||||
trainval(args, model, train_dataset, val_dataset, progress_cb)
|
||||
rest_ep = args.max_epoch - args.epoch_cnt
|
||||
if rest_ep > 0:
|
||||
trainval(args, model, train_dataset, val_dataset, progress_cb)
|
||||
|
||||
args.logger.info('Best epoch:{} acc:{:.2f}%'.format(args.best_epoch, args.best_acc))
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
Loading…
Reference in new issue