From 1756d084aba85a15fbc3e0eb3de665a826477eef Mon Sep 17 00:00:00 2001 From: hanhuifeng2020 Date: Thu, 30 Jul 2020 21:38:44 +0800 Subject: [PATCH] tinybert script suit for gpu --- model_zoo/official/nlp/tinybert/README.md | 4 +- .../nlp/tinybert/run_general_distill.py | 53 ++++++--- .../official/nlp/tinybert/run_task_distill.py | 96 ++++++++++++--- .../scripts/run_distribute_gd_for_gpu.sh | 40 +++++++ .../nlp/tinybert/scripts/run_standalone_td.sh | 2 +- .../official/nlp/tinybert/src/dataset.py | 3 - .../nlp/tinybert/src/tinybert_for_gd_td.py | 109 +++++++++++++++++- model_zoo/official/nlp/tinybert/src/utils.py | 5 +- 8 files changed, 272 insertions(+), 40 deletions(-) create mode 100644 model_zoo/official/nlp/tinybert/scripts/run_distribute_gd_for_gpu.sh diff --git a/model_zoo/official/nlp/tinybert/README.md b/model_zoo/official/nlp/tinybert/README.md index 3d1e990223..13cfa3d777 100644 --- a/model_zoo/official/nlp/tinybert/README.md +++ b/model_zoo/official/nlp/tinybert/README.md @@ -46,7 +46,7 @@ usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T options: --distribute whether to run distributely: "true" | "false" - --device_target target device to run, currently only support "Ascend" + --device_target targeted device to run task: "Ascend" | "GPU" --epoch_size epoch size: N, default is 1 --device_id device id: N, default is 0 --enable_data_sink enable data sink: "true" | "false", default is "true" @@ -64,7 +64,7 @@ usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T options: --distribute whether to run distributely: "true" | "false" - --device_target target device to run, currently only support "Ascend" + --device_target targeted device to run task: "Ascend" | "GPU" --epoch_size epoch size: N, default is 1 --device_id device id: N, default is 0 --device_num device id to run task diff --git a/model_zoo/official/nlp/tinybert/run_general_distill.py b/model_zoo/official/nlp/tinybert/run_general_distill.py index 199ee6adf2..62730b62c5 100644 --- a/model_zoo/official/nlp/tinybert/run_general_distill.py +++ b/model_zoo/official/nlp/tinybert/run_general_distill.py @@ -20,16 +20,20 @@ import argparse import datetime import numpy import mindspore.communication.management as D +import mindspore.common.dtype as mstype from mindspore import context from mindspore.train.model import Model from mindspore.train.callback import TimeMonitor from mindspore.train.parallel_utils import ParallelMode from mindspore.nn.optim import AdamWeightDecay from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore import log as logger from src.dataset import create_tinybert_dataset from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg -from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd +from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell + + def run_general_distill(): """ @@ -53,7 +57,6 @@ def run_general_distill(): parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") @@ -61,13 +64,17 @@ def run_general_distill(): save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) - if not os.path.exists(save_ckpt_dir): - os.makedirs(save_ckpt_dir) if args_opt.distribute == "true": - D.init('hccl') - device_num = args_opt.device_num - rank = args_opt.device_id % device_num + if args_opt.device_target == 'Ascend': + D.init('hccl') + device_num = args_opt.device_num + rank = args_opt.device_id % device_num + else: + D.init('nccl') + device_num = D.get_group_size() + rank = D.get_rank() + save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank) context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) @@ -75,6 +82,21 @@ def run_general_distill(): rank = 0 device_num = 1 + if not os.path.exists(save_ckpt_dir): + os.makedirs(save_ckpt_dir) + + enable_loss_scale = True + if args_opt.device_target == "GPU": + if bert_teacher_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_teacher_net_cfg.compute_type = mstype.float32 + if bert_student_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_student_net_cfg.compute_type = mstype.float32 + # Both the forward and backward of the network are calculated using fp32, + # and the loss scale is not necessary + enable_loss_scale = False + netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, teacher_ckpt=args_opt.load_teacher_ckpt_path, student_config=bert_student_net_cfg, @@ -82,11 +104,11 @@ def run_general_distill(): dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) - dataset_size = dataset.get_dataset_size() print('dataset size: ', dataset_size) + print("dataset repeatcount: ", dataset.get_repeat_count()) if args_opt.enable_data_sink == "true": - repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps + repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.epoch_size @@ -110,12 +132,13 @@ def run_general_distill(): args_opt.save_ckpt_step, args_opt.max_ckpt_num, save_ckpt_dir)] - - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value, - scale_factor=common_cfg.scale_factor, - scale_window=common_cfg.scale_window) - - netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + if enable_loss_scale: + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value, + scale_factor=common_cfg.scale_factor, + scale_window=common_cfg.scale_window) + netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + else: + netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py index fd689f141d..b7eceac8e2 100644 --- a/model_zoo/official/nlp/tinybert/run_task_distill.py +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -18,6 +18,7 @@ import os import re import argparse +import mindspore.common.dtype as mstype from mindspore import Tensor from mindspore import context from mindspore.train.model import Model @@ -25,11 +26,12 @@ from mindspore.train.callback import TimeMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.optim import AdamWeightDecay +from mindspore import log as logger from src.dataset import create_tinybert_dataset from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate from src.assessment_method import Accuracy from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg -from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td +from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell from src.tinybert_model import BertModelCLS _cur_dir = os.getcwd() @@ -45,14 +47,14 @@ def parse_args(): parse args """ parser = argparse.ArgumentParser(description='tinybert task distill') - parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.") + parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.") parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.") parser.add_argument("--td_phase1_epoch_size", type=int, default=10, help="Epoch size for td phase 1, default is 10.") parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") - parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") @@ -64,11 +66,43 @@ def parse_args(): parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") + parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], + help="The name of the task to train.") args = parser.parse_args() return args + args_opt = parse_args() + +DEFAULT_NUM_LABELS = 2 +DEFAULT_SEQ_LENGTH = 128 +task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, + "QNLI": {"num_labels": 2, "seq_length": 128}, + "MNLI": {"num_labels": 3, "seq_length": 128}} + + +class Task: + """ + Encapsulation class of get the task parameter. + """ + def __init__(self, task_name): + self.task_name = task_name + + @property + def num_labels(self): + if self.task_name in task_params and "num_labels" in task_params[self.task_name]: + return task_params[self.task_name]["num_labels"] + return DEFAULT_NUM_LABELS + + @property + def seq_length(self): + if self.task_name in task_params and "seq_length" in task_params[self.task_name]: + return task_params[self.task_name]["seq_length"] + return DEFAULT_SEQ_LENGTH +task = Task(args_opt.task_name) + + def run_predistill(): """ run predistill @@ -81,7 +115,7 @@ def run_predistill(): netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, is_training=True, task_type='classification', - num_labels=args_opt.num_labels, is_predistill=True) + num_labels=task.num_labels, is_predistill=True) rank = 0 device_num = 1 @@ -91,8 +125,9 @@ def run_predistill(): dataset_size = dataset.get_dataset_size() print('td1 dataset size: ', dataset_size) + print('td1 dataset repeatcount: ', dataset.get_repeat_count()) if args_opt.enable_data_sink == 'true': - repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps + repeat_count = args_opt.td_phase1_epoch_size * dataset_size // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.td_phase1_epoch_size @@ -117,10 +152,14 @@ def run_predistill(): args_opt.save_ckpt_step, args_opt.max_ckpt_num, td_phase1_save_ckpt_dir)] - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, - scale_factor=cfg.scale_factor, - scale_window=cfg.scale_window) - netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + if enable_loss_scale: + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) + netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + else: + netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer) + model = Model(netwithgrads) model.train(repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == 'true'), @@ -139,7 +178,7 @@ def run_task_distill(ckpt_file): netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, is_training=True, task_type='classification', - num_labels=args_opt.num_labels, is_predistill=False) + num_labels=task.num_labels, is_predistill=False) rank = 0 device_num = 1 @@ -149,6 +188,7 @@ def run_task_distill(ckpt_file): dataset_size = train_dataset.get_dataset_size() print('td2 train dataset size: ', dataset_size) + print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count()) if args_opt.enable_data_sink == 'true': repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps @@ -175,6 +215,7 @@ def run_task_distill(ckpt_file): eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.eval_data_dir, args_opt.schema_dir) + print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) if args_opt.do_eval.lower() == "true": callback = [TimeMonitor(time_monitor_steps), LossCallBack(), @@ -185,11 +226,14 @@ def run_task_distill(ckpt_file): args_opt.save_ckpt_step, args_opt.max_ckpt_num, td_phase2_save_ckpt_dir)] - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, - scale_factor=cfg.scale_factor, - scale_window=cfg.scale_window) + if enable_loss_scale: + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) - netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + else: + netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(repeat_count, train_dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == 'true'), @@ -203,7 +247,7 @@ def do_eval_standalone(): if ckpt_file == '': raise ValueError("Student ckpt file should not be None") context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) - eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student") + eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") param_dict = load_checkpoint(ckpt_file) new_param_dict = {} for key, value in param_dict.items(): @@ -213,10 +257,13 @@ def do_eval_standalone(): load_param_into_net(eval_model, new_param_dict) eval_model.set_train(False) - eval_dataset = create_tinybert_dataset('td', batch_size=1, + eval_dataset = create_tinybert_dataset('td', batch_size=td_student_net_cfg.batch_size, device_num=1, rank=0, do_shuffle="false", data_dir=args_opt.eval_data_dir, schema_dir=args_opt.schema_dir) + print('eval dataset size: ', eval_dataset.get_dataset_size()) + print('eval dataset batch size: ', eval_dataset.get_batch_size()) + callback = Accuracy() columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] for data in eval_dataset.create_dict_iterator(): @@ -231,9 +278,26 @@ def do_eval_standalone(): print("============== acc is {}".format(acc)) print("======================================") + if __name__ == '__main__': if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": raise ValueError("do_train or do eval must have one be true, please confirm your config") + + enable_loss_scale = True + if args_opt.device_target == "GPU": + if td_teacher_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + td_teacher_net_cfg.compute_type = mstype.float32 + if td_student_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + td_student_net_cfg.compute_type = mstype.float32 + # Both the forward and backward of the network are calculated using fp32, + # and the loss scale is not necessary + enable_loss_scale = False + + td_teacher_net_cfg.seq_length = task.seq_length + td_student_net_cfg.seq_length = task.seq_length + if args_opt.do_train == "true": # run predistill run_predistill() diff --git a/model_zoo/official/nlp/tinybert/scripts/run_distribute_gd_for_gpu.sh b/model_zoo/official/nlp/tinybert/scripts/run_distribute_gd_for_gpu.sh new file mode 100644 index 0000000000..d345ddef5c --- /dev/null +++ b/model_zoo/official/nlp/tinybert/scripts/run_distribute_gd_for_gpu.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash run_distribute_gd_for_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR TEACHER_CKPT_PATH" +echo "for example: bash run_distribute_gd_for_gpu.sh 8 3 /path/data/ /path/datasetSchema.json /path/bert_base.ckpt" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +RANK_SIZE=$1 +EPOCH_SIZE=$2 +DATA_DIR=$3 +SCHEMA_DIR=$4 +TEACHER_CKPT_PATH=$5 + +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) + +mpirun --allow-run-as-root -n $RANK_SIZE \ + python ${PROJECT_DIR}/../run_general_distill.py \ + --distribute="true" \ + --device_target="GPU" \ + --epoch_size=$EPOCH_SIZE \ + --save_ckpt_path="" \ + --data_dir=$DATA_DIR \ + --schema_dir=$SCHEMA_DIR \ + --load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh b/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh index dcc01163db..a8d0d4fc51 100644 --- a/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh +++ b/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh @@ -32,7 +32,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \ --do_eval="true" \ --td_phase1_epoch_size=10 \ --td_phase2_epoch_size=3 \ - --num_labels=2 \ + --task_name="" \ --do_shuffle="true" \ --enable_data_sink="true" \ --data_sink_steps=100 \ diff --git a/model_zoo/official/nlp/tinybert/src/dataset.py b/model_zoo/official/nlp/tinybert/src/dataset.py index fdc0dfe21e..d4af8ed603 100644 --- a/model_zoo/official/nlp/tinybert/src/dataset.py +++ b/model_zoo/official/nlp/tinybert/src/dataset.py @@ -19,7 +19,6 @@ import os import mindspore.common.dtype as mstype import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C -from mindspore import log as logger def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None): @@ -45,7 +44,5 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, ds = ds.map(input_columns="label_ids", operations=type_cast_op) # apply batch operations ds = ds.batch(batch_size, drop_remainder=True) - logger.info("data size: {}".format(ds.get_dataset_size())) - logger.info("repeatcount: {}".format(ds.get_repeat_count())) return ds diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py index 55da0f3db9..f244c5591d 100644 --- a/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py +++ b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py @@ -292,6 +292,60 @@ class BertTrainWithLossScaleCell(nn.Cell): ret = (loss, cond, scaling_sens) return F.depend(ret, succ) +class BertTrainCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + """ + def __init__(self, network, optimizer, sens=1.0): + super(BertTrainCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.sens = sens + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree) + self.cast = P.Cast() + self.hyper_map = C.HyperMap() + + def construct(self, + input_ids, + input_mask, + token_type_id): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + self.cast(F.tuple_to_array((self.sens,)), + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + succ = self.optimizer(grads) + return F.depend(loss, succ) + class BertNetworkWithLoss_td(nn.Cell): """ Provide bert pre-training loss through network. @@ -411,12 +465,12 @@ class BertNetworkWithLoss_td(nn.Cell): total_loss += cls_loss return self.cast(total_loss, mstype.float32) -class BertEvaluationCell(nn.Cell): +class BertEvaluationWithLossScaleCell(nn.Cell): """ Especifically defined for finetuning where only four inputs tensor are needed. """ def __init__(self, network, optimizer, scale_update_cell=None): - super(BertEvaluationCell, self).__init__(auto_prefix=False) + super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.weights = optimizer.parameters self.optimizer = optimizer @@ -496,3 +550,54 @@ class BertEvaluationCell(nn.Cell): succ = self.optimizer(grads) ret = (loss, cond, scaling_sens) return F.depend(ret, succ) + + +class BertEvaluationCell(nn.Cell): + """ + Especifically defined for finetuning where only four inputs tensor are needed. + """ + def __init__(self, network, optimizer, sens=1.0): + super(BertEvaluationCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.sens = sens + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.hyper_map = C.HyperMap() + + def construct(self, + input_ids, + input_mask, + token_type_id, + label_ids): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + label_ids) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + label_ids, + self.cast(F.tuple_to_array((self.sens,)), + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + succ = self.optimizer(grads) + return F.depend(loss, succ) diff --git a/model_zoo/official/nlp/tinybert/src/utils.py b/model_zoo/official/nlp/tinybert/src/utils.py index d10fb8642e..5e1e775707 100644 --- a/model_zoo/official/nlp/tinybert/src/utils.py +++ b/model_zoo/official/nlp/tinybert/src/utils.py @@ -110,7 +110,10 @@ class EvalCallBack(Callback): if acc > self.global_acc: self.global_acc = acc print("The best acc is {}".format(acc)) - _exec_save_checkpoint(self.network, "eval_model.ckpt") + eval_model_ckpt_file = "eval_model.ckpt" + if os.path.exists(eval_model_ckpt_file): + os.remove(eval_model_ckpt_file) + _exec_save_checkpoint(self.network, eval_model_ckpt_file) class BertLearningRate(LearningRateSchedule): """