!3222 add ps filter

Merge pull request !3222 from jinyaohui/master
pull/3222/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 50a784bffe

@ -57,12 +57,15 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation.
ps_parameter(Bool): Use parameter server or not.
ps_parameter (bool): Use parameter server or not.
Returns:
Tensor, the gradient tensor after operation.
"""
if not ps_parameter and allreduce_filter:
if ps_parameter:
return grad
if allreduce_filter:
grad = allreduce(grad)
if mean:
degree = F.scalar_cast(degree, F.dtype(grad))
@ -73,8 +76,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
return grad
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
"""
Apply allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning.
@ -86,10 +89,14 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allgather would apply.
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
ps_parameter (bool): Use parameter server or not.
Returns:
IndexedSlices, the gradient after operation.
"""
if ps_parameter:
return grad
if allreduce_filter:
indices = allgather(grad.indices())
dout = allgather(grad.values())

Loading…
Cancel
Save