!298 bugfix(side effect): fix adding wrong control depend between AllReduce and GetStatus

Merge pull request !298 from gongchen/fix_allreduce_control_depend
pull/298/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit e09364df48

@ -370,7 +370,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
self.grad_reducer = F.identity
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
@ -428,9 +428,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:

@ -220,7 +220,7 @@ class TrainOneStepWithLossScaleCell(Cell):
self.depend_parameter_use = ControlDepend(depend_mode=1)
self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
self.grad_reducer = None
self.grad_reducer = F.identity
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
if self.reducer_flag:
mean = _get_mirror_mean()
@ -250,9 +250,8 @@ class TrainOneStepWithLossScaleCell(Cell):
scaling_sens = sens
grads = self.grad(self.network, weights)(data, label, F.cast(scaling_sens, F.dtype(loss)))
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
if not self.gpu_target:
self.get_status(init)

Loading…
Cancel
Save