update control flow int adamweightdecay for bert

pull/10886/head
VectorSL 5 years ago
parent da7ce4a2e9
commit 0c97835662

2
akg

@ -1 +1 @@
Subproject commit ae997e27b217d6c8c7a6cbf6ef812186835d2bdf
Subproject commit f4f118a2debd2eacc3f2ab6dc31846f1e04d6e13

@ -88,7 +88,6 @@ __global__ void IsFinite(const size_t size, const half* input, bool* out) {
template <typename T>
__global__ void FloatStatus(const size_t size, const T* input, T* out) {
out[0] = 0;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (isinf(input[pos]) != 0 || isnan(input[pos])) {
out[0] = 1;
@ -98,7 +97,6 @@ __global__ void FloatStatus(const size_t size, const T* input, T* out) {
}
template <>
__global__ void FloatStatus(const size_t size, const half* input, half* out) {
out[0] = 0;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) {
out[0] = 1;

@ -24,6 +24,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"
namespace mindspore {
namespace kernel {
@ -46,6 +47,7 @@ class FloatStatusGpuKernel : public GpuKernel {
switch (kernel_name_) {
case OP_STATUS: {
T *output = GetDeviceAddress<T>(outputs, 0);
FillDeviceArray(outputs[0]->size / sizeof(T), output, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}

@ -32,7 +32,8 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore import log as logger
from mindspore.common import set_seed
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
BertTrainAccumulateStepsWithLossScaleCell
BertTrainAccumulateStepsWithLossScaleCell, BertTrainOneStepWithLossScaleCellForAdam, \
AdamWeightDecayForBert
from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg
from src.utils import LossCallBack, BertLearningRate
@ -83,8 +84,10 @@ def _get_optimizer(args_opt, network):
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
if args_opt.enable_lossscale == "true":
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else:
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
format(cfg.optimizer))
@ -206,8 +209,12 @@ def run_pretrain():
scale_window=cfg.scale_window)
if args_opt.accumulation_steps <= 1:
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
if cfg.optimizer == 'AdamWeightDecay':
net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
else:
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
else:
accumulation_steps = args_opt.accumulation_steps
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,

@ -21,13 +21,13 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
SaturateCast, CreateAttentionMaskFromInputMask
from .adam import AdamWeightDecayForBert
__all__ = [
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell",
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert",
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask"
]

File diff suppressed because it is too large Load Diff

@ -440,6 +440,120 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)
class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow
condition as input.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.gpu_target = False
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
init = False
if not self.gpu_target:
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if not self.gpu_target:
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if self.loss_scaling_manager is not None:
overflow = self.loss_scaling_manager(scaling_sens, cond)
succ = self.optimizer(grads, overflow)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)
cast = P.Cast()
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")

Loading…
Cancel
Save