|
|
|
@ -51,8 +51,8 @@ def _init_allreduce_operators(length):
|
|
|
|
|
return opt_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function")
|
|
|
|
|
def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce):
|
|
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function", "Bool")
|
|
|
|
|
def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce, ps_parameter):
|
|
|
|
|
"""
|
|
|
|
|
Apply allreduce on gradient.
|
|
|
|
|
|
|
|
|
@ -67,7 +67,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, the gradient tensor after operation.
|
|
|
|
|
"""
|
|
|
|
|
if allreduce_filter:
|
|
|
|
|
if not ps_parameter and allreduce_filter:
|
|
|
|
|
grad = allreduce(grad)
|
|
|
|
|
if mean:
|
|
|
|
|
degree = F.scalar_cast(degree, F.dtype(grad))
|
|
|
|
@ -258,6 +258,8 @@ class DistributedGradReducer(Cell):
|
|
|
|
|
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
|
|
|
|
self.opt_list = _init_allreduce_operators(len(parameters))
|
|
|
|
|
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
|
|
|
|
|
ps_filter = lambda x: x.is_param_ps
|
|
|
|
|
self.ps_parameters = tuple(ps_filter(x) for x in parameters)
|
|
|
|
|
|
|
|
|
|
def construct(self, grads):
|
|
|
|
|
"""
|
|
|
|
@ -274,7 +276,7 @@ class DistributedGradReducer(Cell):
|
|
|
|
|
datatypes = self.map_(F.partial(_get_datatype), grads)
|
|
|
|
|
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
|
|
|
|
|
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
|
|
|
|
self.allreduce_filter, grads, self.opt_list)
|
|
|
|
|
self.allreduce_filter, grads, self.opt_list, self.ps_parameters)
|
|
|
|
|
|
|
|
|
|
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
|
|
|
|
|
return new_grad
|
|
|
|
|