diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index 3d754977d4..77a55f69bf 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -50,8 +50,8 @@ def _init_allreduce_operators(length): return opt_list -@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function") -def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce): +@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function", "Bool") +def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce, ps_parameter): """ Apply allreduce on gradient. @@ -66,7 +66,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc Returns: Tensor, the gradient tensor after operation. """ - if allreduce_filter: + if not ps_parameter and allreduce_filter: grad = allreduce(grad) if mean: degree = F.scalar_cast(degree, F.dtype(grad)) @@ -257,6 +257,8 @@ class DistributedGradReducer(Cell): self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) self.opt_list = _init_allreduce_operators(len(parameters)) self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) + ps_filter = lambda x: x.is_param_ps + self.ps_parameters = tuple(ps_filter(x) for x in parameters) def construct(self, grads): """ @@ -273,7 +275,7 @@ class DistributedGradReducer(Cell): datatypes = self.map_(F.partial(_get_datatype), grads) grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), - self.allreduce_filter, grads, self.opt_list) + self.allreduce_filter, grads, self.opt_list, self.ps_parameters) new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) return new_grad