parent
9ccc6889eb
commit
7dd5e78fde
@ -0,0 +1,129 @@
|
|||||||
|
# TinyBERT Example
|
||||||
|
## Description
|
||||||
|
This example implements general distill and task distill of [BERT-base](https://github.com/google-research/bert)(the base version of BERT model).
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||||
|
- Download dataset for general distill and task distill such as GLUE.
|
||||||
|
- Prepare a pre-trained bert model and a fine-tuned bert model for specific task such as GLUE.
|
||||||
|
|
||||||
|
## Running the Example
|
||||||
|
### General Distill
|
||||||
|
- Set options in `src/gd_config.py`, including lossscale, optimizer and network.
|
||||||
|
|
||||||
|
- Set options in `scripts/run_standalone_gd.sh`, including device target, data sink config, checkpoint config and dataset. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
|
||||||
|
|
||||||
|
- Run `run_standalone_gd.sh` for non-distributed general distill of BERT-base model.
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
bash scripts/run_standalone_gd.sh
|
||||||
|
```
|
||||||
|
- Run `run_distribute_gd.sh` for distributed general distill of BERT-base model.
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
bash scripts/run_distribute_gd.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task Distill
|
||||||
|
Task distill has two phases, pre-distill and task distill.
|
||||||
|
- Set options in `src/td_config.py`, including lossscale, optimizer config of phase 1 and 2, as well as network config.
|
||||||
|
|
||||||
|
- Run `run_standalone_td.py` for task distill of BERT-base model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/run_standalone_td.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
### General Distill
|
||||||
|
```
|
||||||
|
usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_TARGET]
|
||||||
|
[--epoch_size N] [--device_id N]
|
||||||
|
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
|
||||||
|
[--save_checkpoint_steps N] [--max_ckpt_num N]
|
||||||
|
[--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH]
|
||||||
|
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR]
|
||||||
|
|
||||||
|
options:
|
||||||
|
--distribute whether to run distributely: "true" | "false"
|
||||||
|
--device_target target device to run, currently only support "Ascend"
|
||||||
|
--epoch_size epoch size: N, default is 1
|
||||||
|
--device_id device id: N, default is 0
|
||||||
|
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||||
|
--data_sink_steps set data sink steps: N, default is 1
|
||||||
|
--load_teacher_ckpt_path path of teacher checkpoint to load: PATH, default is ""
|
||||||
|
--data_dir path to dataset directory: PATH, default is ""
|
||||||
|
--schema_dir path to schema.json file, PATH, default is ""
|
||||||
|
|
||||||
|
usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_TARGET]
|
||||||
|
[--epoch_size N] [--device_id N] [--device_num N]
|
||||||
|
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
|
||||||
|
[--save_ckpt_steps N] [--max_ckpt_num N]
|
||||||
|
[--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH]
|
||||||
|
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR]
|
||||||
|
|
||||||
|
options:
|
||||||
|
--distribute whether to run distributely: "true" | "false"
|
||||||
|
--device_target target device to run, currently only support "Ascend"
|
||||||
|
--epoch_size epoch size: N, default is 1
|
||||||
|
--device_id device id: N, default is 0
|
||||||
|
--device_num device id to run task
|
||||||
|
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||||
|
--data_sink_steps set data sink steps: N, default is 1
|
||||||
|
--load_teacher_ckpt_path path of teacher checkpoint to load: PATH, default is ""
|
||||||
|
--data_dir path to dataset directory: PATH, default is ""
|
||||||
|
--schema_dir path to schema.json file, PATH, default is ""
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Options and Parameters
|
||||||
|
`gd_config.py` and `td_config.py` Contain parameters of BERT model and options for optimizer and lossscale.
|
||||||
|
### Options:
|
||||||
|
```
|
||||||
|
Parameters for lossscale:
|
||||||
|
loss_scale_value initial value of loss scale: N, default is 2^8
|
||||||
|
scale_factor factor used to update loss scale: N, default is 2
|
||||||
|
scale_window steps for once updatation of loss scale: N, default is 50
|
||||||
|
|
||||||
|
Parameters for task-specific config:
|
||||||
|
load_teacher_ckpt_path teacher checkpoint to load
|
||||||
|
load_student_ckpt_path student checkpoint to load
|
||||||
|
data_dir training data dir
|
||||||
|
eval_data_dir evaluation data dir
|
||||||
|
schema_dir data schema path
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parameters:
|
||||||
|
```
|
||||||
|
Parameters for bert network:
|
||||||
|
batch_size batch size of input dataset: N, default is 16
|
||||||
|
seq_length length of input sequence: N, default is 128
|
||||||
|
vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 30522
|
||||||
|
hidden_size size of bert encoder layers: N
|
||||||
|
num_hidden_layers number of hidden layers: N
|
||||||
|
num_attention_heads number of attention heads: N, default is 12
|
||||||
|
intermediate_size size of intermediate layer: N
|
||||||
|
hidden_act activation function used: ACTIVATION, default is "gelu"
|
||||||
|
hidden_dropout_prob dropout probability for BertOutput: Q
|
||||||
|
attention_probs_dropout_prob dropout probability for BertAttention: Q
|
||||||
|
max_position_embeddings maximum length of sequences: N, default is 512
|
||||||
|
save_ckpt_step number for saving checkponit: N, default is 100
|
||||||
|
max_ckpt_num maximum number for saving checkpoint: N, default is 1
|
||||||
|
type_vocab_size size of token type vocab: N, default is 2
|
||||||
|
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
||||||
|
use_relative_positions use relative positions or not: True | False, default is False
|
||||||
|
input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True
|
||||||
|
token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True
|
||||||
|
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
|
||||||
|
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16
|
||||||
|
enable_fused_layernorm use batchnorm instead of layernorm to improve performance, default is False
|
||||||
|
|
||||||
|
Parameters for optimizer:
|
||||||
|
optimizer optimizer used in the network: AdamWeightDecay
|
||||||
|
learning_rate value of learning rate: Q
|
||||||
|
end_learning_rate value of end learning rate: Q, must be positive
|
||||||
|
power power: Q
|
||||||
|
weight_decay weight decay: Q
|
||||||
|
eps term added to the denominator to improve numerical stability: Q
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,124 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""general distill script"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import numpy
|
||||||
|
import mindspore.communication.management as D
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.callback import TimeMonitor
|
||||||
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
from mindspore.nn.optim import AdamWeightDecay
|
||||||
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||||
|
from src.dataset import create_tinybert_dataset
|
||||||
|
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
||||||
|
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
||||||
|
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd
|
||||||
|
|
||||||
|
def run_general_distill():
|
||||||
|
"""
|
||||||
|
run general distill
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description='tinybert general distill')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
||||||
|
parser.add_argument("--epoch_size", type=int, default="3", help="Epoch size, default is 1.")
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||||
|
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||||
|
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
|
||||||
|
parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.")
|
||||||
|
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||||
|
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||||
|
parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.")
|
||||||
|
parser.add_argument("--save_ckpt_path", type=str, default="", help="Save checkpoint path")
|
||||||
|
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||||
|
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||||
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||||
|
context.set_context(reserve_class_name_in_scope=False)
|
||||||
|
context.set_context(variable_memory_max_size="30GB")
|
||||||
|
|
||||||
|
save_ckpt_dir = os.path.join(args_opt.save_ckpt_path,
|
||||||
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
|
|
||||||
|
if not os.path.exists(save_ckpt_dir):
|
||||||
|
os.makedirs(save_ckpt_dir)
|
||||||
|
|
||||||
|
if args_opt.distribute == "true":
|
||||||
|
D.init('hccl')
|
||||||
|
device_num = args_opt.device_num
|
||||||
|
rank = args_opt.device_id % device_num
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||||
|
device_num=device_num)
|
||||||
|
else:
|
||||||
|
rank = 0
|
||||||
|
device_num = 1
|
||||||
|
|
||||||
|
netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg,
|
||||||
|
teacher_ckpt=args_opt.load_teacher_ckpt_path,
|
||||||
|
student_config=bert_student_net_cfg,
|
||||||
|
is_training=True, use_one_hot_embeddings=False)
|
||||||
|
|
||||||
|
dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank,
|
||||||
|
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
||||||
|
|
||||||
|
dataset_size = dataset.get_dataset_size()
|
||||||
|
|
||||||
|
if args_opt.enable_data_sink == "true":
|
||||||
|
repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||||
|
else:
|
||||||
|
repeat_count = args_opt.epoch_size
|
||||||
|
|
||||||
|
lr_schedule = BertLearningRate(learning_rate=common_cfg.AdamWeightDecay.learning_rate,
|
||||||
|
end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
|
||||||
|
warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
|
||||||
|
decay_steps=int(dataset_size * args_opt.epoch_size),
|
||||||
|
power=common_cfg.AdamWeightDecay.power)
|
||||||
|
params = netwithloss.trainable_params()
|
||||||
|
decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params))
|
||||||
|
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||||
|
group_params = [{'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay},
|
||||||
|
{'params': other_params, 'weight_decay': 0.0},
|
||||||
|
{'order_params': params}]
|
||||||
|
|
||||||
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps)
|
||||||
|
|
||||||
|
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||||
|
args_opt.save_ckpt_step,
|
||||||
|
args_opt.max_ckpt_num,
|
||||||
|
save_ckpt_dir)]
|
||||||
|
|
||||||
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
|
||||||
|
scale_factor=common_cfg.scale_factor,
|
||||||
|
scale_window=common_cfg.scale_window)
|
||||||
|
|
||||||
|
netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||||
|
model = Model(netwithgrads)
|
||||||
|
model.train(repeat_count, dataset, callbacks=callback,
|
||||||
|
dataset_sink_mode=(args_opt.enable_data_sink == "true"),
|
||||||
|
sink_size=args_opt.data_sink_steps)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
numpy.random.seed(0)
|
||||||
|
run_general_distill()
|
@ -0,0 +1,249 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""task distill script"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import argparse
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.callback import TimeMonitor
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||||
|
from mindspore.nn.optim import AdamWeightDecay
|
||||||
|
from src.dataset import create_tinybert_dataset
|
||||||
|
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
||||||
|
from src.assessment_method import Accuracy
|
||||||
|
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
|
||||||
|
from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td
|
||||||
|
from src.tinybert_model import BertModelCLS
|
||||||
|
|
||||||
|
_cur_dir = os.getcwd()
|
||||||
|
td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt')
|
||||||
|
td_phase2_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase2_save_ckpt')
|
||||||
|
if not os.path.exists(td_phase1_save_ckpt_dir):
|
||||||
|
os.makedirs(td_phase1_save_ckpt_dir)
|
||||||
|
if not os.path.exists(td_phase2_save_ckpt_dir):
|
||||||
|
os.makedirs(td_phase2_save_ckpt_dir)
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""
|
||||||
|
parse args
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description='tinybert task distill')
|
||||||
|
parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.")
|
||||||
|
parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.")
|
||||||
|
parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.")
|
||||||
|
parser.add_argument("--td_phase1_epoch_size", type=int, default=10,
|
||||||
|
help="Epoch size for td phase 1, default is 10.")
|
||||||
|
parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.")
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||||
|
parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.")
|
||||||
|
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||||
|
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||||
|
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
|
||||||
|
parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.")
|
||||||
|
parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.")
|
||||||
|
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||||
|
parser.add_argument("--load_gd_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||||
|
parser.add_argument("--load_td1_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||||
|
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||||
|
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||||
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
args_opt = parse_args()
|
||||||
|
def run_predistill():
|
||||||
|
"""
|
||||||
|
run predistill
|
||||||
|
"""
|
||||||
|
cfg = phase1_cfg
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||||
|
context.set_context(reserve_class_name_in_scope=False)
|
||||||
|
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
||||||
|
load_student_checkpoint_path = args_opt.load_gd_ckpt_path
|
||||||
|
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
||||||
|
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
||||||
|
is_training=True, task_type='classification',
|
||||||
|
num_labels=args_opt.num_labels, is_predistill=True)
|
||||||
|
|
||||||
|
rank = 0
|
||||||
|
device_num = 1
|
||||||
|
dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
|
||||||
|
device_num, rank, args_opt.do_shuffle,
|
||||||
|
args_opt.train_data_dir, args_opt.schema_dir)
|
||||||
|
|
||||||
|
dataset_size = dataset.get_dataset_size()
|
||||||
|
if args_opt.enable_data_sink == 'true':
|
||||||
|
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||||
|
else:
|
||||||
|
repeat_count = args_opt.td_phase1_epoch_size
|
||||||
|
|
||||||
|
optimizer_cfg = cfg.optimizer_cfg
|
||||||
|
|
||||||
|
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
|
||||||
|
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
|
||||||
|
warmup_steps=int(dataset_size / 10),
|
||||||
|
decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size),
|
||||||
|
power=optimizer_cfg.AdamWeightDecay.power)
|
||||||
|
params = netwithloss.trainable_params()
|
||||||
|
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||||
|
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||||
|
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||||
|
{'params': other_params, 'weight_decay': 0.0},
|
||||||
|
{'order_params': params}]
|
||||||
|
|
||||||
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||||
|
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||||
|
args_opt.save_ckpt_step,
|
||||||
|
args_opt.max_ckpt_num,
|
||||||
|
td_phase1_save_ckpt_dir)]
|
||||||
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||||
|
scale_factor=cfg.scale_factor,
|
||||||
|
scale_window=cfg.scale_window)
|
||||||
|
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||||
|
model = Model(netwithgrads)
|
||||||
|
model.train(repeat_count, dataset, callbacks=callback,
|
||||||
|
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||||
|
sink_size=args_opt.data_sink_steps)
|
||||||
|
|
||||||
|
def run_task_distill(ckpt_file):
|
||||||
|
"""
|
||||||
|
run task distill
|
||||||
|
"""
|
||||||
|
if ckpt_file == '':
|
||||||
|
raise ValueError("Student ckpt file should not be None")
|
||||||
|
cfg = phase2_cfg
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||||
|
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
||||||
|
load_student_checkpoint_path = ckpt_file
|
||||||
|
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
||||||
|
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
||||||
|
is_training=True, task_type='classification',
|
||||||
|
num_labels=args_opt.num_labels, is_predistill=False)
|
||||||
|
|
||||||
|
rank = 0
|
||||||
|
device_num = 1
|
||||||
|
train_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
|
||||||
|
device_num, rank, args_opt.do_shuffle,
|
||||||
|
args_opt.train_data_dir, args_opt.schema_dir)
|
||||||
|
|
||||||
|
dataset_size = train_dataset.get_dataset_size()
|
||||||
|
if args_opt.enable_data_sink == 'true':
|
||||||
|
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||||
|
else:
|
||||||
|
repeat_count = args_opt.td_phase2_epoch_size
|
||||||
|
|
||||||
|
optimizer_cfg = cfg.optimizer_cfg
|
||||||
|
|
||||||
|
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
|
||||||
|
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
|
||||||
|
warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10),
|
||||||
|
decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size),
|
||||||
|
power=optimizer_cfg.AdamWeightDecay.power)
|
||||||
|
params = netwithloss.trainable_params()
|
||||||
|
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||||
|
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||||
|
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||||
|
{'params': other_params, 'weight_decay': 0.0},
|
||||||
|
{'order_params': params}]
|
||||||
|
|
||||||
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||||
|
|
||||||
|
eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
|
||||||
|
device_num, rank, args_opt.do_shuffle,
|
||||||
|
args_opt.eval_data_dir, args_opt.schema_dir)
|
||||||
|
if args_opt.do_eval.lower() == "true":
|
||||||
|
callback = [TimeMonitor(dataset_size), LossCallBack(),
|
||||||
|
ModelSaveCkpt(netwithloss.bert,
|
||||||
|
args_opt.save_ckpt_step,
|
||||||
|
args_opt.max_ckpt_num,
|
||||||
|
td_phase2_save_ckpt_dir),
|
||||||
|
EvalCallBack(netwithloss.bert, eval_dataset)]
|
||||||
|
else:
|
||||||
|
callback = [TimeMonitor(dataset_size), LossCallBack(),
|
||||||
|
ModelSaveCkpt(netwithloss.bert,
|
||||||
|
args_opt.save_ckpt_step,
|
||||||
|
args_opt.max_ckpt_num,
|
||||||
|
td_phase2_save_ckpt_dir)]
|
||||||
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||||
|
scale_factor=cfg.scale_factor,
|
||||||
|
scale_window=cfg.scale_window)
|
||||||
|
|
||||||
|
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||||
|
model = Model(netwithgrads)
|
||||||
|
model.train(repeat_count, train_dataset, callbacks=callback,
|
||||||
|
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||||
|
sink_size=args_opt.data_sink_steps)
|
||||||
|
|
||||||
|
def do_eval_standalone():
|
||||||
|
"""
|
||||||
|
do eval standalone
|
||||||
|
"""
|
||||||
|
ckpt_file = args_opt.load_td1_ckpt_path
|
||||||
|
if ckpt_file == '':
|
||||||
|
raise ValueError("Student ckpt file should not be None")
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||||
|
eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student")
|
||||||
|
param_dict = load_checkpoint(ckpt_file)
|
||||||
|
new_param_dict = {}
|
||||||
|
for key, value in param_dict.items():
|
||||||
|
new_key = re.sub('tinybert_', 'bert_', key)
|
||||||
|
new_key = re.sub('^bert.', '', new_key)
|
||||||
|
new_param_dict[new_key] = value
|
||||||
|
load_param_into_net(eval_model, new_param_dict)
|
||||||
|
eval_model.set_train(False)
|
||||||
|
|
||||||
|
eval_dataset = create_tinybert_dataset('td', batch_size=1,
|
||||||
|
device_num=1, rank=0, do_shuffle="false",
|
||||||
|
data_dir=args_opt.eval_data_dir,
|
||||||
|
schema_dir=args_opt.schema_dir)
|
||||||
|
callback = Accuracy()
|
||||||
|
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||||
|
for data in eval_dataset.create_dict_iterator():
|
||||||
|
input_data = []
|
||||||
|
for i in columns_list:
|
||||||
|
input_data.append(Tensor(data[i]))
|
||||||
|
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||||
|
logits = eval_model(input_ids, token_type_id, input_mask)
|
||||||
|
callback.update(logits[3], label_ids)
|
||||||
|
acc = callback.acc_num / callback.total_num
|
||||||
|
print("======================================")
|
||||||
|
print("============== acc is {}".format(acc))
|
||||||
|
print("======================================")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
|
||||||
|
raise ValueError("do_train or do eval must have one be true, please confirm your config")
|
||||||
|
if args_opt.do_train == "true":
|
||||||
|
# run predistill
|
||||||
|
run_predistill()
|
||||||
|
lists = os.listdir(td_phase1_save_ckpt_dir)
|
||||||
|
if lists:
|
||||||
|
lists.sort(key=lambda fn: os.path.getmtime(td_phase1_save_ckpt_dir+'/'+fn))
|
||||||
|
name_ext = os.path.splitext(lists[-1])
|
||||||
|
if name_ext[-1] != ".ckpt":
|
||||||
|
raise ValueError("Invalid file, checkpoint file should be .ckpt file")
|
||||||
|
newest_ckpt_file = os.path.join(td_phase1_save_ckpt_dir, lists[-1])
|
||||||
|
# run task distill
|
||||||
|
run_task_distill(newest_ckpt_file)
|
||||||
|
else:
|
||||||
|
raise ValueError("Checkpoint file not exists, please make sure ckpt file has been saved")
|
||||||
|
else:
|
||||||
|
do_eval_standalone()
|
@ -0,0 +1,72 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the scipt as: "
|
||||||
|
echo "bash scripts/run_distribute_gd.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH"
|
||||||
|
echo "for example: bash scripts/run_distribute_gd.sh 8 40 /path/hccl.json"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "running....... please see details by LOG{}/log.txt"
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
EPOCH_SIZE=$2
|
||||||
|
|
||||||
|
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||||
|
export MINDSPORE_HCCL_CONFIG_PATH=$3
|
||||||
|
export RANK_TABLE_FILE=$3
|
||||||
|
export RANK_SIZE=$1
|
||||||
|
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||||
|
echo "the number of logical core" $cores
|
||||||
|
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
|
||||||
|
core_gap=`expr $avg_core_per_rank \- 1`
|
||||||
|
echo "avg_core_per_rank" $avg_core_per_rank
|
||||||
|
echo "core_gap" $core_gap
|
||||||
|
for((i=0;i<RANK_SIZE;i++))
|
||||||
|
do
|
||||||
|
start=`expr $i \* $avg_core_per_rank`
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEPLOY_MODE=0
|
||||||
|
export GE_USE_STATIC_MEMORY=1
|
||||||
|
end=`expr $start \+ $core_gap`
|
||||||
|
cmdopt=$start"-"$end
|
||||||
|
|
||||||
|
rm -rf LOG$i
|
||||||
|
mkdir ./LOG$i
|
||||||
|
cp *.py ./LOG$i
|
||||||
|
cd ./LOG$i || exit
|
||||||
|
echo "start training for rank $i, device $DEVICE_ID"
|
||||||
|
mkdir -p ms_log
|
||||||
|
CUR_DIR=`pwd`
|
||||||
|
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||||
|
export GLOG_logtostderr=0
|
||||||
|
env > env.log
|
||||||
|
taskset -c $cmdopt python ${PROJECT_DIR}/../run_general_distill.py \
|
||||||
|
--distribute="true" \
|
||||||
|
--device_target="Ascend" \
|
||||||
|
--epoch_size=$EPOCH_SIZE \
|
||||||
|
--device_id=$DEVICE_ID \
|
||||||
|
--device_num=$RANK_SIZE \
|
||||||
|
--enable_data_sink="true" \
|
||||||
|
--data_sink_steps=100 \
|
||||||
|
--save_ckpt_step=100 \
|
||||||
|
--max_ckpt_num=1 \
|
||||||
|
--save_ckpt_path="" \
|
||||||
|
--load_teacher_ckpt_path="" \
|
||||||
|
--data_dir="" \
|
||||||
|
--schema_dir="" > log.txt 2>&1 &
|
||||||
|
cd ../
|
||||||
|
done
|
@ -0,0 +1,42 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the scipt as: "
|
||||||
|
echo "bash scripts/run_standalone_gd.sh"
|
||||||
|
echo "for example: bash scripts/run_standalone_gd.sh"
|
||||||
|
echo "running....... please see details by log.txt"
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
|
||||||
|
mkdir -p ms_log
|
||||||
|
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||||
|
CUR_DIR=`pwd`
|
||||||
|
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||||
|
export GLOG_logtostderr=0
|
||||||
|
python ${PROJECT_DIR}/../run_general_distill.py \
|
||||||
|
--distribute="false" \
|
||||||
|
--device_target="Ascend" \
|
||||||
|
--epoch_size=3 \
|
||||||
|
--device_id=0 \
|
||||||
|
--enable_data_sink="true" \
|
||||||
|
--data_sink_steps=100 \
|
||||||
|
--save_ckpt_step=100 \
|
||||||
|
--max_ckpt_num=1 \
|
||||||
|
--save_ckpt_path="" \
|
||||||
|
--load_teacher_ckpt_path="" \
|
||||||
|
--data_dir="" \
|
||||||
|
--schema_dir="" > log.txt 2>&1 &
|
@ -0,0 +1,47 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the scipt as: "
|
||||||
|
echo "bash scipts/run_standalone_td.sh"
|
||||||
|
echo "for example: bash scipts/run_standalone_td.sh"
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
mkdir -p ms_log
|
||||||
|
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||||
|
CUR_DIR=`pwd`
|
||||||
|
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||||
|
export GLOG_logtostderr=0
|
||||||
|
python ${PROJECT_DIR}/../run_task_distill.py \
|
||||||
|
--device_target="Ascend" \
|
||||||
|
--device_id=0 \
|
||||||
|
--do_train="true" \
|
||||||
|
--do_eval="true" \
|
||||||
|
--td_phase1_epoch_size=10 \
|
||||||
|
--td_phase2_epoch_size=3 \
|
||||||
|
--num_labels=2 \
|
||||||
|
--do_shuffle="true" \
|
||||||
|
--enable_data_sink="true" \
|
||||||
|
--data_sink_steps=100 \
|
||||||
|
--save_ckpt_step=100 \
|
||||||
|
--max_ckpt_num=1 \
|
||||||
|
--load_teacher_ckpt_path="" \
|
||||||
|
--load_gd_ckpt_path="" \
|
||||||
|
--load_td1_ckpt_path="" \
|
||||||
|
--train_data_dir="" \
|
||||||
|
--eval_data_dir="" \
|
||||||
|
--schema_dir="" > log.txt 2>&1 &
|
||||||
|
|
@ -0,0 +1,54 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""assessment methods"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class Accuracy():
|
||||||
|
"""Accuracy"""
|
||||||
|
def __init__(self):
|
||||||
|
self.acc_num = 0
|
||||||
|
self.total_num = 0
|
||||||
|
|
||||||
|
def update(self, logits, labels):
|
||||||
|
labels = labels.asnumpy()
|
||||||
|
labels = np.reshape(labels, -1)
|
||||||
|
logits = logits.asnumpy()
|
||||||
|
logit_id = np.argmax(logits, axis=-1)
|
||||||
|
self.acc_num += np.sum(labels == logit_id)
|
||||||
|
self.total_num += len(labels)
|
||||||
|
|
||||||
|
class F1():
|
||||||
|
"""F1"""
|
||||||
|
def __init__(self):
|
||||||
|
self.TP = 0
|
||||||
|
self.FP = 0
|
||||||
|
self.FN = 0
|
||||||
|
|
||||||
|
def update(self, logits, labels):
|
||||||
|
"""Update F1 score"""
|
||||||
|
labels = labels.asnumpy()
|
||||||
|
labels = np.reshape(labels, -1)
|
||||||
|
logits = logits.asnumpy()
|
||||||
|
logit_id = np.argmax(logits, axis=-1)
|
||||||
|
logit_id = np.reshape(logit_id, -1)
|
||||||
|
pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7])
|
||||||
|
pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7])
|
||||||
|
self.TP += np.sum(pos_eva & pos_label)
|
||||||
|
self.FP += np.sum(pos_eva & (~pos_label))
|
||||||
|
self.FN += np.sum((~pos_eva) & pos_label)
|
||||||
|
print("-----------------precision is ", self.TP / (self.TP + self.FP))
|
||||||
|
print("-----------------recall is ", self.TP / (self.TP + self.FN))
|
@ -0,0 +1,54 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""create tinybert dataset"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset.engine.datasets as de
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
|
||||||
|
do_shuffle="true", data_dir=None, schema_dir=None):
|
||||||
|
"""create tinybert dataset"""
|
||||||
|
files = os.listdir(data_dir)
|
||||||
|
data_files = []
|
||||||
|
for file_name in files:
|
||||||
|
if "record" in file_name:
|
||||||
|
data_files.append(os.path.join(data_dir, file_name))
|
||||||
|
if task == "td":
|
||||||
|
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||||
|
else:
|
||||||
|
columns_list = ["input_ids", "input_mask", "segment_ids"]
|
||||||
|
|
||||||
|
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
|
||||||
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||||
|
shard_equal_rows=True)
|
||||||
|
|
||||||
|
ori_dataset_size = ds.get_dataset_size()
|
||||||
|
print('origin dataset size: ', ori_dataset_size)
|
||||||
|
type_cast_op = C.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||||
|
if task == "td":
|
||||||
|
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
|
||||||
|
# apply batch operations
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||||
|
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||||
|
|
||||||
|
return ds
|
@ -0,0 +1,122 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""fused layernorm"""
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.nn.cell import Cell
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['FusedLayerNorm']
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def get_shape_for_norm(x_shape, begin_norm_axis):
|
||||||
|
print("input_shape: ", x_shape)
|
||||||
|
norm_shape = x_shape[begin_norm_axis:]
|
||||||
|
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
|
||||||
|
print("output_shape: ", output_shape)
|
||||||
|
return output_shape
|
||||||
|
|
||||||
|
class FusedLayerNorm(Cell):
|
||||||
|
r"""
|
||||||
|
Applies Layer Normalization over a mini-batch of inputs.
|
||||||
|
|
||||||
|
Layer normalization is widely used in recurrent neural networks. It applies
|
||||||
|
normalization over a mini-batch of inputs for each single training case as described
|
||||||
|
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
|
||||||
|
normalization, layer normalization performs exactly the same computation at training and
|
||||||
|
testing times. It can be described using the following formula. It is applied across all channels
|
||||||
|
and pixel but only one batch size.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
|
||||||
|
`begin_norm_axis ... R - 1`.
|
||||||
|
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
|
||||||
|
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
|
||||||
|
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
|
||||||
|
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
|
||||||
|
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
|
||||||
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||||
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||||
|
'he_uniform', etc. Default: 'ones'.
|
||||||
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||||
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||||
|
'he_uniform', etc. Default: 'zeros'.
|
||||||
|
use_batch_nrom (bool): Whether use batchnorm to preocess.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
|
||||||
|
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
|
||||||
|
>>> shape1 = x.shape[1:]
|
||||||
|
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
|
||||||
|
>>> m(x)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
normalized_shape,
|
||||||
|
begin_norm_axis=-1,
|
||||||
|
begin_params_axis=-1,
|
||||||
|
gamma_init='ones',
|
||||||
|
beta_init='zeros',
|
||||||
|
use_batch_norm=False):
|
||||||
|
super(FusedLayerNorm, self).__init__()
|
||||||
|
if not isinstance(normalized_shape, (tuple, list)):
|
||||||
|
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
|
||||||
|
.format(normalized_shape, type(normalized_shape)))
|
||||||
|
self.normalized_shape = normalized_shape
|
||||||
|
self.begin_norm_axis = begin_norm_axis
|
||||||
|
self.begin_params_axis = begin_params_axis
|
||||||
|
self.gamma = Parameter(initializer(
|
||||||
|
gamma_init, normalized_shape), name="gamma")
|
||||||
|
self.beta = Parameter(initializer(
|
||||||
|
beta_init, normalized_shape), name="beta")
|
||||||
|
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
|
||||||
|
|
||||||
|
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
"""fusedlayernorm"""
|
||||||
|
if self.use_batch_norm and self.training:
|
||||||
|
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
|
||||||
|
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
|
||||||
|
shape_x = F.shape(input_x)
|
||||||
|
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
|
||||||
|
input_x = F.reshape(input_x, norm_shape)
|
||||||
|
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
|
||||||
|
output = F.reshape(output, shape_x)
|
||||||
|
y = output * self.gamma + self.beta
|
||||||
|
else:
|
||||||
|
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
|
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
|
||||||
|
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
|
||||||
|
return s
|
@ -0,0 +1,81 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in dataset.py, run_general_distill.py and run_task_distill.py
|
||||||
|
"""
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from .tinybert_model import BertConfig
|
||||||
|
|
||||||
|
common_cfg = edict({
|
||||||
|
'loss_scale_value': 2 ** 16,
|
||||||
|
'scale_factor': 2,
|
||||||
|
'scale_window': 1000,
|
||||||
|
'AdamWeightDecay': edict({
|
||||||
|
'learning_rate': 5e-5,
|
||||||
|
'end_learning_rate': 1e-14,
|
||||||
|
'power': 1.0,
|
||||||
|
'weight_decay': 1e-4,
|
||||||
|
'eps': 1e-6,
|
||||||
|
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
'''
|
||||||
|
Including two kinds of network: \
|
||||||
|
teacher network: The BERT-base network.
|
||||||
|
student network: The network which is inherited from teacher network.
|
||||||
|
'''
|
||||||
|
bert_teacher_net_cfg = BertConfig(
|
||||||
|
batch_size=32,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=30522,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16,
|
||||||
|
enable_fused_layernorm=False
|
||||||
|
)
|
||||||
|
bert_student_net_cfg = BertConfig(
|
||||||
|
batch_size=32,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=30522,
|
||||||
|
hidden_size=384,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=1536,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16,
|
||||||
|
enable_fused_layernorm=False
|
||||||
|
)
|
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""config script for task distill"""
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from .tinybert_model import BertConfig
|
||||||
|
|
||||||
|
phase1_cfg = edict({
|
||||||
|
'loss_scale_value': 2 ** 8,
|
||||||
|
'scale_factor': 2,
|
||||||
|
'scale_window': 50,
|
||||||
|
'optimizer_cfg': edict({
|
||||||
|
'AdamWeightDecay': edict({
|
||||||
|
'learning_rate': 5e-5,
|
||||||
|
'end_learning_rate': 1e-14,
|
||||||
|
'power': 1.0,
|
||||||
|
'weight_decay': 1e-4,
|
||||||
|
'eps': 1e-6,
|
||||||
|
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
phase2_cfg = edict({
|
||||||
|
'loss_scale_value': 2 ** 16,
|
||||||
|
'scale_factor': 2,
|
||||||
|
'scale_window': 50,
|
||||||
|
'optimizer_cfg': edict({
|
||||||
|
'AdamWeightDecay': edict({
|
||||||
|
'learning_rate': 2e-5,
|
||||||
|
'end_learning_rate': 1e-14,
|
||||||
|
'power': 1.0,
|
||||||
|
'weight_decay': 1e-4,
|
||||||
|
'eps': 1e-6,
|
||||||
|
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
'''
|
||||||
|
Including two kinds of network: \
|
||||||
|
teacher network: The BERT-base network with finetune.
|
||||||
|
student network: The model which is producted by GD phase.
|
||||||
|
'''
|
||||||
|
td_teacher_net_cfg = BertConfig(
|
||||||
|
batch_size=32,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=30522,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16,
|
||||||
|
enable_fused_layernorm=False
|
||||||
|
)
|
||||||
|
td_student_net_cfg = BertConfig(
|
||||||
|
batch_size=32,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=30522,
|
||||||
|
hidden_size=384,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=1536,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16,
|
||||||
|
enable_fused_layernorm=False
|
||||||
|
)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,140 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""tinybert utils"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
from mindspore.train.serialization import _exec_save_checkpoint
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||||
|
from .assessment_method import Accuracy
|
||||||
|
|
||||||
|
class ModelSaveCkpt(Callback):
|
||||||
|
"""
|
||||||
|
Saves checkpoint.
|
||||||
|
If the loss in NAN or INF terminating training.
|
||||||
|
Args:
|
||||||
|
network (Network): The train network for training.
|
||||||
|
save_ckpt_num (int): The number to save checkpoint, default is 1000.
|
||||||
|
max_ckpt_num (int): The max checkpoint number, default is 3.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, save_ckpt_step, max_ckpt_num, output_dir):
|
||||||
|
super(ModelSaveCkpt, self).__init__()
|
||||||
|
self.count = 0
|
||||||
|
self.network = network
|
||||||
|
self.save_ckpt_step = save_ckpt_step
|
||||||
|
self.max_ckpt_num = max_ckpt_num
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""step end and save ckpt"""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
if cb_params.cur_step_num % self.save_ckpt_step == 0:
|
||||||
|
saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step
|
||||||
|
if saved_ckpt_num > self.max_ckpt_num:
|
||||||
|
oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num
|
||||||
|
path = os.path.join(self.output_dir, "tiny_bert_{}_{}.ckpt".format(int(oldest_ckpt_index),
|
||||||
|
self.save_ckpt_step))
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
_exec_save_checkpoint(self.network, os.path.join(self.output_dir,
|
||||||
|
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
||||||
|
self.save_ckpt_step)))
|
||||||
|
|
||||||
|
class LossCallBack(Callback):
|
||||||
|
"""
|
||||||
|
Monitor the loss in training.
|
||||||
|
If the loss in NAN or INF terminating training.
|
||||||
|
Note:
|
||||||
|
if per_print_times is 0 do not print loss.
|
||||||
|
Args:
|
||||||
|
per_print_times (int): Print loss every times. Default: 1.
|
||||||
|
"""
|
||||||
|
def __init__(self, per_print_times=1):
|
||||||
|
super(LossCallBack, self).__init__()
|
||||||
|
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||||
|
raise ValueError("print_step must be int and >= 0")
|
||||||
|
self._per_print_times = per_print_times
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""step end and print loss"""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num,
|
||||||
|
cb_params.cur_step_num,
|
||||||
|
str(cb_params.net_outputs)))
|
||||||
|
|
||||||
|
class EvalCallBack(Callback):
|
||||||
|
"""Evaluation callback"""
|
||||||
|
def __init__(self, network, dataset):
|
||||||
|
super(EvalCallBack, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.global_acc = 0.0
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""step end and do evaluation"""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
if cb_params.cur_step_num % 100 == 0:
|
||||||
|
callback = Accuracy()
|
||||||
|
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||||
|
for data in self.dataset.create_dict_iterator():
|
||||||
|
input_data = []
|
||||||
|
for i in columns_list:
|
||||||
|
input_data.append(Tensor(data[i]))
|
||||||
|
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||||
|
self.network.set_train(False)
|
||||||
|
logits = self.network(input_ids, token_type_id, input_mask)
|
||||||
|
callback.update(logits[3], label_ids)
|
||||||
|
acc = callback.acc_num / callback.total_num
|
||||||
|
with open("./eval.log", "a+") as f:
|
||||||
|
f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num,
|
||||||
|
callback.acc_num / callback.total_num))
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
if acc > self.global_acc:
|
||||||
|
self.global_acc = acc
|
||||||
|
print("The best acc is {}".format(acc))
|
||||||
|
_exec_save_checkpoint(self.network, "eval_model.ckpt")
|
||||||
|
|
||||||
|
class BertLearningRate(LearningRateSchedule):
|
||||||
|
"""
|
||||||
|
Warmup-decay learning rate for Bert network.
|
||||||
|
"""
|
||||||
|
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||||
|
super(BertLearningRate, self).__init__()
|
||||||
|
self.warmup_flag = False
|
||||||
|
if warmup_steps > 0:
|
||||||
|
self.warmup_flag = True
|
||||||
|
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||||
|
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||||
|
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||||
|
|
||||||
|
self.greater = P.Greater()
|
||||||
|
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
def construct(self, global_step):
|
||||||
|
decay_lr = self.decay_lr(global_step)
|
||||||
|
if self.warmup_flag:
|
||||||
|
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||||
|
warmup_lr = self.warmup_lr(global_step)
|
||||||
|
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||||
|
else:
|
||||||
|
lr = decay_lr
|
||||||
|
return lr
|
Loading…
Reference in new issue