|
|
|
@ -487,9 +487,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|
|
|
|
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32), name="accu_overflow")
|
|
|
|
|
self.loss = Parameter(initializer(0, [1], mstype.float32), name="accu_loss")
|
|
|
|
|
|
|
|
|
|
self.grad = C.GradOperation('grad',
|
|
|
|
|
get_by_list=True,
|
|
|
|
|
sens_param=True)
|
|
|
|
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
|
|
|
self.reducer_flag = False
|
|
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
|
|
|
|