|
|
|
@ -13,107 +13,95 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""grad reducer cell for distributed training"""
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.nn.cell import Cell
|
|
|
|
|
from mindspore.communication.management import GlobalComm, get_group_size
|
|
|
|
|
from mindspore.ops import functional as F, composite as C, operations as P
|
|
|
|
|
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather
|
|
|
|
|
from mindspore.ops.operations.comm_ops import AllReduce, AllGather
|
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
|
|
|
|
|
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
|
|
|
|
|
|
|
|
|
_all_reduce = AllReduce()
|
|
|
|
|
_all_gather = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_optimizer_communication():
|
|
|
|
|
global _all_reduce
|
|
|
|
|
global _all_gather
|
|
|
|
|
|
|
|
|
|
_all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
|
|
|
|
_all_reduce.add_prim_attr('fusion', 1)
|
|
|
|
|
_all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Function", "Number", "Bool", "Tensor")
|
|
|
|
|
def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad):
|
|
|
|
|
def _init_allreduce_operators(length):
|
|
|
|
|
""" initialize allreduce communication operators"""
|
|
|
|
|
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
|
|
|
|
|
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
|
|
|
|
|
if is_parallel_optimizer and split_indices:
|
|
|
|
|
group = 1
|
|
|
|
|
fusion = ()
|
|
|
|
|
for i in range(length):
|
|
|
|
|
fusion = fusion + (group,)
|
|
|
|
|
if split_indices[group - 1] <= i + 1:
|
|
|
|
|
if group >= len(split_indices):
|
|
|
|
|
continue
|
|
|
|
|
group = group + 1
|
|
|
|
|
index = tuple(range(1, length + 1))
|
|
|
|
|
else:
|
|
|
|
|
fusion = (1,) * length
|
|
|
|
|
index = (0,) * length
|
|
|
|
|
opt_list = ()
|
|
|
|
|
for i in range(length):
|
|
|
|
|
opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
|
|
|
|
|
opt.add_prim_attr('fusion', fusion[i])
|
|
|
|
|
opt.add_prim_attr('index', index[i])
|
|
|
|
|
opt_list = opt_list + (opt,)
|
|
|
|
|
return opt_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function")
|
|
|
|
|
def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce):
|
|
|
|
|
"""
|
|
|
|
|
Apply mean and allreduce on gradient. Allreduce is a communication operation used for distributed deep learning.
|
|
|
|
|
Apply allreduce on gradient.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mul (Primitive): Div operation.
|
|
|
|
|
degree (int): The mean coefficient.
|
|
|
|
|
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
|
|
|
|
allgather (Primitive): The communication operator for sparse gradients.
|
|
|
|
|
allreduce_filter (bool): When it is true, allreduce would apply.
|
|
|
|
|
grad (Tensor): The gradient tensor before operation.
|
|
|
|
|
allreduce (Primitive): The communication operator for gradients.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, the gradient tensor after operation.
|
|
|
|
|
"""
|
|
|
|
|
if allreduce_filter:
|
|
|
|
|
degree = F.scalar_cast(degree, F.dtype(grad))
|
|
|
|
|
grad = _all_reduce(grad)
|
|
|
|
|
cast_op = P.Cast()
|
|
|
|
|
return mul(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
|
|
|
|
|
grad = allreduce(grad)
|
|
|
|
|
if mean:
|
|
|
|
|
degree = F.scalar_cast(degree, F.dtype(grad))
|
|
|
|
|
cast_op = P.Cast()
|
|
|
|
|
mul_op = P.Mul()
|
|
|
|
|
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
|
|
|
|
|
return grad
|
|
|
|
|
return grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Function", "Number", "Bool", "Tuple")
|
|
|
|
|
def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad):
|
|
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tuple", "Function")
|
|
|
|
|
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce):
|
|
|
|
|
"""
|
|
|
|
|
Apply mean and allgather on gradient instead of allreduce for sparse feature.
|
|
|
|
|
Apply allgather on gradient instead of allreduce for sparse feature.
|
|
|
|
|
Allgather is a communication operation used for distributed deep learning.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mul (Primitive): Div operation.
|
|
|
|
|
degree (int): The mean coefficient.
|
|
|
|
|
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
|
|
|
|
allgather (Primitive): The communication operator for sparse gradients.
|
|
|
|
|
allreduce_filter (bool): When it is true, allgather would apply.
|
|
|
|
|
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
|
|
|
|
|
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
|
|
|
|
|
allreduce (Primitive): The communication operator for gradients.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple, include indices, the gradient tensor and tensor_shape after operation.
|
|
|
|
|
"""
|
|
|
|
|
if allreduce_filter:
|
|
|
|
|
indices = _all_gather(grad[0])
|
|
|
|
|
degree = F.scalar_cast(degree, F.dtype(grad[1]))
|
|
|
|
|
dout = _all_gather(grad[1])
|
|
|
|
|
cast_op = P.Cast()
|
|
|
|
|
dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
|
|
|
|
grad = (indices, dout, grad[2])
|
|
|
|
|
return grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Bool", "Tensor")
|
|
|
|
|
def _tensors_allreduce(allreduce_filter, grad):
|
|
|
|
|
"""
|
|
|
|
|
Apply allreduce on gradient.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
allreduce_filter (bool): When it is true, allreduce would apply.
|
|
|
|
|
grad (Tensor): The gradient tensor before operation.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, the gradient tensor after operation.
|
|
|
|
|
"""
|
|
|
|
|
if allreduce_filter:
|
|
|
|
|
return _all_reduce(grad)
|
|
|
|
|
return grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Bool", "Tuple")
|
|
|
|
|
def _tensors_allreduce_with_sparse(allreduce_filter, grad):
|
|
|
|
|
"""
|
|
|
|
|
Apply mean and allgather on gradient instead of allreduce for sparse feature.
|
|
|
|
|
Allgather is a communication operation used for distributed deep learning.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
allreduce_filter (bool): When it is true, allgather would apply.
|
|
|
|
|
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple, include indices, the gradient tensor and tensor_shape after operation.
|
|
|
|
|
"""
|
|
|
|
|
if allreduce_filter:
|
|
|
|
|
indices = _all_gather(grad[0])
|
|
|
|
|
dout = _all_gather(grad[1])
|
|
|
|
|
indices = allgather(grad[0])
|
|
|
|
|
dout = allgather(grad[1])
|
|
|
|
|
if mean:
|
|
|
|
|
degree = F.scalar_cast(degree, F.dtype(grad[1]))
|
|
|
|
|
cast_op = P.Cast()
|
|
|
|
|
mul_op = P.Mul()
|
|
|
|
|
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
|
|
|
|
grad = (indices, dout, grad[2])
|
|
|
|
|
return grad
|
|
|
|
|
|
|
|
|
@ -259,7 +247,6 @@ class DistributedGradReducer(Cell):
|
|
|
|
|
def __init__(self, parameters, mean=True, degree=None):
|
|
|
|
|
super(DistributedGradReducer, self).__init__(auto_prefix=False)
|
|
|
|
|
self.map_ = C.Map()
|
|
|
|
|
self.mul = P.Mul()
|
|
|
|
|
if degree is None:
|
|
|
|
|
self.degree = get_group_size()
|
|
|
|
|
else:
|
|
|
|
@ -268,7 +255,8 @@ class DistributedGradReducer(Cell):
|
|
|
|
|
self.degree = degree
|
|
|
|
|
self.mean = mean
|
|
|
|
|
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
|
|
|
|
_init_optimizer_communication()
|
|
|
|
|
self.opt_list = _init_allreduce_operators(len(parameters))
|
|
|
|
|
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
|
|
|
|
|
|
|
|
|
|
def construct(self, grads):
|
|
|
|
|
"""
|
|
|
|
@ -284,11 +272,8 @@ class DistributedGradReducer(Cell):
|
|
|
|
|
"""
|
|
|
|
|
datatypes = self.map_(F.partial(_get_datatype), grads)
|
|
|
|
|
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
|
|
|
|
|
|
|
|
|
|
if self.mean:
|
|
|
|
|
new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads)
|
|
|
|
|
else:
|
|
|
|
|
new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads)
|
|
|
|
|
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
|
|
|
|
self.allreduce_filter, grads, self.opt_list)
|
|
|
|
|
|
|
|
|
|
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
|
|
|
|
|
return new_grad
|
|
|
|
|