From 516b56cb64c372801522084151731ec5327ef5a7 Mon Sep 17 00:00:00 2001 From: lirongzhen1 Date: Sat, 30 May 2020 19:13:32 +0800 Subject: [PATCH] sparse feature bp --- mindspore/nn/wrap/grad_reducer.py | 45 +++++++ mindspore/ops/_grad/grad_comm_ops.py | 55 ++++++-- .../parallel/test_sparse_feature_bprop.py | 118 ++++++++++++++++++ 3 files changed, 206 insertions(+), 12 deletions(-) create mode 100644 tests/ut/python/parallel/test_sparse_feature_bprop.py diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index 8383910a60..c66bfbe646 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -52,6 +52,31 @@ def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad): return grad +@reduce_opt.register("Function", "Number", "Bool", "Tuple") +def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad): + """ + Apply mean and allgather on gradient instead of allreduce for sparse feature. + Allgather is a communication operation used for distributed deep learning. + + Args: + mul (Primitive): Div operation. + degree (int): The mean coefficient. + allreduce_filter (bool): When it is true, allgather would apply. + grad (Tuple): The indices, gradient tensor and tensor_shape before operation. + + Returns: + Tuple, include indices, the gradient tensor and tensor_shape after operation. + """ + if allreduce_filter: + indices = _all_gather(grad[0]) + degree = F.scalar_cast(degree, F.dtype(grad[1])) + dout = _all_gather(grad[1]) + cast_op = P.Cast() + dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout))) + grad = (indices, dout, dout[2]) + return grad + + @reduce_opt.register("Bool", "Tensor") def _tensors_allreduce(allreduce_filter, grad): """ @@ -69,6 +94,26 @@ def _tensors_allreduce(allreduce_filter, grad): return grad +@reduce_opt.register("Bool", "Tuple") +def _tensors_allreduce_with_sparse(allreduce_filter, grad): + """ + Apply mean and allgather on gradient instead of allreduce for sparse feature. + Allgather is a communication operation used for distributed deep learning. + + Args: + allreduce_filter (bool): When it is true, allgather would apply. + grad (Tuple): The indices, gradient tensor and tensor_shape before operation. + + Returns: + Tuple, include indices, the gradient tensor and tensor_shape after operation. + """ + if allreduce_filter: + indices = _all_gather(grad[0]) + dout = _all_gather(grad[1]) + grad = (indices, dout, dout[2]) + return grad + + _get_datatype = C.MultitypeFuncGraph("_get_datatype") diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 057d150be1..7477d50895 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -26,9 +26,10 @@ from .grad_base import bprop_getters @bprop_getters.register(AllReduce) def get_bprop_all_reduce(self): - """Generate bprop for AllReduce.""" + """Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature.""" all_reduce_grad = AllReduce(ReduceOp.SUM, self.group) + all_gather = AllGather(group=self.group) if self.instance_name: instance_name = "grad" + self.instance_name all_reduce_grad.set_prim_instance_name(instance_name) @@ -42,15 +43,28 @@ def get_bprop_all_reduce(self): if self.op == ReduceOp.SUM: def bprop(x, out, dout): - dx = all_reduce_grad(dout) + if F.issubclass_(F.typeof(dout), mstype.tensor): + dx = all_reduce_grad(dout) + else: + indices = all_gather(dout[0]) + grad = all_gather(dout[1]) + dx = (indices, grad, dout[2]) return (dx,) else: def bprop(x, out, dout): - dx = all_reduce_grad(dout) - z = equal(x, out) - z = cast(z, dtype(dx)) - dx = mul(dx, z) + if F.issubclass_(F.typeof(dout), mstype.tensor): + dx = all_reduce_grad(dout) + z = equal(x, out) + z = cast(z, dtype(dx)) + dx = mul(dx, z) + else: + indices = all_gather(dout[0]) + grad = all_gather(dout[1]) + z = equal(x, out) + z = cast(z, dtype(grad)) + grad = mul(grad, z) + dx = (indices, grad, dout[2]) return (dx,) return bprop @@ -147,12 +161,16 @@ def get_bprop_all_to_all(self): @bprop_getters.register(_MirrorOperator) def get_bprop_mirror_operator(self): - """Backpropagator for _MirrorOperator, do allreduce for the devices in group(only for one group).""" + """ + Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group), + allgather for sparse feature. + """ group = self.group dev_num = self.dev_num mean_flag = self.mean_flag all_reduce = AllReduce(group=group) + all_gather = AllGather(group=group) mul = P.Mul() cast = P.Cast() @@ -170,12 +188,25 @@ def get_bprop_mirror_operator(self): def bprop(x, out, dout): if mean_flag: - dx = all_reduce(dout) - float_one = F.scalar_cast(1.0, F.dtype(dx)) - num = F.scalar_cast(dev_num, F.dtype(dx)) - dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) + if F.issubclass_(F.typeof(dout), mstype.tensor): + dx = all_reduce(dout) + float_one = F.scalar_cast(1.0, F.dtype(dx)) + num = F.scalar_cast(dev_num, F.dtype(dx)) + dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) + else: + indices = all_gather(dout[0]) + grad = all_gather(dout[1]) + float_one = F.scalar_cast(1.0, F.dtype(grad)) + num = F.scalar_cast(dev_num, F.dtype(grad)) + grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) + dx = (indices, grad, dout[2]) else: - dx = all_reduce(dout) + if F.issubclass_(F.typeof(dout), mstype.tensor): + dx = all_reduce(dout) + else: + indices = all_gather(dout[0]) + grad = all_gather(dout[1]) + dx = (indices, grad, dout[2]) return (dx,) return bprop diff --git a/tests/ut/python/parallel/test_sparse_feature_bprop.py b/tests/ut/python/parallel/test_sparse_feature_bprop.py new file mode 100644 index 0000000000..cd58261dbd --- /dev/null +++ b/tests/ut/python/parallel/test_sparse_feature_bprop.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test sparse feature bprop """ +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import context +from mindspore.common import dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.ops import composite as C +from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator +from mindspore.ops._grad.grad_base import bprop_getters +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer +from mindspore.common.api import _executor +from mindspore.communication.management import HCCL_WORLD_COMM_GROUP + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x): + return C.grad_all(self.network)(x) + +class VirtualGatherV2(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + """init index_select""" + super(VirtualGatherV2, self).__init__('VirtualGatherV2') + self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) + + def __infer__(self, params, indices, axis): + validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) + validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) + validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) + axis_v = axis['value'] + params_shp = params['shape'] + rank = len(params_shp) + validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) + if axis_v < 0: + axis_v += rank + out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] + out = {'shape': out_shape, + 'dtype': params['dtype'], + 'value': None} + return out + +@bprop_getters.register(VirtualGatherV2) +def get_bprop_gather_v2(self): + """Generate bprop for GatherV2""" + + def bprop(x, indices, axis, out, dout): + return (indices, dout, x), axis, out + + return bprop + +def test_bprop_with_sparse_feature_allreduce(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") + + class Net(nn.Cell): + def __init__(self, axis=0, shape=None): + super(Net, self).__init__() + if shape is None: + shape = [8, 8] + self.all_reduce = AllReduce() + self.gatherv2 = VirtualGatherV2() + self.index = Tensor(np.ones(shape), dtype=ms.int32) + self.axis = axis + + def construct(self, x): + out = self.all_reduce(x) + out = self.gatherv2(out, self.index, self.axis) + + return out + + net = GradWrap(Net()) + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + + _executor.compile(net, x) + +def test_bprop_with_sparse_feature_mirror(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") + + class Net(nn.Cell): + def __init__(self, axis=0, shape=None): + super(Net, self).__init__() + if shape is None: + shape = [8, 8] + self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP) + self.gatherv2 = VirtualGatherV2() + self.index = Tensor(np.ones(shape), dtype=ms.int32) + self.axis = axis + + def construct(self, x): + out = self.mirror(x) + out = self.gatherv2(out, self.index, self.axis) + + return out + + net = GradWrap(Net()) + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + + _executor.compile(net, x)