parent
2d6c07a367
commit
837f1a160c
@ -0,0 +1,155 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
GPT evaluation script.
|
||||
"""
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.inference import generate
|
||||
from src.dataset import create_dataset
|
||||
from src.gpt import GPT, EvalNet, GPTWithLoss, CrossEntropyLoss
|
||||
from src.utils import GPTConfig
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def ppl_score(probs, length, is_logsoftmax=True):
|
||||
""" calculate perplexity with prob or log_prob inputs """
|
||||
probs = probs[:length]
|
||||
if is_logsoftmax:
|
||||
prob = np.sum(probs) / length
|
||||
ppl = 1.0 / np.power(np.e, prob)
|
||||
else:
|
||||
prob = 1.0
|
||||
for p in probs:
|
||||
prob *= (1.0 / p)
|
||||
ppl = np.power(prob, 1.0/length)
|
||||
return ppl
|
||||
|
||||
def get_ppl(model, dataset):
|
||||
""" calculate perplexity for input dataset """
|
||||
PPL = []
|
||||
tokens = 0
|
||||
for data in dataset:
|
||||
data = data[0].asnumpy()
|
||||
input_ids = data
|
||||
|
||||
logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
|
||||
PPL.append(logits * len(data))
|
||||
tokens += len(data)
|
||||
|
||||
val_loss = sum(PPL) / tokens
|
||||
ppl = math.exp(min(20, val_loss))
|
||||
return ppl
|
||||
|
||||
def get_acc(model, dataset):
|
||||
""" calculate accuracy for input dataset """
|
||||
total_num = 0
|
||||
acc_num = 0
|
||||
for data in dataset:
|
||||
data = data[0].asnumpy()
|
||||
input_mask = (data != 0).astype(np.int32)
|
||||
length = np.sum(input_mask, 1)
|
||||
label = np.zeros(length.shape)
|
||||
for i, idx in enumerate(length):
|
||||
label[i] = data[i][idx-1]
|
||||
input_mask[i][idx-1] = 0
|
||||
data[i][idx-1] = 0
|
||||
|
||||
length = np.sum(data != 50256, 1)
|
||||
input_ids = data
|
||||
logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
|
||||
logits = logits.reshape(len(length), -1)
|
||||
|
||||
predicted_label = np.zeros(length.shape)
|
||||
for i, idx in enumerate(length):
|
||||
predicted_label[i] = logits[i][idx-2]
|
||||
|
||||
total_num += len(label)
|
||||
acc_num += sum(label == predicted_label)
|
||||
|
||||
acc = acc_num / total_num
|
||||
return acc
|
||||
|
||||
|
||||
def run_eval():
|
||||
""" evaluate scripts """
|
||||
parser = argparse.ArgumentParser(description="GPT inferencing")
|
||||
parser.add_argument('--task_type', type=str, default="", help="Evaluation task.")
|
||||
parser.add_argument('--metrics', type=str, default="acc", choices=["ppl", "acc"], help="Evaluation metrics.")
|
||||
parser.add_argument('--ckpt_path', type=str, default="", help="path of checkpoint file.")
|
||||
parser.add_argument('--data_path', type=str, default="", help="path of MindRecord file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
task = args.task_type
|
||||
metrics = args.metrics
|
||||
ckpt_path = args.ckpt_path
|
||||
if task not in ["generate", "lambada", "wikitext"]:
|
||||
raise ValueError("{} is not supported now".format(task))
|
||||
|
||||
if metrics not in ["acc", "ppl"]:
|
||||
raise ValueError("{} is not supported now".format(metrics))
|
||||
|
||||
|
||||
config = GPTConfig(batch_size=16,
|
||||
seq_length=1024,
|
||||
vocab_size=50257,
|
||||
embedding_size=1024,
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
expand_ratio=4,
|
||||
post_layernorm_residual=False,
|
||||
dropout_rate=0.0,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=False)
|
||||
|
||||
ckpt_dict = load_checkpoint(ckpt_path)
|
||||
|
||||
gpt = GPT(config)
|
||||
if task == "generate":
|
||||
gpt_eval = EvalNet(gpt, generate=True)
|
||||
elif metrics == "acc":
|
||||
gpt_eval = EvalNet(gpt, generate=False)
|
||||
else:
|
||||
loss = CrossEntropyLoss(config)
|
||||
gpt_eval = GPTWithLoss(gpt, loss)
|
||||
|
||||
gpt_eval.set_train(False)
|
||||
load_param_into_net(gpt_eval, ckpt_dict)
|
||||
|
||||
if task == "generate":
|
||||
start_sentence = [6170, 318, 257]
|
||||
input_ids = np.array(start_sentence).reshape(1, -1)
|
||||
outputs = generate(gpt_eval, input_ids, config.seq_length)
|
||||
output_list = outputs.tolist()
|
||||
print("output id is ", output_list)
|
||||
else:
|
||||
data_path = args.data_path
|
||||
eval_dataset = create_dataset(config.batch_size, data_path=data_path, drop=False)
|
||||
if metrics == "acc":
|
||||
acc = get_acc(gpt_eval, eval_dataset)
|
||||
print("Accuracy is ", acc)
|
||||
elif metrics == "ppl":
|
||||
ppl = get_ppl(gpt_eval, eval_dataset)
|
||||
print("Perplexity is ", ppl)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_eval()
|
@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_distributed_pretrain_ascend.sh DATA_DIR RANK_TABLE_FILE DEVICE_NUM"
|
||||
echo "for example: bash run_distributed_pretrain_ascend.sh /path/dataset /path/hccl.json 8"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
ROOT_PATH='pwd'
|
||||
DATA_DIR=$1
|
||||
export RANK_TABLE_FILE=$2
|
||||
RANK_SIZE=$3
|
||||
|
||||
|
||||
for((i=0;i<=${RANK_SIZE};i++));
|
||||
do
|
||||
rm ${ROOT_PATH}/device$i/ -rf
|
||||
mkdir ${ROOT_PATH}/device$i
|
||||
cd ${ROOT_PATH}/device$i || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python ${ROOT_PATH}/train.py --distribute=true --device_num=$RANK_SIZE --data_path=$DATA_DIR >log$i.log 2>&1 &
|
||||
done
|
@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash scripts/run_evaluation.sh TASK_TYPE CKPT_PATH DATA_PATH METRICS"
|
||||
echo "for example: bash scripts/run_evaluation.sh lambada /your/ckpt /your/data acc"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
|
||||
TASK_TYPE=$1
|
||||
CKPT_PATH=$2
|
||||
DATA_PATH=$3
|
||||
METRICS=$4
|
||||
python eval.py \
|
||||
--task_type=$TASK_TYPE \
|
||||
--ckpt_path=$CKPT_PATH \
|
||||
--data_path=$DATA_PATH \
|
||||
--metrics=$METRICS
|
||||
|
@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_pretrain_ascend.sh DEVICE_ID EPOCH_SIZE DATA_DIR"
|
||||
echo "for example: bash run_standalone_pretrain_ascend.sh 0 40 /path/zh-wiki/"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATA_DIR=$3
|
||||
|
||||
|
||||
python train.py \
|
||||
--distribute="false" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--data_path=$DATA_DIR \
|
||||
--optimizer="adam" > training_log.txt 2>&1 &
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Create dataset for training and evaluting
|
||||
"""
|
||||
|
||||
import os
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True):
|
||||
"""
|
||||
Create dataset
|
||||
|
||||
Inputs:
|
||||
batch_size: batch size
|
||||
data_path: path of your MindRecord files
|
||||
device_num: total device number
|
||||
rank: current rank id
|
||||
drop: whether drop remainder
|
||||
|
||||
Returns:
|
||||
dataset: the dataset for training or evaluating
|
||||
"""
|
||||
home_path = os.path.join(os.getcwd(), data_path)
|
||||
data = [os.path.join(home_path, name) for name in os.listdir(data_path) if name.endswith("mindrecord")]
|
||||
print(data)
|
||||
dataset = ds.MindDataset(data, columns_list=["input_ids"], shuffle=True, num_shards=device_num, shard_id=rank)
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
dataset = dataset.map(input_columns="input_ids", operations=type_cast_op)
|
||||
dataset = dataset.batch(batch_size, drop_remainder=drop)
|
||||
dataset = dataset.repeat(1)
|
||||
return dataset
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,157 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GPT training wrapper"""
|
||||
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from utils import ClipByGlobalNorm
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in [0, 1]:
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
class GPTTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of GPT network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
|
||||
super(GPTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
past=None,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
past)
|
||||
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
past,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
|
||||
if self.enable_global_norm:
|
||||
grads = ClipByGlobalNorm()(grads)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
@ -0,0 +1,60 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
TopK for text generation
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
def generate(model, origin_inputs, seq_length, end_token=50256):
|
||||
"""
|
||||
TopK for text generation
|
||||
|
||||
Inputs:
|
||||
model: the model for inferencing
|
||||
origin_inputs: the original inputs based on which the model will continue writing
|
||||
seq_length: seq_length for the model
|
||||
end_token: end of sentence token id
|
||||
|
||||
Returns:
|
||||
outputs: the ids for the generated text
|
||||
"""
|
||||
TOPK = 5
|
||||
seq_length = seq_length
|
||||
bs, valid_length = origin_inputs.shape
|
||||
pad_length = seq_length - origin_inputs.shape[-1]
|
||||
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
|
||||
print("input_ids is ", input_ids)
|
||||
while valid_length < seq_length:
|
||||
inputs = Tensor(input_ids, mstype.int32)
|
||||
logits = model(inputs).asnumpy()
|
||||
logits = logits.reshape(bs, seq_length, -1)
|
||||
probs = logits[0, valid_length-1, :]
|
||||
p_args = probs.argsort()[::-1][:TOPK]
|
||||
|
||||
p = probs[p_args]
|
||||
p = p / sum(p)
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
if p_args[target_index] == end_token or valid_length == seq_length-1:
|
||||
outputs = input_ids
|
||||
break
|
||||
input_ids[0][valid_length] = p_args[target_index]
|
||||
valid_length += 1
|
||||
length = np.sum(outputs != 0)
|
||||
outputs = outputs[0][:length]
|
||||
return outputs
|
@ -0,0 +1,138 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
network config setting, gradient clip function and dynamic learning rate function
|
||||
"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GPTConfig:
|
||||
"""
|
||||
GPT config class which defines the model size
|
||||
"""
|
||||
def __init__(self,
|
||||
batch_size=32,
|
||||
seq_length=1024,
|
||||
vocab_size=50257,
|
||||
embedding_size=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
expand_ratio=4,
|
||||
post_layernorm_residual=False,
|
||||
dropout_rate=0.1,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=False):
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.expand_ratio = expand_ratio
|
||||
self.post_layernorm_residual = post_layernorm_residual
|
||||
self.dropout_rate = dropout_rate
|
||||
self.compute_dtype = compute_dtype
|
||||
self.use_past = use_past
|
||||
|
||||
|
||||
|
||||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
@get_square_sum.register("Tensor")
|
||||
def _get_square_sum(grad):
|
||||
norm = P.ReduceSum(False)(F.square(grad), ())
|
||||
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return norm
|
||||
|
||||
|
||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(clip_norm, global_norm, grad):
|
||||
grad = grad * clip_norm / global_norm
|
||||
return grad
|
||||
|
||||
class GlobalNorm(nn.Cell):
|
||||
"""
|
||||
Calculate the global norm value of given tensors
|
||||
"""
|
||||
def __init__(self):
|
||||
super(GlobalNorm, self).__init__()
|
||||
self.norm = nn.Norm()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
square_sum = self.hyper_map(get_square_sum, grads)
|
||||
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
|
||||
return global_norms
|
||||
|
||||
class ClipByGlobalNorm(nn.Cell):
|
||||
"""
|
||||
Clip grads by global norm
|
||||
"""
|
||||
def __init__(self, clip_norm=1.0):
|
||||
super(ClipByGlobalNorm, self).__init__()
|
||||
self.global_norm = GlobalNorm()
|
||||
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
global_norm = self.global_norm(grads)
|
||||
cond = P.GreaterEqual()(global_norm, self.clip_norm)
|
||||
global_norm = F.select(cond, global_norm, self.clip_norm)
|
||||
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
||||
return grads
|
||||
|
||||
|
||||
class LearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for GPT network.
|
||||
"""
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power=1.0, use_cosine=True):
|
||||
super(LearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.cosine_decay_lr = CosineDecayLR(end_learning_rate, learning_rate, decay_steps)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
self.use_cosine = use_cosine
|
||||
|
||||
def construct(self, global_step):
|
||||
"""dynamic learning rate"""
|
||||
if not self.use_cosine:
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
else:
|
||||
decay_lr = self.cosine_decay_lr(global_step)
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
@ -0,0 +1,133 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
GPT train script
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.context import ParallelMode
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common import set_seed
|
||||
from src.dataset import create_dataset
|
||||
from src.gpt import GPT, GPTWithLoss, CrossEntropyLoss
|
||||
from src.gpt_wrapcell import GPTTrainOneStepWithLossScaleCell
|
||||
from src.utils import GPTConfig, LearningRate
|
||||
|
||||
def run_train():
|
||||
"""train function for GPT"""
|
||||
parser = argparse.ArgumentParser(description="GPT training")
|
||||
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
|
||||
help="Run distribute, default is false.")
|
||||
parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "lamb"],
|
||||
help="select which optimizer to be used, default adam")
|
||||
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10.")
|
||||
parser.add_argument("--warmup_step", type=int, default=10000, help="Warmup step, default is 10000.")
|
||||
parser.add_argument("--data_path", type=str, default="", help="Data path of your MindRecord files.")
|
||||
parser.add_argument("--start_lr", type=float, default="5e-5", help="Start learning rate, default is 5e-5.")
|
||||
parser.add_argument("--end_lr", type=float, default="1e-10", help="End learning rate, default is 1e-10.")
|
||||
parser.add_argument("--sink_size", type=int, default=100, help="Sink size for every iteration, default is 100")
|
||||
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
if args_opt.distribute == "true":
|
||||
D.init()
|
||||
device_num = args_opt.device_num
|
||||
rank = device_id % device_num
|
||||
print("device_id is {}, rank_id is {}".format(device_id, rank))
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
config = GPTConfig(batch_size=4,
|
||||
seq_length=1024,
|
||||
vocab_size=50257,
|
||||
embedding_size=1024,
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
expand_ratio=4,
|
||||
post_layernorm_residual=False,
|
||||
dropout_rate=0.1,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=False)
|
||||
gpt = GPT(config)
|
||||
loss = CrossEntropyLoss(config)
|
||||
gpt_with_loss = GPTWithLoss(gpt, loss)
|
||||
|
||||
ds = create_dataset(config.batch_size, data_path=args_opt.data_path, device_num=device_num, rank=rank)
|
||||
|
||||
|
||||
epoch_num = args_opt.epoch_size
|
||||
step_per_epoch = ds.get_dataset_size()
|
||||
|
||||
lr = LearningRate(learning_rate=args_opt.start_lr,
|
||||
end_learning_rate=args_opt.end_lr,
|
||||
warmup_steps=args_opt.warmup_step,
|
||||
decay_steps=epoch_num*step_per_epoch)
|
||||
|
||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||||
params = gpt.trainable_params()
|
||||
decay_params = list(filter(decay_filter, params))
|
||||
other_params = list(filter(lambda x: not decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': 1e-2},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
if args_opt.optimizer == "lamb":
|
||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr)
|
||||
|
||||
callback_size = args_opt.sink_size
|
||||
actual_epoch_num = int(epoch_num * step_per_epoch/callback_size)
|
||||
callback = [TimeMonitor(callback_size), LossMonitor(callback_size)]
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=step_per_epoch, keep_checkpoint_max=1)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="GPT2", config=config_ck)
|
||||
callback.append(ckpoint_cb)
|
||||
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024,
|
||||
scale_factor=2,
|
||||
scale_window=1000)
|
||||
|
||||
gpt_with_grads = GPTTrainOneStepWithLossScaleCell(gpt_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
|
||||
|
||||
model = Model(gpt_with_grads)
|
||||
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_seed(12315)
|
||||
run_train()
|
Loading…
Reference in new issue