|
|
|
@ -16,18 +16,22 @@
|
|
|
|
|
from mindspore.nn.cell import Cell
|
|
|
|
|
from mindspore.communication.management import GlobalComm, get_group_size
|
|
|
|
|
from mindspore.ops import functional as F, composite as C, operations as P
|
|
|
|
|
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp
|
|
|
|
|
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
|
|
|
|
|
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
|
|
|
|
|
|
|
|
|
_all_reduce = AllReduce()
|
|
|
|
|
_all_gather = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_optimizer_allreduce():
|
|
|
|
|
def _init_optimizer_communication():
|
|
|
|
|
global _all_reduce
|
|
|
|
|
global _all_gather
|
|
|
|
|
|
|
|
|
|
_all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
|
|
|
|
_all_reduce.add_prim_attr('fusion', 1)
|
|
|
|
|
_all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Function", "Number", "Bool", "Tensor")
|
|
|
|
@ -72,8 +76,8 @@ def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad):
|
|
|
|
|
degree = F.scalar_cast(degree, F.dtype(grad[1]))
|
|
|
|
|
dout = _all_gather(grad[1])
|
|
|
|
|
cast_op = P.Cast()
|
|
|
|
|
dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout)))
|
|
|
|
|
grad = (indices, dout, dout[2])
|
|
|
|
|
dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
|
|
|
|
grad = (indices, dout, grad[2])
|
|
|
|
|
return grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -110,7 +114,7 @@ def _tensors_allreduce_with_sparse(allreduce_filter, grad):
|
|
|
|
|
if allreduce_filter:
|
|
|
|
|
indices = _all_gather(grad[0])
|
|
|
|
|
dout = _all_gather(grad[1])
|
|
|
|
|
grad = (indices, dout, dout[2])
|
|
|
|
|
grad = (indices, dout, grad[2])
|
|
|
|
|
return grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -131,6 +135,20 @@ def _tensors_get_datatype(grad):
|
|
|
|
|
return F.dtype(grad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_get_datatype.register("Tuple")
|
|
|
|
|
def _tensors_get_datatype_with_sparse(grad):
|
|
|
|
|
"""
|
|
|
|
|
Acquire gradient datatype.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
grad (Tuple): The gradient tensor before operation.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
mstype, the datatype of gradient.
|
|
|
|
|
"""
|
|
|
|
|
return F.dtype(grad[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -149,6 +167,22 @@ def _tensors_cast_datatype(datatype, grad):
|
|
|
|
|
return F.cast(grad, datatype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_cast_datatype.register("TypeType", "Tuple")
|
|
|
|
|
def _tensors_cast_datatype_with_sparse(datatype, grad):
|
|
|
|
|
"""
|
|
|
|
|
Cast gradient to datatype.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
datatype (mstype): the destination datatype of gradient.
|
|
|
|
|
grad (Tuple): The gradient tensor before operation.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple, the gradient tuple after operation.
|
|
|
|
|
"""
|
|
|
|
|
dout = F.cast(grad[1], datatype)
|
|
|
|
|
return (grad[0], dout, grad[2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedGradReducer(Cell):
|
|
|
|
|
"""
|
|
|
|
|
A distributed optimizer.
|
|
|
|
@ -224,7 +258,7 @@ class DistributedGradReducer(Cell):
|
|
|
|
|
|
|
|
|
|
def __init__(self, parameters, mean=True, degree=None):
|
|
|
|
|
super(DistributedGradReducer, self).__init__(auto_prefix=False)
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
self.map_ = C.Map()
|
|
|
|
|
self.mul = P.Mul()
|
|
|
|
|
if degree is None:
|
|
|
|
|
self.degree = get_group_size()
|
|
|
|
@ -234,19 +268,27 @@ class DistributedGradReducer(Cell):
|
|
|
|
|
self.degree = degree
|
|
|
|
|
self.mean = mean
|
|
|
|
|
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
|
|
|
|
_init_optimizer_allreduce()
|
|
|
|
|
_init_optimizer_communication()
|
|
|
|
|
|
|
|
|
|
def construct(self, grads):
|
|
|
|
|
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
|
|
|
|
|
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
|
|
|
|
|
# and cast back after the operation.
|
|
|
|
|
datatypes = self.hyper_map(F.partial(_get_datatype), grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)
|
|
|
|
|
"""
|
|
|
|
|
In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
|
|
|
|
|
result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
|
|
|
|
|
and cast back after the operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
grads (Union[Tensor, tuple[Tensor]]): The gradient tensor or tuple before operation.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
new_grads (Union[Tensor, tuple[Tensor]]), the gradient tensor or tuple after operation.
|
|
|
|
|
"""
|
|
|
|
|
datatypes = self.map_(F.partial(_get_datatype), grads)
|
|
|
|
|
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
|
|
|
|
|
|
|
|
|
|
if self.mean:
|
|
|
|
|
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads)
|
|
|
|
|
new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads)
|
|
|
|
|
else:
|
|
|
|
|
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)
|
|
|
|
|
new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads)
|
|
|
|
|
|
|
|
|
|
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
|
|
|
|
|
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
|
|
|
|
|
return new_grad
|
|
|
|
|