pull/7740/head
alouhahaha 5 years ago
parent 2d6c07a367
commit 837f1a160c

@ -0,0 +1,76 @@
# It is still under development.
# Contents
- [Contents](#contents)
- [GPT Description](#bert-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [ModelZoo Homepage](#modelzoo-homepage)
# [GPT Description](#contents)
The GPT network was proposed by OpenAI and it has three versions, i.e., GPT, GPT2 and GPT3. The newest version GPT3 was proposed in Jul 2020 and it is quite a large language model with 175 billion parameters. Stacking many Decoder structure of Transformer and feeding massive amount of training data, GPT3 becomes such a powerful language model that no fine-tuning process is needed. As the papre title says, language models are few-shot learners, GPT3 proves that with a large and well-trained model, we can achieve a similar performance compared to those of fine-tuning methods.
[Paper](https://arxiv.org/abs/2005.14165): Tom B.Brown, Benjamin Mann, Nick Ryder et al. [Language Models are Few-Shot Learners]((https://arxiv.org/abs/2005.14165)). arXiv preprint arXiv:2005.14165
# [Model Architecture](#contents)
GPT3 stacks many layers of decoder of transformer. According to the layer numbers and embedding size, GPT3 has several versions. The largest model contains 96 layers with embedding size of 12288 resulting to a total parameter of 175 billion.
# [Dataset](#contents)
- OpenWebText is utilized as the training data and the training objective is to predict the next token at each position.
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get access to the resources.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```bash
# run standalone training example
bash scripts/run_standalone_train.sh 0 10 /path/dataset
# run distributed training example
bash scripts/run_distribute_training.sh /path/dataset /path/hccl.json 8
# run evaluation example, now only accuracy and perplexity for lambada and wikitext103 are supported
bash scripts/run_evaluation.sh lambada /your/ckpt /your/data acc
```
For distributed training, an hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below:
https:gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```shell
.
└─gpt
├─README.md
├─scripts
├─run_standalone_train.sh # shell script for standalone training on ascend
├─run_distribut_train.sh # shell script for distributed training on ascend
└─run_evaluation.sh # shell script for evaluation of ascend
├─src
├─gpt_wrapper.py # backbone code of network
├─gpt.py # backbone code of network
├─dataset.py # data preprocessing
├─inference.py # evaluation function
├─utils.py # util function
├─train.py # train net for training phase
└─eval.py # eval net for evaluation
```
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -0,0 +1,155 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT evaluation script.
"""
import math
import argparse
import numpy as np
from mindspore import context
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.inference import generate
from src.dataset import create_dataset
from src.gpt import GPT, EvalNet, GPTWithLoss, CrossEntropyLoss
from src.utils import GPTConfig
context.set_context(mode=context.GRAPH_MODE)
def ppl_score(probs, length, is_logsoftmax=True):
""" calculate perplexity with prob or log_prob inputs """
probs = probs[:length]
if is_logsoftmax:
prob = np.sum(probs) / length
ppl = 1.0 / np.power(np.e, prob)
else:
prob = 1.0
for p in probs:
prob *= (1.0 / p)
ppl = np.power(prob, 1.0/length)
return ppl
def get_ppl(model, dataset):
""" calculate perplexity for input dataset """
PPL = []
tokens = 0
for data in dataset:
data = data[0].asnumpy()
input_ids = data
logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
PPL.append(logits * len(data))
tokens += len(data)
val_loss = sum(PPL) / tokens
ppl = math.exp(min(20, val_loss))
return ppl
def get_acc(model, dataset):
""" calculate accuracy for input dataset """
total_num = 0
acc_num = 0
for data in dataset:
data = data[0].asnumpy()
input_mask = (data != 0).astype(np.int32)
length = np.sum(input_mask, 1)
label = np.zeros(length.shape)
for i, idx in enumerate(length):
label[i] = data[i][idx-1]
input_mask[i][idx-1] = 0
data[i][idx-1] = 0
length = np.sum(data != 50256, 1)
input_ids = data
logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
logits = logits.reshape(len(length), -1)
predicted_label = np.zeros(length.shape)
for i, idx in enumerate(length):
predicted_label[i] = logits[i][idx-2]
total_num += len(label)
acc_num += sum(label == predicted_label)
acc = acc_num / total_num
return acc
def run_eval():
""" evaluate scripts """
parser = argparse.ArgumentParser(description="GPT inferencing")
parser.add_argument('--task_type', type=str, default="", help="Evaluation task.")
parser.add_argument('--metrics', type=str, default="acc", choices=["ppl", "acc"], help="Evaluation metrics.")
parser.add_argument('--ckpt_path', type=str, default="", help="path of checkpoint file.")
parser.add_argument('--data_path', type=str, default="", help="path of MindRecord file.")
args = parser.parse_args()
task = args.task_type
metrics = args.metrics
ckpt_path = args.ckpt_path
if task not in ["generate", "lambada", "wikitext"]:
raise ValueError("{} is not supported now".format(task))
if metrics not in ["acc", "ppl"]:
raise ValueError("{} is not supported now".format(metrics))
config = GPTConfig(batch_size=16,
seq_length=1024,
vocab_size=50257,
embedding_size=1024,
num_layers=24,
num_heads=16,
expand_ratio=4,
post_layernorm_residual=False,
dropout_rate=0.0,
compute_dtype=mstype.float16,
use_past=False)
ckpt_dict = load_checkpoint(ckpt_path)
gpt = GPT(config)
if task == "generate":
gpt_eval = EvalNet(gpt, generate=True)
elif metrics == "acc":
gpt_eval = EvalNet(gpt, generate=False)
else:
loss = CrossEntropyLoss(config)
gpt_eval = GPTWithLoss(gpt, loss)
gpt_eval.set_train(False)
load_param_into_net(gpt_eval, ckpt_dict)
if task == "generate":
start_sentence = [6170, 318, 257]
input_ids = np.array(start_sentence).reshape(1, -1)
outputs = generate(gpt_eval, input_ids, config.seq_length)
output_list = outputs.tolist()
print("output id is ", output_list)
else:
data_path = args.data_path
eval_dataset = create_dataset(config.batch_size, data_path=data_path, drop=False)
if metrics == "acc":
acc = get_acc(gpt_eval, eval_dataset)
print("Accuracy is ", acc)
elif metrics == "ppl":
ppl = get_ppl(gpt_eval, eval_dataset)
print("Perplexity is ", ppl)
if __name__ == "__main__":
run_eval()

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distributed_pretrain_ascend.sh DATA_DIR RANK_TABLE_FILE DEVICE_NUM"
echo "for example: bash run_distributed_pretrain_ascend.sh /path/dataset /path/hccl.json 8"
echo "It is better to use absolute path."
echo "=============================================================================================================="
ROOT_PATH='pwd'
DATA_DIR=$1
export RANK_TABLE_FILE=$2
RANK_SIZE=$3
for((i=0;i<=${RANK_SIZE};i++));
do
rm ${ROOT_PATH}/device$i/ -rf
mkdir ${ROOT_PATH}/device$i
cd ${ROOT_PATH}/device$i || exit
export RANK_ID=$i
export DEVICE_ID=$i
python ${ROOT_PATH}/train.py --distribute=true --device_num=$RANK_SIZE --data_path=$DATA_DIR >log$i.log 2>&1 &
done

@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash scripts/run_evaluation.sh TASK_TYPE CKPT_PATH DATA_PATH METRICS"
echo "for example: bash scripts/run_evaluation.sh lambada /your/ckpt /your/data acc"
echo "=============================================================================================================="
TASK_TYPE=$1
CKPT_PATH=$2
DATA_PATH=$3
METRICS=$4
python eval.py \
--task_type=$TASK_TYPE \
--ckpt_path=$CKPT_PATH \
--data_path=$DATA_PATH \
--metrics=$METRICS

@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_pretrain_ascend.sh DEVICE_ID EPOCH_SIZE DATA_DIR"
echo "for example: bash run_standalone_pretrain_ascend.sh 0 40 /path/zh-wiki/"
echo "=============================================================================================================="
DEVICE_ID=$1
EPOCH_SIZE=$2
DATA_DIR=$3
python train.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--data_path=$DATA_DIR \
--optimizer="adam" > training_log.txt 2>&1 &

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Create dataset for training and evaluting
"""
import os
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.common.dtype as mstype
def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True):
"""
Create dataset
Inputs:
batch_size: batch size
data_path: path of your MindRecord files
device_num: total device number
rank: current rank id
drop: whether drop remainder
Returns:
dataset: the dataset for training or evaluating
"""
home_path = os.path.join(os.getcwd(), data_path)
data = [os.path.join(home_path, name) for name in os.listdir(data_path) if name.endswith("mindrecord")]
print(data)
dataset = ds.MindDataset(data, columns_list=["input_ids"], shuffle=True, num_shards=device_num, shard_id=rank)
type_cast_op = C.TypeCast(mstype.int32)
dataset = dataset.map(input_columns="input_ids", operations=type_cast_op)
dataset = dataset.batch(batch_size, drop_remainder=drop)
dataset = dataset.repeat(1)
return dataset

File diff suppressed because it is too large Load Diff

@ -0,0 +1,157 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GPT training wrapper"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.parameter import Parameter
from utils import ClipByGlobalNorm
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type not in [0, 1]:
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class GPTTrainOneStepWithLossScaleCell(nn.Cell):
"""
Encapsulation class of GPT network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
super(GPTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = optimizer.parameters
self.optimizer = optimizer
self.enable_global_norm = enable_global_norm
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
past=None,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
past)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids,
past,
self.cast(scaling_sens,
mstype.float32))
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
if self.enable_global_norm:
grads = ClipByGlobalNorm()(grads)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
TopK for text generation
"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
def generate(model, origin_inputs, seq_length, end_token=50256):
"""
TopK for text generation
Inputs:
model: the model for inferencing
origin_inputs: the original inputs based on which the model will continue writing
seq_length: seq_length for the model
end_token: end of sentence token id
Returns:
outputs: the ids for the generated text
"""
TOPK = 5
seq_length = seq_length
bs, valid_length = origin_inputs.shape
pad_length = seq_length - origin_inputs.shape[-1]
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
print("input_ids is ", input_ids)
while valid_length < seq_length:
inputs = Tensor(input_ids, mstype.int32)
logits = model(inputs).asnumpy()
logits = logits.reshape(bs, seq_length, -1)
probs = logits[0, valid_length-1, :]
p_args = probs.argsort()[::-1][:TOPK]
p = probs[p_args]
p = p / sum(p)
target_index = np.random.choice(len(p), p=p)
if p_args[target_index] == end_token or valid_length == seq_length-1:
outputs = input_ids
break
input_ids[0][valid_length] = p_args[target_index]
valid_length += 1
length = np.sum(outputs != 0)
outputs = outputs[0][:length]
return outputs

@ -0,0 +1,138 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, gradient clip function and dynamic learning rate function
"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR
import numpy as np
class GPTConfig:
"""
GPT config class which defines the model size
"""
def __init__(self,
batch_size=32,
seq_length=1024,
vocab_size=50257,
embedding_size=768,
num_layers=12,
num_heads=12,
expand_ratio=4,
post_layernorm_residual=False,
dropout_rate=0.1,
compute_dtype=mstype.float16,
use_past=False):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.num_layers = num_layers
self.num_heads = num_heads
self.expand_ratio = expand_ratio
self.post_layernorm_residual = post_layernorm_residual
self.dropout_rate = dropout_rate
self.compute_dtype = compute_dtype
self.use_past = use_past
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
@get_square_sum.register("Tensor")
def _get_square_sum(grad):
norm = P.ReduceSum(False)(F.square(grad), ())
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
return norm
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
def _apply_global_norm(clip_norm, global_norm, grad):
grad = grad * clip_norm / global_norm
return grad
class GlobalNorm(nn.Cell):
"""
Calculate the global norm value of given tensors
"""
def __init__(self):
super(GlobalNorm, self).__init__()
self.norm = nn.Norm()
self.hyper_map = C.HyperMap()
def construct(self, grads):
square_sum = self.hyper_map(get_square_sum, grads)
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
return global_norms
class ClipByGlobalNorm(nn.Cell):
"""
Clip grads by global norm
"""
def __init__(self, clip_norm=1.0):
super(ClipByGlobalNorm, self).__init__()
self.global_norm = GlobalNorm()
self.clip_norm = Tensor([clip_norm], mstype.float32)
self.hyper_map = C.HyperMap()
def construct(self, grads):
global_norm = self.global_norm(grads)
cond = P.GreaterEqual()(global_norm, self.clip_norm)
global_norm = F.select(cond, global_norm, self.clip_norm)
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
return grads
class LearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for GPT network.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power=1.0, use_cosine=True):
super(LearningRate, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.cosine_decay_lr = CosineDecayLR(end_learning_rate, learning_rate, decay_steps)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
self.use_cosine = use_cosine
def construct(self, global_step):
"""dynamic learning rate"""
if not self.use_cosine:
decay_lr = self.decay_lr(global_step)
else:
decay_lr = self.cosine_decay_lr(global_step)
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr

@ -0,0 +1,133 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT train script
"""
import os
import argparse
from mindspore import context
from mindspore.train.model import Model
import mindspore.communication.management as D
from mindspore.context import ParallelMode
import mindspore.nn as nn
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.common.dtype as mstype
from mindspore.common import set_seed
from src.dataset import create_dataset
from src.gpt import GPT, GPTWithLoss, CrossEntropyLoss
from src.gpt_wrapcell import GPTTrainOneStepWithLossScaleCell
from src.utils import GPTConfig, LearningRate
def run_train():
"""train function for GPT"""
parser = argparse.ArgumentParser(description="GPT training")
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
help="Run distribute, default is false.")
parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "lamb"],
help="select which optimizer to be used, default adam")
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10.")
parser.add_argument("--warmup_step", type=int, default=10000, help="Warmup step, default is 10000.")
parser.add_argument("--data_path", type=str, default="", help="Data path of your MindRecord files.")
parser.add_argument("--start_lr", type=float, default="5e-5", help="Start learning rate, default is 5e-5.")
parser.add_argument("--end_lr", type=float, default="1e-10", help="End learning rate, default is 1e-10.")
parser.add_argument("--sink_size", type=int, default=100, help="Sink size for every iteration, default is 100")
args_opt = parser.parse_args()
device_id = int(os.getenv("DEVICE_ID"))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
if args_opt.distribute == "true":
D.init()
device_num = args_opt.device_num
rank = device_id % device_num
print("device_id is {}, rank_id is {}".format(device_id, rank))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
else:
rank = 0
device_num = 1
config = GPTConfig(batch_size=4,
seq_length=1024,
vocab_size=50257,
embedding_size=1024,
num_layers=24,
num_heads=16,
expand_ratio=4,
post_layernorm_residual=False,
dropout_rate=0.1,
compute_dtype=mstype.float16,
use_past=False)
gpt = GPT(config)
loss = CrossEntropyLoss(config)
gpt_with_loss = GPTWithLoss(gpt, loss)
ds = create_dataset(config.batch_size, data_path=args_opt.data_path, device_num=device_num, rank=rank)
epoch_num = args_opt.epoch_size
step_per_epoch = ds.get_dataset_size()
lr = LearningRate(learning_rate=args_opt.start_lr,
end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step,
decay_steps=epoch_num*step_per_epoch)
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
params = gpt.trainable_params()
decay_params = list(filter(decay_filter, params))
other_params = list(filter(lambda x: not decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': 1e-2},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
if args_opt.optimizer == "lamb":
optimizer = nn.Lamb(group_params, learning_rate=lr)
else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr)
callback_size = args_opt.sink_size
actual_epoch_num = int(epoch_num * step_per_epoch/callback_size)
callback = [TimeMonitor(callback_size), LossMonitor(callback_size)]
config_ck = CheckpointConfig(save_checkpoint_steps=step_per_epoch, keep_checkpoint_max=1)
ckpoint_cb = ModelCheckpoint(prefix="GPT2", config=config_ck)
callback.append(ckpoint_cb)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024,
scale_factor=2,
scale_window=1000)
gpt_with_grads = GPTTrainOneStepWithLossScaleCell(gpt_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
model = Model(gpt_with_grads)
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size)
if __name__ == "__main__":
set_seed(12315)
run_train()
Loading…
Cancel
Save