From 23284f0b3587a97c1ceaf4b477d3c1ec1fdaddc4 Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Wed, 11 Nov 2020 18:20:26 +0800 Subject: [PATCH] Add AllSwap Op --- mindspore/core/abstract/infer_functions.h | 2 + mindspore/core/abstract/prim_others.cc | 39 +++++++++++++ .../core/abstract/primitive_infer_map.cc | 1 + mindspore/core/base/core_ops.h | 1 + mindspore/ops/_grad/grad_comm_ops.py | 17 +++++- mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/comm_ops.py | 55 ++++++++++++++++++- tests/ut/python/communication/test_comm.py | 30 +++++++++- 8 files changed, 144 insertions(+), 4 deletions(-) diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index b475959046..c40dee6cae 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -217,6 +217,8 @@ AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const Primit const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 08e0449ef8..537929aafc 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -367,6 +367,45 @@ AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, co return sparse_tensor->dense_shape(); } +AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto tensor_in = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(tensor_in); + MS_EXCEPTION_IF_NULL(tensor_in->shape()); + auto tensor_in_shape = tensor_in->shape()->shape(); + + auto send_size = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(send_size); + auto recv_size = CheckArg(op_name, args_spec_list, 2); + MS_EXCEPTION_IF_NULL(recv_size); + + // Get the content of the recv size + auto recv_size_value_ptr = recv_size->BuildValue(); + MS_EXCEPTION_IF_NULL(recv_size_value_ptr); + auto recv_size_tensor = recv_size_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(recv_size_tensor); + auto data_pos = reinterpret_cast(recv_size_tensor->data_c()); + MS_EXCEPTION_IF_NULL(data_pos); + int64_t infer_max_size = 0; + for (int64_t i = 0; i < recv_size_tensor->DataSize(); ++i) { + infer_max_size += *(data_pos + i); + } + + ShapeVector tensor_out_shape = {Shape::SHP_ANY, tensor_in_shape[1]}; + ShapeVector min_shape = {1, tensor_in_shape[1]}; + + ShapeVector max_shape = {infer_max_size / tensor_in_shape[1], tensor_in_shape[1]}; + + auto tensor_out = std::make_shared(tensor_in->element(), + std::make_shared(tensor_out_shape, min_shape, max_shape)); + + AbstractTensorPtr ret = std::make_shared( + tensor_out->element(), std::make_shared(tensor_out_shape, min_shape, max_shape)); + return ret; +} + AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index b285453d1b..7feb467fd8 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -135,6 +135,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimAllReduce, {InferImplAllReduce, true}}, {prim::kPrimBroadcast, {InferImplBroadcast, true}}, {prim::kPrimAllGather, {InferImplAllGather, true}}, + {prim::kPrimAllSwap, {InferImplAllSwap, true}}, {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, {prim::kPrimCast, {InferImplCast, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index a9f05a023d..402bad385c 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -186,6 +186,7 @@ inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOper inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); inline const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); +inline const PrimitivePtr kPrimAllSwap = std::make_shared("AllSwap"); inline const PrimitivePtr kPrimBroadcast = std::make_shared("Broadcast"); inline const PrimitivePtr kPrimAllGather = std::make_shared("AllGather"); inline const PrimitivePtr kPrimReduceScatter = std::make_shared("ReduceScatter"); diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index c219e76a9e..d17b1f6e53 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -21,7 +21,7 @@ from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, _GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive, - ReduceScatter, _HostReduceScatter, _VirtualDiv) + ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) from .grad_base import bprop_getters @@ -155,6 +155,21 @@ def get_bprop_reduce_scatter(self): return bprop +@bprop_getters.register(AllSwap) +def get_bprop_allswap(self): + """Generate bprop for AllSwap.""" + all_swap_grad = AllSwap(self.group) + if self.instance_name: + instance_name = "grad" + self.instance_name + all_to_all_grad.set_prim_instance_name(instance_name) + + def bprop(x, send_size, recv_size, out, dout): + dx = all_swap_grad(dout, recv_size, send_size) + return (dx,) + + return bprop + + @bprop_getters.register(_HostReduceScatter) def get_bprop_host_reduce_scatter(self): """Generate bprop for _HostReduceScatter""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index c003e992b9..e6f02cd56f 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -34,7 +34,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, Unique, GatherD, Identity, RepeatElements) -from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, +from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, Send, Receive, _HostAllGather, _HostReduceScatter) @@ -294,6 +294,7 @@ __all__ = [ 'UnsortedSegmentProd', "AllGather", "AllReduce", + "AllSwap", "ReduceScatter", "Broadcast", "ReduceOp", diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 6ccd4287af..96e9e03293 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -20,7 +20,7 @@ from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group from ...common import dtype as mstype -from ..primitive import PrimitiveWithInfer, prim_attr_register +from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register class ReduceOp: @@ -507,6 +507,59 @@ class Broadcast(PrimitiveWithInfer): return x_dtype +class AllSwap(PrimitiveWithCheck): + """ + AllSwap is a collective operation. + + AllSwap sends data from the all processes to the all processes in the specified group. It has two phases: + + - The scatter phase: On each process, the operand is split into the send size of blocks along the + 0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process. + - The gather phase: Each process concatenates the received blocks along the 0-th axis. + + Note: + The tensors must have the same format in all processes of the collection. + + Args: + group (str): The communication group name. + + Inputs: + tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size. + send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process. + recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process. + + Returns: + tensor_out (tensor): The result tensor. + + Raises: + TypeError: If group is not a string. + """ + + @prim_attr_register + def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): + """Initialize AllSwap""" + validator.check_value_type('group', _get_group(group), (str,), self.name) + self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out']) + self.add_prim_attr('group', _get_group(group)) + + def __check__(self, tensor_in, send_size, recv_size): + validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name) + validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64], + self.name) + validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64], + self.name) + + validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name) + validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name) + validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name) + + out_shape = [-1] + [tensor_in['shape'][1]] + out = {'shape': out_shape, + 'dtype': tensor_in['dtype'], + 'value': None} + return out + + class _AlltoAll(PrimitiveWithInfer): """ AlltoAll is a collective operation. diff --git a/tests/ut/python/communication/test_comm.py b/tests/ut/python/communication/test_comm.py index 7df6149e7b..bda7ae80fe 100644 --- a/tests/ut/python/communication/test_comm.py +++ b/tests/ut/python/communication/test_comm.py @@ -26,7 +26,9 @@ from mindspore.nn import Momentum from mindspore.nn import ReLU from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter -from mindspore.ops.operations.comm_ops import Broadcast +from mindspore.ops.operations.comm_ops import Broadcast, AllSwap +from mindspore.ops.operations.math_ops import ReduceSum +import mindspore # pylint: disable=W0212 # W0212: protected-access @@ -117,6 +119,25 @@ class AlltoAllNet(nn.Cell): return self.relu(x) +class AllSwapNet(nn.Cell): + """AlltoAllNet definition""" + + def __init__(self, batch_size, input_channel, out_channel): + super(AllSwapNet, self).__init__() + self.dense = Dense(input_channel, out_channel) + self.allswap = AllSwap() + self.relu = ReLU() + self.reduce = ReduceSum() + part_slice = batch_size / 2 + self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64) + self.recv_size = Tensor([part_slice*out_channel, part_slice*out_channel, 0], mindspore.int64) + def construct(self, x): + x = self.dense(x) + x = self.allswap(x, self.send_size, self.recv_size) + x = self.relu(x) + return x + + def run_allreduce(op): """run_allreduce""" context.set_context(mode=context.GRAPH_MODE) @@ -154,6 +175,13 @@ def test_allgather(): network = TrainOneStepCell(network, optimizer) _executor.compile(network, input_tensor, label_tensor) +def test_allswap(): + """run_allswap""" + context.set_context(mode=context.GRAPH_MODE) + input_tensor = Tensor(np.ones((100, 20)), dtype=mindspore.float32) + network = AllSwapNet(100, 20, 20) + _executor.compile(network, input_tensor) + def run_reducescatter(op): """run_reducescatter"""