diff --git a/model_zoo/official/nlp/gpt/README.md b/model_zoo/official/nlp/gpt/README.md new file mode 100644 index 0000000000..f3677ce77a --- /dev/null +++ b/model_zoo/official/nlp/gpt/README.md @@ -0,0 +1,76 @@ +# It is still under development. +# Contents +- [Contents](#contents) +- [GPT Description](#bert-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [GPT Description](#contents) +The GPT network was proposed by OpenAI and it has three versions, i.e., GPT, GPT2 and GPT3. The newest version GPT3 was proposed in Jul 2020 and it is quite a large language model with 175 billion parameters. Stacking many Decoder structure of Transformer and feeding massive amount of training data, GPT3 becomes such a powerful language model that no fine-tuning process is needed. As the papre title says, language models are few-shot learners, GPT3 proves that with a large and well-trained model, we can achieve a similar performance compared to those of fine-tuning methods. + + +[Paper](https://arxiv.org/abs/2005.14165): Tom B.Brown, Benjamin Mann, Nick Ryder et al. [Language Models are Few-Shot Learners]((https://arxiv.org/abs/2005.14165)). arXiv preprint arXiv:2005.14165 + + +# [Model Architecture](#contents) +GPT3 stacks many layers of decoder of transformer. According to the layer numbers and embedding size, GPT3 has several versions. The largest model contains 96 layers with embedding size of 12288 resulting to a total parameter of 175 billion. + +# [Dataset](#contents) +- OpenWebText is utilized as the training data and the training objective is to predict the next token at each position. + +# [Environment Requirements](#contents) +- Hardware(Ascend) + - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get access to the resources. +- Framework + - [MindSpore](https://gitee.com/mindspore/mindspore) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# [Quick Start](#contents) +After installing MindSpore via the official website, you can start training and evaluation as follows: +```bash +# run standalone training example +bash scripts/run_standalone_train.sh 0 10 /path/dataset + +# run distributed training example +bash scripts/run_distribute_training.sh /path/dataset /path/hccl.json 8 + +# run evaluation example, now only accuracy and perplexity for lambada and wikitext103 are supported +bash scripts/run_evaluation.sh lambada /your/ckpt /your/data acc +``` + +For distributed training, an hccl configuration file with JSON format needs to be created in advance. +Please follow the instructions in the link below: +https:gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools. + +# [Script Description](#contents) + +## [Script and Sample Code](#contents) + +```shell +. +└─gpt + ├─README.md + ├─scripts + ├─run_standalone_train.sh # shell script for standalone training on ascend + ├─run_distribut_train.sh # shell script for distributed training on ascend + └─run_evaluation.sh # shell script for evaluation of ascend + ├─src + ├─gpt_wrapper.py # backbone code of network + ├─gpt.py # backbone code of network + ├─dataset.py # data preprocessing + ├─inference.py # evaluation function + ├─utils.py # util function + ├─train.py # train net for training phase + └─eval.py # eval net for evaluation +``` + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/nlp/gpt/eval.py b/model_zoo/official/nlp/gpt/eval.py new file mode 100644 index 0000000000..a6426cf4c4 --- /dev/null +++ b/model_zoo/official/nlp/gpt/eval.py @@ -0,0 +1,155 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +GPT evaluation script. +""" + +import math +import argparse +import numpy as np +from mindspore import context +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.inference import generate +from src.dataset import create_dataset +from src.gpt import GPT, EvalNet, GPTWithLoss, CrossEntropyLoss +from src.utils import GPTConfig + +context.set_context(mode=context.GRAPH_MODE) + +def ppl_score(probs, length, is_logsoftmax=True): + """ calculate perplexity with prob or log_prob inputs """ + probs = probs[:length] + if is_logsoftmax: + prob = np.sum(probs) / length + ppl = 1.0 / np.power(np.e, prob) + else: + prob = 1.0 + for p in probs: + prob *= (1.0 / p) + ppl = np.power(prob, 1.0/length) + return ppl + +def get_ppl(model, dataset): + """ calculate perplexity for input dataset """ + PPL = [] + tokens = 0 + for data in dataset: + data = data[0].asnumpy() + input_ids = data + + logits = model(Tensor(input_ids, mstype.int32)).asnumpy() + PPL.append(logits * len(data)) + tokens += len(data) + + val_loss = sum(PPL) / tokens + ppl = math.exp(min(20, val_loss)) + return ppl + +def get_acc(model, dataset): + """ calculate accuracy for input dataset """ + total_num = 0 + acc_num = 0 + for data in dataset: + data = data[0].asnumpy() + input_mask = (data != 0).astype(np.int32) + length = np.sum(input_mask, 1) + label = np.zeros(length.shape) + for i, idx in enumerate(length): + label[i] = data[i][idx-1] + input_mask[i][idx-1] = 0 + data[i][idx-1] = 0 + + length = np.sum(data != 50256, 1) + input_ids = data + logits = model(Tensor(input_ids, mstype.int32)).asnumpy() + logits = logits.reshape(len(length), -1) + + predicted_label = np.zeros(length.shape) + for i, idx in enumerate(length): + predicted_label[i] = logits[i][idx-2] + + total_num += len(label) + acc_num += sum(label == predicted_label) + + acc = acc_num / total_num + return acc + + +def run_eval(): + """ evaluate scripts """ + parser = argparse.ArgumentParser(description="GPT inferencing") + parser.add_argument('--task_type', type=str, default="", help="Evaluation task.") + parser.add_argument('--metrics', type=str, default="acc", choices=["ppl", "acc"], help="Evaluation metrics.") + parser.add_argument('--ckpt_path', type=str, default="", help="path of checkpoint file.") + parser.add_argument('--data_path', type=str, default="", help="path of MindRecord file.") + + args = parser.parse_args() + task = args.task_type + metrics = args.metrics + ckpt_path = args.ckpt_path + if task not in ["generate", "lambada", "wikitext"]: + raise ValueError("{} is not supported now".format(task)) + + if metrics not in ["acc", "ppl"]: + raise ValueError("{} is not supported now".format(metrics)) + + + config = GPTConfig(batch_size=16, + seq_length=1024, + vocab_size=50257, + embedding_size=1024, + num_layers=24, + num_heads=16, + expand_ratio=4, + post_layernorm_residual=False, + dropout_rate=0.0, + compute_dtype=mstype.float16, + use_past=False) + + ckpt_dict = load_checkpoint(ckpt_path) + + gpt = GPT(config) + if task == "generate": + gpt_eval = EvalNet(gpt, generate=True) + elif metrics == "acc": + gpt_eval = EvalNet(gpt, generate=False) + else: + loss = CrossEntropyLoss(config) + gpt_eval = GPTWithLoss(gpt, loss) + + gpt_eval.set_train(False) + load_param_into_net(gpt_eval, ckpt_dict) + + if task == "generate": + start_sentence = [6170, 318, 257] + input_ids = np.array(start_sentence).reshape(1, -1) + outputs = generate(gpt_eval, input_ids, config.seq_length) + output_list = outputs.tolist() + print("output id is ", output_list) + else: + data_path = args.data_path + eval_dataset = create_dataset(config.batch_size, data_path=data_path, drop=False) + if metrics == "acc": + acc = get_acc(gpt_eval, eval_dataset) + print("Accuracy is ", acc) + elif metrics == "ppl": + ppl = get_ppl(gpt_eval, eval_dataset) + print("Perplexity is ", ppl) + +if __name__ == "__main__": + run_eval() diff --git a/model_zoo/official/nlp/gpt/scripts/run_distribute_train.sh b/model_zoo/official/nlp/gpt/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..a97b8aeaaa --- /dev/null +++ b/model_zoo/official/nlp/gpt/scripts/run_distribute_train.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash run_distributed_pretrain_ascend.sh DATA_DIR RANK_TABLE_FILE DEVICE_NUM" +echo "for example: bash run_distributed_pretrain_ascend.sh /path/dataset /path/hccl.json 8" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +ROOT_PATH='pwd' +DATA_DIR=$1 +export RANK_TABLE_FILE=$2 +RANK_SIZE=$3 + + +for((i=0;i<=${RANK_SIZE};i++)); +do + rm ${ROOT_PATH}/device$i/ -rf + mkdir ${ROOT_PATH}/device$i + cd ${ROOT_PATH}/device$i || exit + export RANK_ID=$i + export DEVICE_ID=$i + python ${ROOT_PATH}/train.py --distribute=true --device_num=$RANK_SIZE --data_path=$DATA_DIR >log$i.log 2>&1 & +done diff --git a/model_zoo/official/nlp/gpt/scripts/run_evaluation.sh b/model_zoo/official/nlp/gpt/scripts/run_evaluation.sh new file mode 100644 index 0000000000..d750484ae3 --- /dev/null +++ b/model_zoo/official/nlp/gpt/scripts/run_evaluation.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_evaluation.sh TASK_TYPE CKPT_PATH DATA_PATH METRICS" +echo "for example: bash scripts/run_evaluation.sh lambada /your/ckpt /your/data acc" +echo "==============================================================================================================" + + +TASK_TYPE=$1 +CKPT_PATH=$2 +DATA_PATH=$3 +METRICS=$4 +python eval.py \ + --task_type=$TASK_TYPE \ + --ckpt_path=$CKPT_PATH \ + --data_path=$DATA_PATH \ + --metrics=$METRICS + diff --git a/model_zoo/official/nlp/gpt/scripts/run_standalone_train.sh b/model_zoo/official/nlp/gpt/scripts/run_standalone_train.sh new file mode 100644 index 0000000000..03dbb91b9c --- /dev/null +++ b/model_zoo/official/nlp/gpt/scripts/run_standalone_train.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash run_standalone_pretrain_ascend.sh DEVICE_ID EPOCH_SIZE DATA_DIR" +echo "for example: bash run_standalone_pretrain_ascend.sh 0 40 /path/zh-wiki/" +echo "==============================================================================================================" + +DEVICE_ID=$1 +EPOCH_SIZE=$2 +DATA_DIR=$3 + + +python train.py \ + --distribute="false" \ + --epoch_size=$EPOCH_SIZE \ + --device_id=$DEVICE_ID \ + --data_path=$DATA_DIR \ + --optimizer="adam" > training_log.txt 2>&1 & diff --git a/model_zoo/official/nlp/gpt/src/dataset.py b/model_zoo/official/nlp/gpt/src/dataset.py new file mode 100644 index 0000000000..7088443895 --- /dev/null +++ b/model_zoo/official/nlp/gpt/src/dataset.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Create dataset for training and evaluting +""" + +import os +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.common.dtype as mstype + + +def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True): + """ + Create dataset + + Inputs: + batch_size: batch size + data_path: path of your MindRecord files + device_num: total device number + rank: current rank id + drop: whether drop remainder + + Returns: + dataset: the dataset for training or evaluating + """ + home_path = os.path.join(os.getcwd(), data_path) + data = [os.path.join(home_path, name) for name in os.listdir(data_path) if name.endswith("mindrecord")] + print(data) + dataset = ds.MindDataset(data, columns_list=["input_ids"], shuffle=True, num_shards=device_num, shard_id=rank) + type_cast_op = C.TypeCast(mstype.int32) + dataset = dataset.map(input_columns="input_ids", operations=type_cast_op) + dataset = dataset.batch(batch_size, drop_remainder=drop) + dataset = dataset.repeat(1) + return dataset diff --git a/model_zoo/official/nlp/gpt/src/gpt.py b/model_zoo/official/nlp/gpt/src/gpt.py new file mode 100644 index 0000000000..31537e4e15 --- /dev/null +++ b/model_zoo/official/nlp/gpt/src/gpt.py @@ -0,0 +1,545 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""GPT model""" + +import math +import numpy as np +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +import mindspore.common.dtype as mstype +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +class Mapping(nn.Cell): + """ + A mapping function with a 3d input + + Args: + input_size: the size of the last dimension of the input tensor + output_size: the desired size of the last dimension of the output tensor + dtype: the compute datatype + scale: the scale factor for initialization + + Inputs: + x: the 3d input + + Returns: + output: Tensor, a 3d tensor after projection + """ + def __init__(self, input_size, output_size, dtype, scale=1.0): + super(Mapping, self).__init__() + self.output_size = output_size + self.input_size = input_size + weight = np.random.normal(loc=0.0, scale=0.02*scale, size=(input_size, output_size)) + bias = np.zeros(shape=(output_size,)) + self.weight = Parameter(Tensor(weight, mstype.float32), name="mapping_weight") + self.bias = Parameter(Tensor(bias, mstype.float32), name="mapping_bias") + self.dtype = dtype + self.cast = P.Cast() + + def construct(self, x): + out_shape = P.Shape()(x)[:-1] + (self.output_size,) + x = P.Reshape()(x, (-1, self.input_size)) + x = nn.MatMul()(x, self.cast(self.weight, self.dtype)) + self.cast(self.bias, self.dtype) + output = P.Reshape()(x, out_shape) + return output + + + +class Output(nn.Cell): + """ + The output mapping module for each layer + + Args: + config(GPTConfig): the config of network + scale: scale factor for initialization + + Inputs: + x: output of the self-attention module + + Returns: + output: Tensor, the output of this layer after mapping + """ + def __init__(self, config, scale=1.0): + super(Output, self).__init__() + input_size = config.embedding_size + output_size = config.embedding_size*config.expand_ratio + self.mapping = Mapping(input_size, output_size, config.compute_dtype) + self.projection = Mapping(output_size, input_size, config.compute_dtype, scale) + self.activation = nn.GELU() + self.dropout = nn.Dropout(1-config.dropout_rate) + + def construct(self, x): + hidden = self.activation(self.mapping(x)) + output = self.projection(hidden) + output = self.dropout(output) + return output + +class AttentionMask(nn.Cell): + """ + Get the attention matrix for self-attention module + + Args: + config(GPTConfig): the config of network + + Inputs: + input_mask: the mask indicating whether each position is a valid input + + Returns: + attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length) + """ + def __init__(self, config): + super(AttentionMask, self).__init__() + self.reshape = P.Reshape() + self.mul = P.BatchMatMul() + ones = np.ones(shape=(config.seq_length, config.seq_length)) + self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32) + self.multiply = P.Mul() + + + def construct(self, input_mask): + input_shape = P.Shape()(input_mask) + shape_right = (input_shape[0], 1, input_shape[1]) + shape_left = input_shape + (1,) + mask_left = self.reshape(input_mask, shape_left) + mask_right = self.reshape(input_mask, shape_right) + attention_mask = self.mul(mask_left, mask_right) + lower_traiangle = P.ExpandDims()(self.lower_triangle_mask, 0) + attention_mask = self.multiply(attention_mask, lower_traiangle) #bs seq_length seq_length + return attention_mask + +class EmbeddingLookup(nn.Cell): + """ + The embedding lookup table for vocabulary + + Args: + config(GPTConfig): the config of network + + Inputs: + input_ids: the tokenized inputs with datatype int32 + + Returns: + output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size) + self.embedding_table: Tensor, the embedding table for the vocabulary + """ + def __init__(self, config): + super(EmbeddingLookup, self).__init__() + self.vocab_size = config.vocab_size + self.embedding_size = config.embedding_size + self.embedding_table = Parameter(initializer(TruncatedNormal(0.02), [self.vocab_size, self.embedding_size]), + name="embedding_table") + self.gather = P.GatherV2() + self.shape = (-1, config.seq_length, config.embedding_size) + def construct(self, input_ids): + output = self.gather(self.embedding_table, input_ids, 0) + return output, self.embedding_table + + +class Attention(nn.Cell): + """ + Self-Attention module for each layer + + Args: + config(GPTConfig): the config of network + scale: scale factor for initialization + layer_idx: current layer index + """ + def __init__(self, config, scale=1.0, layer_idx=None): + super(Attention, self).__init__() + self.get_attention_mask = AttentionMask(config) + self.expand_mapping = Mapping(config.embedding_size, 3*config.embedding_size, config.compute_dtype) + self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale) + self.split = P.Split(axis=-1, output_num=3) + self.transpose = P.Transpose() + self.reshape = P.Reshape() + self.n_head = config.num_heads + self.size_per_head = config.embedding_size // self.n_head + self.concat_k = P.Concat(axis=3) + self.concat_v = P.Concat(axis=2) + self.multiply_data = Tensor([-10000.0,], dtype=mstype.float32) + self.batch_matmul = P.BatchMatMul() + self.scale = scale + if self.scale: + self.scale_factor = Tensor(math.sqrt(self.size_per_head)) + if layer_idx is not None: + self.coeff = math.sqrt(layer_idx * math.sqrt(self.size_per_head)) + self.coeff = Tensor(self.coeff) + self.use_past = config.use_past + self.dropout = nn.Dropout(1-config.dropout_rate) + self.prob_dropout = nn.Dropout(1-config.dropout_rate) + self.softmax = nn.Softmax() + + self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype) + self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype) + self.dense3 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype) + + def construct(self, x, attention_mask, layer_past=None): + """ + self-attention + + Inputs: + x: output of previous layer + attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length) + layer_past: the previous feature map + + Returns: + output: Tensor, the output logit of this layer + layer_present: Tensor, the feature map of current layer + """ + + original_shape = F.shape(x) + x = F.reshape(x, (-1, original_shape[-1])) + query = self.dense1(x) + key = self.dense2(x) + value = self.dense3(x) + query = self.transpose(F.reshape(query, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) + key = self.transpose(F.reshape(key, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 3, 1)) + value = self.transpose(F.reshape(value, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) + if self.use_past: + past_value = layer_past[1] + past_key = self.transpose(layer_past[0], (0, 1, 3, 2)) + key = self.concat_k((past_key, key)) + value = self.concat_v(past_value, value) + layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value]) + attention = self._attn(query, key, value, attention_mask) + attention_merge = self.merge_heads(attention) + output = self.projection(attention_merge) + output = self.dropout(output) + return output, layer_present + + def split_heads(self, x, transpose): + """ + split 3d tensor to 4d and switch certain axes + + Inputs: + x: input tensor + transpose: tuple, the transpose sequence + + Returns: + x_transpose: the 4d output + """ + x_size = P.Shape()(x) + new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head) + x = self.reshape(x, new_x_shape) + x_transpose = self.transpose(x, transpose) + return x_transpose + + def merge_heads(self, x): + """ + convert a 4d input to a 3d output + + Inputs: + x: input tensor + + Returns: + x_merge: the 3d output + """ + x = self.transpose(x, (0, 2, 1, 3)) #bs, seq_length, head, size_per_head + x_shape = P.Shape()(x) + new_shape = x_shape[:-2] + (x_shape[-2]*x_shape[-1],) + x_merge = self.reshape(x, new_shape) + return x_merge + + def _attn(self, query, key, value, attention_mask): + """ + Get the weighted score along the seq_length + + Inputs: + query: the query matrix + key: the key matrix + value: the value matrix + attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length) + + Returns: + weighted_values: Tensor, the weighted sum scores + """ + if not self.scale: + query = query / F.cast(self.coeff, F.dtype(query)) + key = key / F.cast(self.coeff, F.dtype(key)) + + score = self.batch_matmul(query, key) + if self.scale: + score = score / P.Cast()(self.scale_factor, P.DType()(score)) + + ori_dtype = P.DType()(score) + score = P.Cast()(score, mstype.float32) + multiplu_out = P.Sub()(P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)), + P.Cast()(attention_mask, P.DType()(score))) + + adder = P.Mul()(multiplu_out, self.multiply_data) + attention_scores = adder + score + + attention_scores = P.Cast()(attention_scores, ori_dtype) + shape = F.shape(attention_scores) + attention_probs = nn.Softmax()(F.reshape(attention_scores, (-1, shape[-1]))) + attention_probs = F.reshape(attention_probs, shape) + + attention_probs = self.prob_dropout(attention_probs) + weighted_values = self.batch_matmul(attention_probs, value) + return weighted_values + +class Block(nn.Cell): + """ + The basic block of GPT network + + Args: + config(GPTConfig): the config of network + layer_idx: current layer index + + Inputs: + x: the output of previous layer(input_ids for the first layer) + attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length) + layer_past: the previous feature map + + Returns: + output: Tensor, the output logit of this layer + layer_present: Tensor, the feature map of current layer + """ + def __init__(self, config, layer_idx): + super(Block, self).__init__() + scale = 1 / math.sqrt(2.0*layer_idx) + self.layernorm1 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) + self.attention = Attention(config, scale, layer_idx) + self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) + self.output = Output(config, scale) + self.post_layernorm_residual = config.post_layernorm_residual + + def construct(self, x, attention_mask, layer_past=None): + """basic block of each layer""" + input_x = self.layernorm1(x) + attention, layer_present = self.attention(input_x, attention_mask, layer_past) + if self.post_layernorm_residual: + x = input_x + attention + else: + x = x + attention + + output_x = self.layernorm2(x) + mlp_logit = self.output(output_x) + if self.post_layernorm_residual: + output = output_x + mlp_logit + else: + output = x + mlp_logit + return output, layer_present + +class GPT_Model(nn.Cell): + """ + The backbone of GPT network + + Args: + config(GPTConfig): the config of network + + Inputs: + input_ids: the tokenized inputs with datatype int32 + input_mask: the mask indicating whether each position is a valid input + layer_past: the previous feature map + + Returns: + output_state: Tensor, the output logit of backbone + present_layer: Tensor, the current feature map + embedding_table: Tensor, the embedding table for the vocabulary + """ + def __init__(self, config): + super(GPT_Model, self).__init__() + self.get_attention_mask = AttentionMask(config) + self.word_embedding = EmbeddingLookup(config) + self.position_embedding = nn.Embedding(config.seq_length, config.embedding_size, + embedding_table=TruncatedNormal(0.02)) + self.blocks = nn.CellList() + for i in range(config.num_layers): + self.blocks.append(Block(config, i+1)) + self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) + self.use_past = config.use_past + self.past = tuple([None]*config.num_layers) + self.num_layers = config.num_layers + + def construct(self, input_ids, input_mask, layer_past=None): + """GPT model""" + if not self.use_past: + layer_past = self.past + + input_embedding, embedding_table = self.word_embedding(input_ids) + + batch_size, seq_length = F.shape(input_ids) + input_position = F.tuple_to_array(F.make_range(seq_length)) + input_position = P.Tile()(input_position, (batch_size, 1)) + + + position_embedding = self.position_embedding(input_position) + hidden_states = input_embedding + position_embedding + + hidden_states = P.Cast()(hidden_states, mstype.float16) + attention_mask = self.get_attention_mask(input_mask) + attention_mask = P.ExpandDims()(attention_mask, 1) + + present_layer = () + for i in range(self.num_layers): + hidden_states, present = self.blocks[i](hidden_states, attention_mask, layer_past) + present_layer = present_layer + (present,) + + output_state = self.layernorm(hidden_states) + return output_state, present_layer, embedding_table + +class GPT_Head(nn.Cell): + """ + Head for GPT to get the logits of each token in the vocab + + Args: + config(GPTConfig): the config of network + + Inputs: + state: the output of the backbone + embedding_table: the embedding table of the vocabulary + + Returns: + logits: Tensor, the logits of the corresponding inputs + """ + def __init__(self, config): + super(GPT_Head, self).__init__() + self.matmul = P.MatMul(transpose_b=True) + self.embedding_size = config.embedding_size + self.log_softmax = P.LogSoftmax(axis=-1) + self.dtype = config.compute_dtype + self.cast = P.Cast() + + def construct(self, state, embedding_table): + state = P.Reshape()(state, (-1, self.embedding_size)) + logits = self.matmul(state, self.cast(embedding_table, self.dtype)) + return logits + +class GPT(nn.Cell): + """ + The GPT network consisting of two parts the backbone and the head + + Args: + config(GPTConfig): the config of network + + Inputs: + input_ids: the tokenized inputs + input_mask: the mask indicating whether each position is a valid input + past: the previous feature map + + Returns: + logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size) + """ + def __init__(self, config): + super(GPT, self).__init__() + self.backbone = GPT_Model(config) + self.head = GPT_Head(config) + + def construct(self, input_ids, input_mask, past=None): + output_states, _, embedding_table = self.backbone(input_ids, input_mask, past) + logits = self.head(output_states, embedding_table) + return logits + +class CrossEntropyLoss(nn.Cell): + """ + Calculate the cross entropy loss + + Args: + config(GPTConfig): the config of the network + + Inputs: + logits: the output logits of the backbone + label: the ground truth label of the sample + input_mask: the mask indicating whether each position is a valid input + + Returns: + loss: Tensor, the corrsponding cross entropy loss + """ + def __init__(self, config): + super(CrossEntropyLoss, self).__init__() + self.log_softmax = nn.LogSoftmax(axis=-1) + self.mean = P.ReduceMean() + self.sum = P.ReduceSum() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.vocab_size = config.vocab_size + + def construct(self, logits, label, input_mask): + logits = self.log_softmax(P.Cast()(logits, mstype.float32)) + label = P.Reshape()(label, (-1,)) + one_hot_label = self.onehot(label, self.vocab_size, self.on_value, self.off_value) + loss_sum = P.Neg()(self.sum(logits*one_hot_label, (-1,))) + input_mask = P.Reshape()(input_mask, (-1,)) + numerator = self.sum(loss_sum*input_mask) + denominator = self.sum(input_mask) + P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32) + loss = numerator / denominator + return loss + +class GPTWithLoss(nn.Cell): + """ + GPT training loss + + Args: + network: backbone network of GPT2/3 + loss: loss function, e.g., crossentropy + eos_token: the end_of_sentence token + + Inputs: + input_ids: the tokenized inputs + past: the previous feature map + + Returns: + output: Tensor, the loss of the network + """ + def __init__(self, network, loss, eos_token=50256): + super(GPTWithLoss, self).__init__(auto_prefix=False) + self.network = network + self.loss = loss + self.eos_token = eos_token + + def construct(self, input_ids, past=None): + tokens = input_ids[:, :-1] + input_mask = F.cast(F.not_equal(tokens, self.eos_token), mstype.float32) + logits = self.network(tokens, input_mask, past) + labels = input_ids[:, 1:] + output = self.loss(logits, labels, input_mask) + return output + +class EvalNet(nn.Cell): + """ + GPT evaluation net + + Args: + backbone: backbone network of GPT2/3 + generate: enable generate mode + + Inputs: + input_ids: the tokenized inpus + + Returns: + outputs: Tensor, corresponding output for different tasks + """ + def __init__(self, backbone, generate=False): + super(EvalNet, self).__init__(auto_prefix=False) + self.backbone = backbone + self.argmax = P.Argmax() + self.generate = generate + + def construct(self, input_ids): + """evaluation net""" + input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32) + logits = self.backbone(input_ids, input_mask) + outputs = None + if self.generate: + outputs = nn.LogSoftmax()(logits) + outputs = F.tensor_pow(np.e, outputs) + else: + outputs = self.argmax(logits) + return outputs diff --git a/model_zoo/official/nlp/gpt/src/gpt_wrapcell.py b/model_zoo/official/nlp/gpt/src/gpt_wrapcell.py new file mode 100644 index 0000000000..0d024fabbd --- /dev/null +++ b/model_zoo/official/nlp/gpt/src/gpt_wrapcell.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""GPT training wrapper""" + + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore import context +from mindspore.context import ParallelMode +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.communication.management import get_group_size +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype +from mindspore.common.parameter import Parameter +from utils import ClipByGlobalNorm + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 +clip_grad = C.MultitypeFuncGraph("clip_grad") + +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type not in [0, 1]: + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + +class GPTTrainOneStepWithLossScaleCell(nn.Cell): + """ + Encapsulation class of GPT network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False): + super(GPTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.enable_global_norm = enable_global_norm + self.grad = C.GradOperation(get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + past=None, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + past) + + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + past, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + + if self.enable_global_norm: + grads = ClipByGlobalNorm()(grads) + else: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) diff --git a/model_zoo/official/nlp/gpt/src/inference.py b/model_zoo/official/nlp/gpt/src/inference.py new file mode 100644 index 0000000000..f08d9fbc57 --- /dev/null +++ b/model_zoo/official/nlp/gpt/src/inference.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +TopK for text generation +""" + +import numpy as np +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor + +def generate(model, origin_inputs, seq_length, end_token=50256): + """ + TopK for text generation + + Inputs: + model: the model for inferencing + origin_inputs: the original inputs based on which the model will continue writing + seq_length: seq_length for the model + end_token: end of sentence token id + + Returns: + outputs: the ids for the generated text + """ + TOPK = 5 + seq_length = seq_length + bs, valid_length = origin_inputs.shape + pad_length = seq_length - origin_inputs.shape[-1] + input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0)) + print("input_ids is ", input_ids) + while valid_length < seq_length: + inputs = Tensor(input_ids, mstype.int32) + logits = model(inputs).asnumpy() + logits = logits.reshape(bs, seq_length, -1) + probs = logits[0, valid_length-1, :] + p_args = probs.argsort()[::-1][:TOPK] + + p = probs[p_args] + p = p / sum(p) + target_index = np.random.choice(len(p), p=p) + if p_args[target_index] == end_token or valid_length == seq_length-1: + outputs = input_ids + break + input_ids[0][valid_length] = p_args[target_index] + valid_length += 1 + length = np.sum(outputs != 0) + outputs = outputs[0][:length] + return outputs diff --git a/model_zoo/official/nlp/gpt/src/utils.py b/model_zoo/official/nlp/gpt/src/utils.py new file mode 100644 index 0000000000..ae8a91932a --- /dev/null +++ b/model_zoo/official/nlp/gpt/src/utils.py @@ -0,0 +1,138 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +network config setting, gradient clip function and dynamic learning rate function +""" + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR +import numpy as np + + +class GPTConfig: + """ + GPT config class which defines the model size + """ + def __init__(self, + batch_size=32, + seq_length=1024, + vocab_size=50257, + embedding_size=768, + num_layers=12, + num_heads=12, + expand_ratio=4, + post_layernorm_residual=False, + dropout_rate=0.1, + compute_dtype=mstype.float16, + use_past=False): + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.num_layers = num_layers + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.post_layernorm_residual = post_layernorm_residual + self.dropout_rate = dropout_rate + self.compute_dtype = compute_dtype + self.use_past = use_past + + + + +get_square_sum = C.MultitypeFuncGraph("get_square_sum") +@get_square_sum.register("Tensor") +def _get_square_sum(grad): + norm = P.ReduceSum(False)(F.square(grad), ()) + norm = F.expand_dims(F.cast(norm, mstype.float32), 0) + return norm + + +apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") +@apply_global_norm.register("Tensor", "Tensor", "Tensor") +def _apply_global_norm(clip_norm, global_norm, grad): + grad = grad * clip_norm / global_norm + return grad + +class GlobalNorm(nn.Cell): + """ + Calculate the global norm value of given tensors + """ + def __init__(self): + super(GlobalNorm, self).__init__() + self.norm = nn.Norm() + self.hyper_map = C.HyperMap() + + def construct(self, grads): + square_sum = self.hyper_map(get_square_sum, grads) + global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) + return global_norms + +class ClipByGlobalNorm(nn.Cell): + """ + Clip grads by global norm + """ + def __init__(self, clip_norm=1.0): + super(ClipByGlobalNorm, self).__init__() + self.global_norm = GlobalNorm() + self.clip_norm = Tensor([clip_norm], mstype.float32) + self.hyper_map = C.HyperMap() + + def construct(self, grads): + global_norm = self.global_norm(grads) + cond = P.GreaterEqual()(global_norm, self.clip_norm) + global_norm = F.select(cond, global_norm, self.clip_norm) + grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) + return grads + + +class LearningRate(LearningRateSchedule): + """ + Warmup-decay learning rate for GPT network. + """ + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power=1.0, use_cosine=True): + super(LearningRate, self).__init__() + self.warmup_flag = False + if warmup_steps > 0: + self.warmup_flag = True + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.cosine_decay_lr = CosineDecayLR(end_learning_rate, learning_rate, decay_steps) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + self.use_cosine = use_cosine + + def construct(self, global_step): + """dynamic learning rate""" + if not self.use_cosine: + decay_lr = self.decay_lr(global_step) + else: + decay_lr = self.cosine_decay_lr(global_step) + if self.warmup_flag: + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + else: + lr = decay_lr + return lr diff --git a/model_zoo/official/nlp/gpt/train.py b/model_zoo/official/nlp/gpt/train.py new file mode 100644 index 0000000000..8271c2f00a --- /dev/null +++ b/model_zoo/official/nlp/gpt/train.py @@ -0,0 +1,133 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +GPT train script +""" + + +import os +import argparse +from mindspore import context +from mindspore.train.model import Model +import mindspore.communication.management as D +from mindspore.context import ParallelMode +import mindspore.nn as nn +from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +import mindspore.common.dtype as mstype +from mindspore.common import set_seed +from src.dataset import create_dataset +from src.gpt import GPT, GPTWithLoss, CrossEntropyLoss +from src.gpt_wrapcell import GPTTrainOneStepWithLossScaleCell +from src.utils import GPTConfig, LearningRate + +def run_train(): + """train function for GPT""" + parser = argparse.ArgumentParser(description="GPT training") + parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], + help="Run distribute, default is false.") + parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "lamb"], + help="select which optimizer to be used, default adam") + parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10.") + parser.add_argument("--warmup_step", type=int, default=10000, help="Warmup step, default is 10000.") + parser.add_argument("--data_path", type=str, default="", help="Data path of your MindRecord files.") + parser.add_argument("--start_lr", type=float, default="5e-5", help="Start learning rate, default is 5e-5.") + parser.add_argument("--end_lr", type=float, default="1e-10", help="End learning rate, default is 1e-10.") + parser.add_argument("--sink_size", type=int, default=100, help="Sink size for every iteration, default is 100") + + + args_opt = parser.parse_args() + device_id = int(os.getenv("DEVICE_ID")) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) + if args_opt.distribute == "true": + D.init() + device_num = args_opt.device_num + rank = device_id % device_num + print("device_id is {}, rank_id is {}".format(device_id, rank)) + + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, + device_num=device_num) + + else: + rank = 0 + device_num = 1 + + config = GPTConfig(batch_size=4, + seq_length=1024, + vocab_size=50257, + embedding_size=1024, + num_layers=24, + num_heads=16, + expand_ratio=4, + post_layernorm_residual=False, + dropout_rate=0.1, + compute_dtype=mstype.float16, + use_past=False) + gpt = GPT(config) + loss = CrossEntropyLoss(config) + gpt_with_loss = GPTWithLoss(gpt, loss) + + ds = create_dataset(config.batch_size, data_path=args_opt.data_path, device_num=device_num, rank=rank) + + + epoch_num = args_opt.epoch_size + step_per_epoch = ds.get_dataset_size() + + lr = LearningRate(learning_rate=args_opt.start_lr, + end_learning_rate=args_opt.end_lr, + warmup_steps=args_opt.warmup_step, + decay_steps=epoch_num*step_per_epoch) + + decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() + params = gpt.trainable_params() + decay_params = list(filter(decay_filter, params)) + other_params = list(filter(lambda x: not decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': 1e-2}, + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] + + if args_opt.optimizer == "lamb": + optimizer = nn.Lamb(group_params, learning_rate=lr) + else: + optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr) + + callback_size = args_opt.sink_size + actual_epoch_num = int(epoch_num * step_per_epoch/callback_size) + callback = [TimeMonitor(callback_size), LossMonitor(callback_size)] + + config_ck = CheckpointConfig(save_checkpoint_steps=step_per_epoch, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="GPT2", config=config_ck) + callback.append(ckpoint_cb) + + + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024, + scale_factor=2, + scale_window=1000) + + gpt_with_grads = GPTTrainOneStepWithLossScaleCell(gpt_with_loss, optimizer=optimizer, + scale_update_cell=update_cell) + + + model = Model(gpt_with_grads) + model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size) + + +if __name__ == "__main__": + set_seed(12315) + run_train()