You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
4.3 KiB
110 lines
4.3 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
'''
|
|
##############train models#################
|
|
python train.py
|
|
'''
|
|
import argparse
|
|
from mindspore import context, nn
|
|
from mindspore.train import Model
|
|
from mindspore.common import set_seed
|
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
|
from src.dataset import create_dataset
|
|
from src.musictagger import MusicTaggerCNN
|
|
from src.loss import BCELoss
|
|
from src.config import music_cfg as cfg
|
|
|
|
def train(model, dataset_direct, filename, columns_list, num_consumer=4,
|
|
batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50,
|
|
prefix="model", directory='./'):
|
|
"""
|
|
train network
|
|
"""
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps,
|
|
keep_checkpoint_max=keep_checkpoint_max)
|
|
ckpoint_cb = ModelCheckpoint(prefix=prefix,
|
|
directory=directory,
|
|
config=config_ck)
|
|
data_train = create_dataset(dataset_direct, filename, batch, columns_list,
|
|
num_consumer)
|
|
|
|
|
|
model.train(epoch,
|
|
data_train,
|
|
callbacks=[
|
|
ckpoint_cb,
|
|
LossMonitor(per_print_times=181),
|
|
TimeMonitor()
|
|
],
|
|
dataset_sink_mode=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
set_seed(1)
|
|
parser = argparse.ArgumentParser(description='Train model')
|
|
parser.add_argument('--device_id',
|
|
type=int,
|
|
help='device ID',
|
|
default=None)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.device_id is not None:
|
|
context.set_context(device_target='Ascend',
|
|
mode=context.GRAPH_MODE,
|
|
device_id=args.device_id)
|
|
else:
|
|
context.set_context(device_target='Ascend',
|
|
mode=context.GRAPH_MODE,
|
|
device_id=cfg.device_id)
|
|
|
|
context.set_context(enable_auto_mixed_precision=cfg.mixed_precision)
|
|
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
|
kernel_size=[3, 3, 3, 3, 3],
|
|
padding=[0] * 5,
|
|
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
|
has_bias=True)
|
|
|
|
if cfg.pre_trained:
|
|
param_dict = load_checkpoint(cfg.checkpoint_path + '/' +
|
|
cfg.model_name)
|
|
load_param_into_net(network, param_dict)
|
|
|
|
net_loss = BCELoss()
|
|
|
|
network.set_train(True)
|
|
net_opt = nn.Adam(params=network.trainable_params(),
|
|
learning_rate=cfg.lr,
|
|
loss_scale=cfg.loss_scale)
|
|
|
|
loss_scale_manager = FixedLossScaleManager(loss_scale=cfg.loss_scale,
|
|
drop_overflow_update=False)
|
|
net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager)
|
|
|
|
train(model=net_model,
|
|
dataset_direct=cfg.data_dir,
|
|
filename=cfg.train_filename,
|
|
columns_list=['feature', 'label'],
|
|
num_consumer=cfg.num_consumer,
|
|
batch=cfg.batch_size,
|
|
epoch=cfg.epoch_size,
|
|
save_checkpoint_steps=cfg.save_step,
|
|
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
|
prefix=cfg.prefix,
|
|
directory=cfg.checkpoint_path + "_{}".format(cfg.device_id))
|
|
print("train success")
|