|
|
|
@ -440,6 +440,120 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
ret = (loss, cond, scaling_sens)
|
|
|
|
|
return F.depend(ret, succ)
|
|
|
|
|
|
|
|
|
|
class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
Encapsulation class of bert network training.
|
|
|
|
|
|
|
|
|
|
Append an optimizer to the training network after that the construct
|
|
|
|
|
function can be called to create the backward graph.
|
|
|
|
|
Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow
|
|
|
|
|
condition as input.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
network (Cell): The training network. Note that loss function should have been added.
|
|
|
|
|
optimizer (Optimizer): Optimizer for updating the weights.
|
|
|
|
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
|
|
|
|
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_grad()
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.grad = C.GradOperation(get_by_list=True,
|
|
|
|
|
sens_param=True)
|
|
|
|
|
self.reducer_flag = False
|
|
|
|
|
self.allreduce = P.AllReduce()
|
|
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
|
|
|
|
self.reducer_flag = True
|
|
|
|
|
self.grad_reducer = F.identity
|
|
|
|
|
self.degree = 1
|
|
|
|
|
if self.reducer_flag:
|
|
|
|
|
self.degree = get_group_size()
|
|
|
|
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
|
|
|
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
if context.get_context("device_target") == "GPU":
|
|
|
|
|
self.gpu_target = True
|
|
|
|
|
self.float_status = P.FloatStatus()
|
|
|
|
|
self.addn = P.AddN()
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
else:
|
|
|
|
|
self.gpu_target = False
|
|
|
|
|
self.alloc_status = P.NPUAllocFloatStatus()
|
|
|
|
|
self.get_status = P.NPUGetFloatStatus()
|
|
|
|
|
self.clear_before_grad = P.NPUClearFloatStatus()
|
|
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
|
|
|
|
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
|
|
|
|
self.base = Tensor(1, mstype.float32)
|
|
|
|
|
self.less_equal = P.LessEqual()
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
self.loss_scale = None
|
|
|
|
|
self.loss_scaling_manager = scale_update_cell
|
|
|
|
|
if scale_update_cell:
|
|
|
|
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
|
|
|
|
|
|
|
|
|
@C.add_flags(has_effect=True)
|
|
|
|
|
def construct(self,
|
|
|
|
|
input_ids,
|
|
|
|
|
input_mask,
|
|
|
|
|
token_type_id,
|
|
|
|
|
next_sentence_labels,
|
|
|
|
|
masked_lm_positions,
|
|
|
|
|
masked_lm_ids,
|
|
|
|
|
masked_lm_weights,
|
|
|
|
|
sens=None):
|
|
|
|
|
"""Defines the computation performed."""
|
|
|
|
|
weights = self.weights
|
|
|
|
|
loss = self.network(input_ids,
|
|
|
|
|
input_mask,
|
|
|
|
|
token_type_id,
|
|
|
|
|
next_sentence_labels,
|
|
|
|
|
masked_lm_positions,
|
|
|
|
|
masked_lm_ids,
|
|
|
|
|
masked_lm_weights)
|
|
|
|
|
if sens is None:
|
|
|
|
|
scaling_sens = self.loss_scale
|
|
|
|
|
else:
|
|
|
|
|
scaling_sens = sens
|
|
|
|
|
init = False
|
|
|
|
|
if not self.gpu_target:
|
|
|
|
|
# alloc status and clear should be right before gradoperation
|
|
|
|
|
init = self.alloc_status()
|
|
|
|
|
self.clear_before_grad(init)
|
|
|
|
|
grads = self.grad(self.network, weights)(input_ids,
|
|
|
|
|
input_mask,
|
|
|
|
|
token_type_id,
|
|
|
|
|
next_sentence_labels,
|
|
|
|
|
masked_lm_positions,
|
|
|
|
|
masked_lm_ids,
|
|
|
|
|
masked_lm_weights,
|
|
|
|
|
self.cast(scaling_sens,
|
|
|
|
|
mstype.float32))
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
|
|
|
if not self.gpu_target:
|
|
|
|
|
self.get_status(init)
|
|
|
|
|
flag_sum = self.reduce_sum(init, (0,))
|
|
|
|
|
else:
|
|
|
|
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
|
|
|
|
flag_sum = self.addn(flag_sum)
|
|
|
|
|
flag_sum = self.reshape(flag_sum, (()))
|
|
|
|
|
if self.is_distributed:
|
|
|
|
|
# sum overflow flag over devices
|
|
|
|
|
flag_reduce = self.allreduce(flag_sum)
|
|
|
|
|
cond = self.less_equal(self.base, flag_reduce)
|
|
|
|
|
else:
|
|
|
|
|
cond = self.less_equal(self.base, flag_sum)
|
|
|
|
|
overflow = cond
|
|
|
|
|
if self.loss_scaling_manager is not None:
|
|
|
|
|
overflow = self.loss_scaling_manager(scaling_sens, cond)
|
|
|
|
|
succ = self.optimizer(grads, overflow)
|
|
|
|
|
ret = (loss, cond, scaling_sens)
|
|
|
|
|
return F.depend(ret, succ)
|
|
|
|
|
|
|
|
|
|
cast = P.Cast()
|
|
|
|
|
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
|
|
|
|