|
|
|
@ -220,7 +220,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|
|
|
|
self.depend_parameter_use = ControlDepend(depend_mode=1)
|
|
|
|
|
self.allreduce = P.AllReduce()
|
|
|
|
|
self.parallel_mode = _get_parallel_mode()
|
|
|
|
|
self.grad_reducer = None
|
|
|
|
|
self.grad_reducer = F.identity
|
|
|
|
|
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
|
|
|
|
|
if self.reducer_flag:
|
|
|
|
|
mean = _get_mirror_mean()
|
|
|
|
@ -250,9 +250,8 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|
|
|
|
scaling_sens = sens
|
|
|
|
|
grads = self.grad(self.network, weights)(data, label, F.cast(scaling_sens, F.dtype(loss)))
|
|
|
|
|
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
|
|
|
|
|
if self.reducer_flag:
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
# get the overflow buffer
|
|
|
|
|
if not self.gpu_target:
|
|
|
|
|
self.get_status(init)
|
|
|
|
|