# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ''' Bert finetune script. ''' import os from src.utils import BertFinetuneCell, BertCLS, BertNER from src.finetune_config import cfg, bert_net_cfg, tag_to_index import mindspore.common.dtype as mstype import mindspore.communication.management as D from mindspore import context import mindspore.dataset as de import mindspore.dataset.transforms.c_transforms as C from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum from mindspore.train.model import Model from mindspore.train.callback import Callback from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.serialization import load_checkpoint, load_param_into_net class LossCallBack(Callback): ''' Monitor the loss in training. If the loss is NAN or INF, terminate training. Note: If per_print_times is 0, do not print loss. Args: per_print_times (int): Print loss every times. Default: 1. ''' def __init__(self, per_print_times=1): super(LossCallBack, self).__init__() if not isinstance(per_print_times, int) or per_print_times < 0: raise ValueError("print_step must be in and >= 0.") self._per_print_times = per_print_times def step_end(self, run_context): cb_params = run_context.original_args() with open("./loss.log", "a+") as f: f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, str(cb_params.net_outputs))) f.write("\n") def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): ''' get dataset ''' _ = distribute_file ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="segment_ids", operations=type_cast_op) ds = ds.map(input_columns="input_mask", operations=type_cast_op) ds = ds.map(input_columns="input_ids", operations=type_cast_op) ds = ds.map(input_columns="label_ids", operations=type_cast_op) ds = ds.repeat(repeat_count) # apply shuffle operation buffer_size = 960 ds = ds.shuffle(buffer_size=buffer_size) # apply batch operations ds = ds.batch(batch_size, drop_remainder=True) return ds def test_train(): ''' finetune function pytest -s finetune.py::test_train ''' devid = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) #BertCLSTrain for classification #BertNERTrain for sequence labeling if cfg.task == 'NER': if cfg.use_crf: netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True, tag_to_index=tag_to_index, dropout_prob=0.1) else: netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) else: netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num) # optimizer steps_per_epoch = dataset.get_dataset_size() if cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num, learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=cfg.AdamWeightDecayDynamicLR.power, warmup_steps=steps_per_epoch, weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, eps=cfg.AdamWeightDecayDynamicLR.eps) elif cfg.optimizer == 'Lamb': optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num, start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, power=cfg.Lamb.power, warmup_steps=steps_per_epoch, decay_filter=cfg.Lamb.decay_filter) elif cfg.optimizer == 'Momentum': optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) else: raise Exception("Optimizer not supported.") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config) param_dict = load_checkpoint(cfg.pre_training_ckpt) load_param_into_net(netwithloss, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb]) D.release() if __name__ == "__main__": test_train()