From da142ccfd3dac2d30f9a42de570f45c819cbdd20 Mon Sep 17 00:00:00 2001 From: zongha Date: Fri, 14 Aug 2020 17:05:42 +0800 Subject: [PATCH] modify the codes for bert_thor 2nd for update thor bert --- model_zoo/official/nlp/bert_thor/README.md | 18 +-- .../official/nlp/bert_thor/run_pretrain.py | 13 +- .../scripts/run_distribute_pretrain.sh | 15 +-- .../scripts/run_standalone_pretrain.sh | 52 +++++--- .../bert_thor/src/bert_for_pre_training.py | 6 +- .../official/nlp/bert_thor/src/bert_model.py | 22 ++-- .../official/nlp/bert_thor/src/config.py | 7 +- .../official/nlp/bert_thor/src/dataset.py | 3 +- ..._reducer_thor1.py => grad_reducer_thor.py} | 6 +- .../nlp/bert_thor/src/lr_generator.py | 2 +- .../nlp/bert_thor/src/thor_for_bert.py | 115 +++++------------ .../nlp/bert_thor/src/thor_for_bert_arg.py | 119 +++++------------- .../official/nlp/bert_thor/src/thor_layer.py | 35 +----- 13 files changed, 138 insertions(+), 275 deletions(-) rename model_zoo/official/nlp/bert_thor/src/{grad_reducer_thor1.py => grad_reducer_thor.py} (97%) diff --git a/model_zoo/official/nlp/bert_thor/README.md b/model_zoo/official/nlp/bert_thor/README.md index a3df8b73bb..1949a4c79b 100644 --- a/model_zoo/official/nlp/bert_thor/README.md +++ b/model_zoo/official/nlp/bert_thor/README.md @@ -11,14 +11,14 @@ This is an example of training bert by second-order optimizer THOR. THOR is a no ## Running the Example ### Pre-Training -- Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. +- Set options in `config.py`, including optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. -- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model. +- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base, BERT-NEZHA and BERT-large model. ``` bash sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR ``` -- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model. +- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base, BERT-NEZHA and BERT-large model. ``` bash sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR RANK_TABLE_FILE @@ -30,7 +30,7 @@ This is an example of training bert by second-order optimizer THOR. THOR is a no usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] [--enable_save_ckpt ENABLE_SAVE_CKPT] [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] - [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH] + [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--save_checkpoint_path CHECKPOINT_PATH] [--save_checkpoint_steps N] [--save_checkpoint_num N] [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] @@ -44,7 +44,7 @@ options: --do_shuffle enable shuffle: "true" | "false", default is "true" --enable_data_sink enable data sink: "true" | "false", default is "true" --data_sink_steps set data sink steps: N, default is 1 - --checkpoint_path path to save checkpoint files: PATH, default is "" + --save_checkpoint_path path to save checkpoint files: PATH, default is "" --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 --save_checkpoint_num number for saving checkpoint files: N, default is 1 --data_dir path to dataset directory: PATH, default is "" @@ -55,7 +55,7 @@ It contains of parameters of BERT model and options for training, which is set i ### Options: ``` config.py: - bert_network version of BERT model: base | nezha, default is base + bert_network version of BERT model: base | nezha | large, default is large optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum | Thor, default is "Thor" ``` @@ -63,7 +63,7 @@ config.py: ### Parameters: ``` Parameters for dataset and network (Pre-Training/Evaluation): - batch_size batch size of input dataset: N, default is 8 + batch_size batch size of input dataset: N, default is 12 seq_length length of input sequence: N, default is 128 vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 hidden_size size of bert encoder layers: N, default is 768 @@ -87,7 +87,7 @@ Parameters for optimizer: momentum momentum for the moving average: Q weight_decay weight decay: Q loss_scale loss scale: N - frequency the step interval to update second-order information matrix: N, default is 10 - batch_size batch size of input dataset: N, default is 8 + frequency the step interval to update second-order information matrix: N, default is 100 + batch_size batch size of input dataset: N, default is 12 ``` diff --git a/model_zoo/official/nlp/bert_thor/run_pretrain.py b/model_zoo/official/nlp/bert_thor/run_pretrain.py index 0ec84545db..5c5e1c282f 100644 --- a/model_zoo/official/nlp/bert_thor/run_pretrain.py +++ b/model_zoo/official/nlp/bert_thor/run_pretrain.py @@ -19,7 +19,6 @@ python run_pretrain.py import argparse import os - import numpy from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from src.bert_net_config import bert_net_cfg @@ -27,10 +26,8 @@ from src.config import cfg from src.dataset import create_bert_dataset from src.lr_generator import get_bert_lr, get_bert_damping from src.model_thor import Model -# from src.thor_for_bert import THOR from src.thor_for_bert_arg import THOR from src.utils import LossCallBack, BertLearningRate - import mindspore.common.dtype as mstype import mindspore.communication.management as D from mindspore import context @@ -69,8 +66,8 @@ def run_pretrain(): parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id, - save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, + device_id=args_opt.device_id, save_graphs=False) context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") ckpt_save_dir = args_opt.save_checkpoint_path @@ -165,15 +162,13 @@ def run_pretrain(): optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum, filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()), filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()), - filter(lambda x: 'A_inv_max' in x.name, net_with_loss.get_parameters()), - filter(lambda x: 'G_inv_max' in x.name, net_with_loss.get_parameters()), cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers, bert_net_cfg.batch_size, damping) else: - raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". + raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]". format(cfg.optimizer)) callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()] - if args_opt.enable_save_ckpt == "true": + if args_opt.enable_save_ckpt == "true" and rank == 0: config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) diff --git a/model_zoo/official/nlp/bert_thor/scripts/run_distribute_pretrain.sh b/model_zoo/official/nlp/bert_thor/scripts/run_distribute_pretrain.sh index f82151bea0..3ac2db0206 100644 --- a/model_zoo/official/nlp/bert_thor/scripts/run_distribute_pretrain.sh +++ b/model_zoo/official/nlp/bert_thor/scripts/run_distribute_pretrain.sh @@ -37,25 +37,26 @@ do rm -rf LOG$i mkdir ./LOG$i - cp *.py ./LOG$i - cp -r src ./LOG$i + cp ../*.py ./LOG$i + cp -r ../src ./LOG$i cd ./LOG$i || exit - echo "start training for rank $i, device $DEVICE_ID" + echo "start training for rank $RANK_ID, device $DEVICE_ID" env > env.log - python ../run_pretrain.py \ + python run_pretrain.py \ --distribute="true" \ --epoch_size=$EPOCH_SIZE \ --device_id=$DEVICE_ID \ --device_num=$RANK_SIZE \ --enable_save_ckpt="true" \ --enable_lossscale="false" \ - --do_shuffle="true" \ + --do_shuffle="false" \ --enable_data_sink="true" \ --data_sink_steps=1000 \ --load_checkpoint_path="" \ - --save_checkpoint_steps=5000 \ + --save_checkpoint_path='./' \ + --save_checkpoint_steps=1000 \ --save_checkpoint_num=30 \ --data_dir=$DATA_DIR \ --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & cd ../ -done \ No newline at end of file +done diff --git a/model_zoo/official/nlp/bert_thor/scripts/run_standalone_pretrain.sh b/model_zoo/official/nlp/bert_thor/scripts/run_standalone_pretrain.sh index f59eb69601..35d18c2ad0 100644 --- a/model_zoo/official/nlp/bert_thor/scripts/run_standalone_pretrain.sh +++ b/model_zoo/official/nlp/bert_thor/scripts/run_standalone_pretrain.sh @@ -20,27 +20,39 @@ echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR" echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json" echo "==============================================================================================================" -DEVICE_ID=$1 EPOCH_SIZE=$2 DATA_DIR=$3 SCHEMA_DIR=$4 -mkdir -p ms_log -PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) -CUR_DIR=`pwd` -export GLOG_log_dir=${CUR_DIR}/ms_log -export GLOG_logtostderr=0 -python ${PROJECT_DIR}/../run_pretrain.py \ - --distribute="false" \ - --epoch_size=$EPOCH_SIZE \ - --device_id=$DEVICE_ID \ - --enable_save_ckpt="true" \ - --enable_lossscale="true" \ - --do_shuffle="true" \ - --enable_data_sink="true" \ - --data_sink_steps=1 \ - --load_checkpoint_path="" \ - --save_checkpoint_steps=10000 \ - --save_checkpoint_num=1 \ - --data_dir=$DATA_DIR \ - --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & +ulimit -u unlimited +export DEVICE_ID=$1 +export RANK_SIZE=1 + +if [ -d "LOG" ]; +then + rm -rf ./LOG +fi +mkdir ./LOG +cp ../*.py ./LOG +cp -r ../src ./LOG +cd ./LOG || exit +echo "start training for device $DEVICE_ID" +env > env.log +python run_pretrain.py \ +--distribute="false" \ +--epoch_size=$EPOCH_SIZE \ +--device_id=$DEVICE_ID \ +--device_num=$RANK_SIZE \ +--enable_save_ckpt="true" \ +--enable_lossscale="false" \ +--do_shuffle="false" \ +--enable_data_sink="true" \ +--data_sink_steps=1000 \ +--load_checkpoint_path="" \ +--save_checkpoint_path='./' \ +--save_checkpoint_steps=5000 \ +--save_checkpoint_num=20 \ +--data_dir=$DATA_DIR \ +--schema_dir=$SCHEMA_DIR > log.txt 2>&1 & +cd ../ + diff --git a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py index fb2db14743..7ba00146db 100644 --- a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py @@ -35,6 +35,8 @@ from .thor_layer import Dense_Thor damping = get_bert_damping() loss_scale = cfg.Thor.loss_scale +frequency = cfg.Thor.frequency +batch_size = cfg.Thor.batch_size GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_VALUE = 1.0 @@ -91,9 +93,9 @@ class GetMaskedLMOutput(nn.Cell): bias_init='zeros', damping=damping, loss_scale=loss_scale, - frequency=1, + frequency=frequency, activation=config.hidden_act, - batch_size=config.batch_size).to_float(config.compute_type) + batch_size=batch_size).to_float(config.compute_type) self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) self.output_bias = Parameter( initializer( diff --git a/model_zoo/official/nlp/bert_thor/src/bert_model.py b/model_zoo/official/nlp/bert_thor/src/bert_model.py index 93b5a9169a..91362f1354 100644 --- a/model_zoo/official/nlp/bert_thor/src/bert_model.py +++ b/model_zoo/official/nlp/bert_thor/src/bert_model.py @@ -34,6 +34,7 @@ from .thor_layer import Dense_Thor, Embedding_Thor damping = get_bert_damping() loss_scale = cfg.Thor.loss_scale +frequency = cfg.Thor.frequency batch_size = cfg.Thor.batch_size @@ -200,11 +201,10 @@ class EmbeddingPostprocessor(nn.Cell): use_one_hot_embeddings=use_one_hot_embeddings, initializer_range=initializer_range, name='embedding_table', - is_expand=False, batch_size=batch_size, damping=damping, loss_scale=loss_scale, - frequency=1) + frequency=frequency) self.shape_flat = (-1,) self.one_hot = P.OneHot() self.on_value = Tensor(1.0, mstype.float32) @@ -225,11 +225,10 @@ class EmbeddingPostprocessor(nn.Cell): use_one_hot_embeddings=use_one_hot_embeddings, initializer_range=initializer_range, name='full_position_embeddings', - is_expand=False, batch_size=batch_size, damping=damping, loss_scale=loss_scale, - frequency=1) + frequency=frequency) self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) self.layernorm = nn.LayerNorm((embedding_size,)) @@ -274,7 +273,7 @@ class BertOutput(nn.Cell): bias_init='zeros', damping=damping, loss_scale=loss_scale, - frequency=1, + frequency=frequency, activation=None, batch_size=batch_size).to_float(compute_type) self.dropout = nn.Dropout(1 - dropout_prob) @@ -488,7 +487,7 @@ class BertAttention(nn.Cell): bias_init='zeros', damping=damping, loss_scale=loss_scale, - frequency=1, + frequency=frequency, activation=query_act, batch_size=batch_size).to_float(compute_type) self.key_layer = Dense_Thor(in_channels=to_tensor_width, @@ -498,7 +497,7 @@ class BertAttention(nn.Cell): bias_init='zeros', damping=damping, loss_scale=loss_scale, - frequency=1, + frequency=frequency, activation=key_act, batch_size=batch_size).to_float(compute_type) self.value_layer = Dense_Thor(in_channels=to_tensor_width, @@ -508,7 +507,7 @@ class BertAttention(nn.Cell): bias_init='zeros', damping=damping, loss_scale=loss_scale, - frequency=1, + frequency=frequency, activation=value_act, batch_size=batch_size).to_float(compute_type) self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) @@ -764,7 +763,7 @@ class BertEncoderCell(nn.Cell): bias_init='zeros', damping=damping, loss_scale=loss_scale, - frequency=1, + frequency=frequency, activation=hidden_act, batch_size=batch_size).to_float(compute_type) self.output = BertOutput(in_channels=intermediate_size, @@ -945,11 +944,10 @@ class BertModel(nn.Cell): use_one_hot_embeddings=use_one_hot_embeddings, initializer_range=config.initializer_range, name='embedding_table', - is_expand=True, batch_size=batch_size, damping=damping, loss_scale=loss_scale, - frequency=1) + frequency=frequency) self.bert_embedding_postprocessor = EmbeddingPostprocessor( embedding_size=self.embedding_size, embedding_shape=output_embedding_shape, @@ -991,7 +989,7 @@ class BertModel(nn.Cell): bias_init='zeros', damping=damping, loss_scale=loss_scale, - frequency=1, + frequency=frequency, activation="tanh", batch_size=batch_size).to_float(config.compute_type) self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) diff --git a/model_zoo/official/nlp/bert_thor/src/config.py b/model_zoo/official/nlp/bert_thor/src/config.py index 9c1d5bf725..b17eecd0f6 100644 --- a/model_zoo/official/nlp/bert_thor/src/config.py +++ b/model_zoo/official/nlp/bert_thor/src/config.py @@ -19,9 +19,6 @@ from easydict import EasyDict as edict cfg = edict({ 'bert_network': 'large', - 'loss_scale_value': 65536, - 'scale_factor': 2, - 'scale_window': 1000, 'optimizer': 'Thor', 'AdamWeightDecay': edict({ 'learning_rate': 3e-5, @@ -49,7 +46,7 @@ cfg = edict({ 'momentum': 0.9, 'weight_decay': 5e-4, 'loss_scale': 1, - 'frequency': 10, - 'batch_size': 8, + 'frequency': 100, + 'batch_size': 12, }), }) diff --git a/model_zoo/official/nlp/bert_thor/src/dataset.py b/model_zoo/official/nlp/bert_thor/src/dataset.py index 889e27694a..fee6c97024 100644 --- a/model_zoo/official/nlp/bert_thor/src/dataset.py +++ b/model_zoo/official/nlp/bert_thor/src/dataset.py @@ -16,7 +16,6 @@ Data operations, will be used in run_pretrain.py """ import os - import mindspore.common.dtype as mstype import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C @@ -37,7 +36,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, - num_shards=device_num, shard_id=rank, shard_equal_rows=True) + num_shards=device_num, shard_id=rank, shard_equal_rows=False) ori_dataset_size = ds.get_dataset_size() print('origin dataset size: ', ori_dataset_size) type_cast_op = C.TypeCast(mstype.int32) diff --git a/model_zoo/official/nlp/bert_thor/src/grad_reducer_thor1.py b/model_zoo/official/nlp/bert_thor/src/grad_reducer_thor.py similarity index 97% rename from model_zoo/official/nlp/bert_thor/src/grad_reducer_thor1.py rename to model_zoo/official/nlp/bert_thor/src/grad_reducer_thor.py index 709b0b73df..d0316e99b2 100644 --- a/model_zoo/official/nlp/bert_thor/src/grad_reducer_thor1.py +++ b/model_zoo/official/nlp/bert_thor/src/grad_reducer_thor.py @@ -80,7 +80,7 @@ def _tensors_cast_datatype(datatype, grad): return F.cast(grad, datatype) -class DistributedGradReducerThor1(Cell): +class DistributedGradReducerThor(Cell): """ A distributed optimizer. @@ -154,7 +154,7 @@ class DistributedGradReducerThor1(Cell): """ def __init__(self, parameters, group, mean=True, degree=None): - super(DistributedGradReducerThor1, self).__init__(auto_prefix=False) + super(DistributedGradReducerThor, self).__init__(auto_prefix=False) self.hyper_map = C.HyperMap() self.mul = P.Mul() if degree is None: @@ -168,7 +168,7 @@ class DistributedGradReducerThor1(Cell): _init_optimizer_allreduce(group) def construct(self, grads): - """construct of DistributedGradReducerThor1""" + """construct of DistributedGradReducerThor""" # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, # and cast back after the operation. diff --git a/model_zoo/official/nlp/bert_thor/src/lr_generator.py b/model_zoo/official/nlp/bert_thor/src/lr_generator.py index c2416e9b81..d3ca9f458a 100644 --- a/model_zoo/official/nlp/bert_thor/src/lr_generator.py +++ b/model_zoo/official/nlp/bert_thor/src/lr_generator.py @@ -58,7 +58,7 @@ def get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, # bert kfac hyperparam setting def get_bert_lr(): learning_rate = Tensor( - get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=4e-4, warmup_steps=0, total_steps=30000, + get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.1e-3, warmup_steps=0, total_steps=30000, poly_power=1)) return learning_rate diff --git a/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py b/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py index b9e9c46ab4..60e40d41c0 100644 --- a/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py +++ b/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py @@ -46,9 +46,8 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): class THOR(Optimizer): """THOR""" - - def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, - loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10, + def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0, + loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) if isinstance(momentum, float) and momentum < 0.0: @@ -60,8 +59,6 @@ class THOR(Optimizer): self.opt = P.ApplyMomentum() self.matrix_A = ParameterTuple(matrix_A) self.matrix_G = ParameterTuple(matrix_G) - self.A_inv_max = ParameterTuple(A_inv_max) - self.G_inv_max = ParameterTuple(G_inv_max) self.matmul = P.MatMul() self.transpose = P.Transpose() self.shape = P.Shape() @@ -70,16 +67,8 @@ class THOR(Optimizer): self.gather = P.GatherV2() self.matrix_A_inv = () self.matrix_G_inv = () - self.matrix_max_inv = () self.num_hidden_layers = num_hidden_layers - fc_layer_num = num_hidden_layers * 6 + 5 - for i in range(fc_layer_num): - self.matrix_max_inv = self.matrix_max_inv + ( - Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) - self.log = P.Log() - self.exp = P.Exp() self.sqrt = P.Sqrt() - self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) self.assign = P.Assign() self.cast = P.Cast() self.thor = True @@ -90,7 +79,6 @@ class THOR(Optimizer): self.inv = P.Inv() self.batch_size = batch_size self.damping = damping - self.freq = Tensor(frequency, mstype.int32) self.one = Tensor(1, mstype.int32) self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) @@ -106,26 +94,20 @@ class THOR(Optimizer): g = gradients[em_idx] matrix_idx = em_idx temp_a_ori = self.matrix_A[matrix_idx] - temp_a = self.expand(temp_a_ori, 1) temp_g = self.matrix_G[matrix_idx] - G_max = self.G_inv_max[matrix_idx] - temp_g = self.cast(temp_g, mstype.float32) - matrix_G_inv_max = self.log(G_max) - matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) - matrix_G_inv_max = self.exp(matrix_G_inv_max) - temp_g = self.mul(temp_g, matrix_G_inv_max) - g = self.mul(temp_a, g) - g = self.cast(g, mstype.float16) + temp_a_ori = F.depend(temp_a_ori, g) + temp_g = F.depend(temp_g, g) + temp_a = self.expand(temp_a_ori, 1) + temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + g = self.mul(temp_a, g) g = self.matmul(g, temp_g) g = self.cast(g, mstype.float32) - g = self.mul(g, G_max) fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) - fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max) g = F.depend(g, fake_A) g = F.depend(g, fake_G) - g = F.depend(g, fake_max) new_grads = new_grads + (g,) # process bert_embedding_postprocessor.layernorm grad_idx = 3 @@ -180,32 +162,18 @@ class THOR(Optimizer): matrix_idx = 6 * i + offset_idx + 3 temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = self.cast(temp_a, mstype.float32) - temp_g = self.cast(temp_g, mstype.float32) - matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) - matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) - matrix_A_inv_max = self.exp(matrix_A_inv_max) - temp_a = self.mul(temp_a, matrix_A_inv_max) - matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) - matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) - matrix_G_inv_max = self.exp(matrix_G_inv_max) - temp_g = self.mul(temp_g, matrix_G_inv_max) - temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, temp_max) - fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) - fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) g = F.depend(g, fake_A) g = F.depend(g, fake_G) - g = F.depend(g, fake_max) new_grads = new_grads + (g,) new_grads = new_grads + (gradients[grad_idx + 1],) @@ -216,32 +184,18 @@ class THOR(Optimizer): pooler_bias = gradients[pooler_layer_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = self.cast(temp_a, mstype.float32) - temp_g = self.cast(temp_g, mstype.float32) - matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) - matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) - matrix_A_inv_max = self.exp(matrix_A_inv_max) - temp_a = self.mul(temp_a, matrix_A_inv_max) - matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) - matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) - matrix_G_inv_max = self.exp(matrix_G_inv_max) - temp_g = self.mul(temp_g, matrix_G_inv_max) - temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, temp_max) - fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) - fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) g = F.depend(g, fake_A) g = F.depend(g, fake_G) - g = F.depend(g, fake_max) new_grads = new_grads + (g, pooler_bias) # for cls1 fc layer: mlm @@ -251,38 +205,26 @@ class THOR(Optimizer): mlm_bias = gradients[mlm_fc_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = self.cast(temp_a, mstype.float32) - temp_g = self.cast(temp_g, mstype.float32) - matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) - matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) - matrix_A_inv_max = self.exp(matrix_A_inv_max) - temp_a = self.mul(temp_a, matrix_A_inv_max) - matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) - matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) - matrix_G_inv_max = self.exp(matrix_G_inv_max) - temp_g = self.mul(temp_g, matrix_G_inv_max) - temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, temp_max) - + # add bert.cls1.output_bias grad fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) - fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) g = F.depend(g, fake_A) g = F.depend(g, fake_G) - g = F.depend(g, fake_max) new_grads = new_grads + (gradients[mlm_fc_idx - 1],) new_grads = new_grads + (g, mlm_bias) # add bert.cls1.layernorm grad begin_idx = mlm_fc_idx + 2 end_idx = mlm_fc_idx + 4 new_grads = new_grads + gradients[begin_idx: end_idx] + lenth = len(gradients) new_grads = new_grads + gradients[lenth - 2: lenth] gradients = new_grads @@ -293,15 +235,16 @@ class THOR(Optimizer): g = gradients[em_idx] matrix_idx = em_idx temp_a = self.matrix_A[matrix_idx] - temp_a = self.expand(temp_a, 1) temp_g = self.matrix_G[matrix_idx] - matrix_max = self.matrix_max_inv[matrix_idx] - g = self.mul(temp_a, g) + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) + temp_a = self.expand(temp_a, 1) + temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) + g = self.mul(temp_a, g) g = self.matmul(g, temp_g) g = self.cast(g, mstype.float32) - g = self.mul(g, matrix_max) new_grads = new_grads + (g,) # process bert_embedding_postprocessor.layernorm grad_idx = 3 @@ -356,15 +299,14 @@ class THOR(Optimizer): matrix_idx = 6 * i + offset_idx + 3 temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, matrix_max) new_grads = new_grads + (g,) new_grads = new_grads + (gradients[grad_idx + 1],) @@ -375,15 +317,14 @@ class THOR(Optimizer): pooler_bias = gradients[pooler_layer_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, matrix_max) new_grads = new_grads + (g, pooler_bias) # for cls1 fc layer: mlm @@ -393,15 +334,14 @@ class THOR(Optimizer): mlm_bias = gradients[mlm_fc_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, matrix_max) # add bert.cls1.output_bias grad new_grads = new_grads + (gradients[mlm_fc_idx - 1],) new_grads = new_grads + (g, mlm_bias) @@ -409,6 +349,7 @@ class THOR(Optimizer): begin_idx = mlm_fc_idx + 2 end_idx = mlm_fc_idx + 4 new_grads = new_grads + gradients[begin_idx: end_idx] + lenth = len(gradients) new_grads = new_grads + gradients[lenth - 2: lenth] gradients = new_grads diff --git a/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py b/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py index 8cb56d1f70..aeb3cf309f 100644 --- a/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py +++ b/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py @@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor from mindspore.nn.optim.optimizer import Optimizer from mindspore.ops import functional as F, composite as C, operations as P from mindspore.parallel._utils import _get_device_num, _get_mirror_mean -from .grad_reducer_thor1 import DistributedGradReducerThor1 +from .grad_reducer_thor import DistributedGradReducerThor momentum_opt = C.MultitypeFuncGraph("momentum_opt") @@ -48,9 +48,8 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): class THOR(Optimizer): """THOR""" - - def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, - loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10, + def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0, + loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) if isinstance(momentum, float) and momentum < 0.0: @@ -62,8 +61,6 @@ class THOR(Optimizer): self.opt = P.ApplyMomentum() self.matrix_A = ParameterTuple(matrix_A) self.matrix_G = ParameterTuple(matrix_G) - self.A_inv_max = ParameterTuple(A_inv_max) - self.G_inv_max = ParameterTuple(G_inv_max) self.matmul = P.MatMul() self.transpose = P.Transpose() self.shape = P.Shape() @@ -72,16 +69,8 @@ class THOR(Optimizer): self.gather = P.GatherV2() self.matrix_A_inv = () self.matrix_G_inv = () - self.matrix_max_inv = () self.num_hidden_layers = num_hidden_layers - fc_layer_num = num_hidden_layers * 6 + 5 - for i in range(fc_layer_num): - self.matrix_max_inv = self.matrix_max_inv + ( - Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) - self.log = P.Log() - self.exp = P.Exp() self.sqrt = P.Sqrt() - self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) self.assign = P.Assign() self.cast = P.Cast() self.thor = True @@ -92,12 +81,11 @@ class THOR(Optimizer): self.inv = P.Inv() self.batch_size = batch_size self.damping = damping - self.freq = Tensor(frequency, mstype.int32) self.one = Tensor(1, mstype.int32) self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) mean = _get_mirror_mean() degree = _get_device_num() - self.grad_reducer_g = DistributedGradReducerThor1(self.parameters, 3, mean, degree) + self.grad_reducer_g = DistributedGradReducerThor(self.parameters, 3, mean, degree) def construct(self, gradients): """construct of THOR""" @@ -111,26 +99,20 @@ class THOR(Optimizer): g = gradients[em_idx] matrix_idx = em_idx temp_a_ori = self.matrix_A[matrix_idx] - temp_a = self.expand(temp_a_ori, 1) temp_g = self.matrix_G[matrix_idx] - G_max = self.G_inv_max[matrix_idx] - temp_g = self.cast(temp_g, mstype.float32) - matrix_G_inv_max = self.log(G_max) - matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) - matrix_G_inv_max = self.exp(matrix_G_inv_max) - temp_g = self.mul(temp_g, matrix_G_inv_max) - g = self.mul(temp_a, g) - g = self.cast(g, mstype.float16) + temp_a_ori = F.depend(temp_a_ori, g) + temp_g = F.depend(temp_g, g) + temp_a = self.expand(temp_a_ori, 1) + temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + g = self.mul(temp_a, g) g = self.matmul(g, temp_g) g = self.cast(g, mstype.float32) - g = self.mul(g, G_max) fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) - fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max) g = F.depend(g, fake_A) g = F.depend(g, fake_G) - g = F.depend(g, fake_max) new_grads = new_grads + (g,) # process bert_embedding_postprocessor.layernorm grad_idx = 3 @@ -185,32 +167,18 @@ class THOR(Optimizer): matrix_idx = 6 * i + offset_idx + 3 temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = self.cast(temp_a, mstype.float32) - temp_g = self.cast(temp_g, mstype.float32) - matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) - matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) - matrix_A_inv_max = self.exp(matrix_A_inv_max) - temp_a = self.mul(temp_a, matrix_A_inv_max) - matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) - matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) - matrix_G_inv_max = self.exp(matrix_G_inv_max) - temp_g = self.mul(temp_g, matrix_G_inv_max) - temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, temp_max) - fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) - fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) g = F.depend(g, fake_A) g = F.depend(g, fake_G) - g = F.depend(g, fake_max) new_grads = new_grads + (g,) new_grads = new_grads + (gradients[grad_idx + 1],) @@ -221,32 +189,18 @@ class THOR(Optimizer): pooler_bias = gradients[pooler_layer_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = self.cast(temp_a, mstype.float32) - temp_g = self.cast(temp_g, mstype.float32) - matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) - matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) - matrix_A_inv_max = self.exp(matrix_A_inv_max) - temp_a = self.mul(temp_a, matrix_A_inv_max) - matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) - matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) - matrix_G_inv_max = self.exp(matrix_G_inv_max) - temp_g = self.mul(temp_g, matrix_G_inv_max) - temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, temp_max) - fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) - fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) g = F.depend(g, fake_A) g = F.depend(g, fake_G) - g = F.depend(g, fake_max) new_grads = new_grads + (g, pooler_bias) # for cls1 fc layer: mlm @@ -256,38 +210,26 @@ class THOR(Optimizer): mlm_bias = gradients[mlm_fc_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = self.cast(temp_a, mstype.float32) - temp_g = self.cast(temp_g, mstype.float32) - matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) - matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) - matrix_A_inv_max = self.exp(matrix_A_inv_max) - temp_a = self.mul(temp_a, matrix_A_inv_max) - matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) - matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) - matrix_G_inv_max = self.exp(matrix_G_inv_max) - temp_g = self.mul(temp_g, matrix_G_inv_max) - temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, temp_max) - + # add bert.cls1.output_bias grad fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) - fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) g = F.depend(g, fake_A) g = F.depend(g, fake_G) - g = F.depend(g, fake_max) new_grads = new_grads + (gradients[mlm_fc_idx - 1],) new_grads = new_grads + (g, mlm_bias) # add bert.cls1.layernorm grad begin_idx = mlm_fc_idx + 2 end_idx = mlm_fc_idx + 4 new_grads = new_grads + gradients[begin_idx: end_idx] + lenth = len(gradients) new_grads = new_grads + gradients[lenth - 2: lenth] gradients = new_grads @@ -299,15 +241,16 @@ class THOR(Optimizer): g = gradients[em_idx] matrix_idx = em_idx temp_a = self.matrix_A[matrix_idx] - temp_a = self.expand(temp_a, 1) temp_g = self.matrix_G[matrix_idx] - matrix_max = self.matrix_max_inv[matrix_idx] - g = self.mul(temp_a, g) + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) + temp_a = self.expand(temp_a, 1) + temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) + g = self.mul(temp_a, g) g = self.matmul(g, temp_g) g = self.cast(g, mstype.float32) - g = self.mul(g, matrix_max) new_grads = new_grads + (g,) # process bert_embedding_postprocessor.layernorm grad_idx = 3 @@ -362,15 +305,14 @@ class THOR(Optimizer): matrix_idx = 6 * i + offset_idx + 3 temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, matrix_max) new_grads = new_grads + (g,) new_grads = new_grads + (gradients[grad_idx + 1],) @@ -381,15 +323,14 @@ class THOR(Optimizer): pooler_bias = gradients[pooler_layer_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, matrix_max) new_grads = new_grads + (g, pooler_bias) # for cls1 fc layer: mlm @@ -399,15 +340,14 @@ class THOR(Optimizer): mlm_bias = gradients[mlm_fc_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = F.depend(temp_a, g) + temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) - g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) - g = self.mul(g, matrix_max) # add bert.cls1.output_bias grad new_grads = new_grads + (gradients[mlm_fc_idx - 1],) new_grads = new_grads + (g, mlm_bias) @@ -415,6 +355,7 @@ class THOR(Optimizer): begin_idx = mlm_fc_idx + 2 end_idx = mlm_fc_idx + 4 new_grads = new_grads + gradients[begin_idx: end_idx] + lenth = len(gradients) new_grads = new_grads + gradients[lenth - 2: lenth] gradients = new_grads diff --git a/model_zoo/official/nlp/bert_thor/src/thor_layer.py b/model_zoo/official/nlp/bert_thor/src/thor_layer.py index 8f9e0c0759..dbe3821f53 100644 --- a/model_zoo/official/nlp/bert_thor/src/thor_layer.py +++ b/model_zoo/official/nlp/bert_thor/src/thor_layer.py @@ -14,7 +14,6 @@ # ============================================================================ """thor_layer""" import numpy as np - import mindspore.common.dtype as mstype from mindspore._checkparam import check_bool, check_int_positive from mindspore.common.initializer import TruncatedNormal, initializer @@ -24,7 +23,6 @@ from mindspore.nn.cell import Cell from mindspore.nn.layer.activation import get_activation from mindspore.ops import operations as P - class Embedding_Thor(Cell): """ A embeddings lookup table with a fixed dictionary and size. @@ -37,7 +35,6 @@ class Embedding_Thor(Cell): use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. """ - def __init__(self, vocab_size, embedding_size, @@ -45,11 +42,10 @@ class Embedding_Thor(Cell): use_one_hot_embeddings=False, initializer_range=0.02, name='embedding_table', - is_expand=False, batch_size=12, damping=0.03, loss_scale=1, - frequency=10, + frequency=100, ): super(Embedding_Thor, self).__init__() self.vocab_size = vocab_size @@ -59,7 +55,6 @@ class Embedding_Thor(Cell): [vocab_size, embedding_size]), name=name) self.thor = True - self.is_expand = is_expand self.expand = P.ExpandDims() self.shape_flat = (-1,) self.gather = P.GatherV2() @@ -71,13 +66,11 @@ class Embedding_Thor(Cell): self.em_shape = tuple(embedding_shape) self.shape = P.Shape() self.loss_scale = Tensor(1 / loss_scale, mstype.float16) - self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), name='matrix_A_inv', - requires_grad=False) + + self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float16)), + name='matrix_A_inv', requires_grad=False) self.matrix_G_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)), name="matrix_G_inv", requires_grad=False) - self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) - self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) - self.fused_abs_max = P.CusFusedAbsMax1() self.fake_G = Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)) self.dampingA = Tensor(np.ones([vocab_size]).astype(np.float32)) self.dampingG = Tensor(np.identity(embedding_size), mstype.float32) @@ -117,9 +110,6 @@ class Embedding_Thor(Cell): matrix_G = matrix_G + damping * dampingG matrix_G_inv = self.cholesky(matrix_G) matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) - matrix_G_inv_max = self.fused_abs_max(matrix_G_inv) - matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max) - self.G_inv_max = matrix_G_inv_max matrix_G_inv = self.matrix_combine(matrix_G_inv) matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) self.matrix_G_inv = matrix_G_inv @@ -127,8 +117,6 @@ class Embedding_Thor(Cell): def construct(self, input_ids): """construct of Embedding_Thor""" - if self.is_expand: - input_ids = self.expand(input_ids, -1) flat_ids = self.reshape(input_ids, self.shape_flat) if self.use_one_hot_embeddings: one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) @@ -146,6 +134,7 @@ class Embedding_Thor(Cell): dampingA = self.cast(self.dampingA, mstype.float32) matrix_A = matrix_A + damping * dampingA matrix_A_inv = self.inv(matrix_A) + matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) self.matrix_A_inv = matrix_A_inv self.matrix_G_inv = self.fake_G output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) @@ -156,11 +145,9 @@ class Embedding_Thor(Cell): output = self.reshape(output_for_reshape, self.em_shape) return output, self.embedding_table - class Dense_Thor(Cell): """Dense_Thor""" - # @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels']) def __init__(self, in_channels, out_channels, @@ -168,7 +155,7 @@ class Dense_Thor(Cell): bias_init='zeros', damping=0.03, loss_scale=1, - frequency=10, + frequency=100, has_bias=False, activation=None, batch_size=12): @@ -200,9 +187,6 @@ class Dense_Thor(Cell): name='matrix_A_inv', requires_grad=False) self.matrix_G_inv = Parameter(Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)), name="matrix_G_inv", requires_grad=False) - self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) - self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) - self.fused_abs_max = P.CusFusedAbsMax1() self.fake_G = Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)) self.matmul = P.MatMul(transpose_b=True) @@ -250,9 +234,6 @@ class Dense_Thor(Cell): matrix_G = matrix_G + damping * dampingG matrix_G_inv = self.cholesky(matrix_G) matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) - matrix_G_inv_max = self.fused_abs_max(matrix_G_inv) - matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max) - self.G_inv_max = matrix_G_inv_max matrix_G_inv = self.matrix_combine(matrix_G_inv) matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) self.matrix_G_inv = matrix_G_inv @@ -265,7 +246,6 @@ class Dense_Thor(Cell): shape = self.shape(x) normalizer = self.cast(shape[0], mstype.float32) matrix_A = self.mul(inputs, 1.0 / normalizer) - damping_step = self.gather(self.damping, self.cov_step, self.axis) damping_step = self.cast(damping_step, mstype.float32) damping = self.sqrt(damping_step) @@ -273,9 +253,6 @@ class Dense_Thor(Cell): matrix_A = matrix_A + damping * dampingA matrix_A_inv = self.cholesky(matrix_A) matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv) - matrix_A_inv_max = self.fused_abs_max(matrix_A_inv) - matrix_A_inv_max = self.fused_abs_max(matrix_A_inv_max) - self.A_inv_max = matrix_A_inv_max matrix_A_inv = self.matrix_combine(matrix_A_inv) matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) self.matrix_A_inv = matrix_A_inv