allreduce add ps filter

pull/3153/head
jinyaohui 5 years ago
parent bbcefa731d
commit 59e519e8ed

@ -50,8 +50,8 @@ def _init_allreduce_operators(length):
return opt_list return opt_list
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function") @reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function", "Bool")
def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce): def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce, ps_parameter):
""" """
Apply allreduce on gradient. Apply allreduce on gradient.
@ -66,7 +66,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc
Returns: Returns:
Tensor, the gradient tensor after operation. Tensor, the gradient tensor after operation.
""" """
if allreduce_filter: if not ps_parameter and allreduce_filter:
grad = allreduce(grad) grad = allreduce(grad)
if mean: if mean:
degree = F.scalar_cast(degree, F.dtype(grad)) degree = F.scalar_cast(degree, F.dtype(grad))
@ -257,6 +257,8 @@ class DistributedGradReducer(Cell):
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
self.opt_list = _init_allreduce_operators(len(parameters)) self.opt_list = _init_allreduce_operators(len(parameters))
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in parameters)
def construct(self, grads): def construct(self, grads):
""" """
@ -273,7 +275,7 @@ class DistributedGradReducer(Cell):
datatypes = self.map_(F.partial(_get_datatype), grads) datatypes = self.map_(F.partial(_get_datatype), grads)
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
self.allreduce_filter, grads, self.opt_list) self.allreduce_filter, grads, self.opt_list, self.ps_parameters)
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad return new_grad

Loading…
Cancel
Save