From 5e041966f132dda07b5486fdf32e8e6ea20bc3f2 Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Fri, 8 May 2020 15:08:36 +0800 Subject: [PATCH] add a new vritualdataset cell for three inputs --- mindspore/nn/wrap/__init__.py | 5 ++-- mindspore/nn/wrap/cell_wrapper.py | 30 +++++++++++++++++++ .../parallel/test_virtual_dataset_3_input.py | 24 +++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/wrap/__init__.py b/mindspore/nn/wrap/__init__.py index a07fc51a1f..813c8bf766 100644 --- a/mindspore/nn/wrap/__init__.py +++ b/mindspore/nn/wrap/__init__.py @@ -18,7 +18,7 @@ Wrap cells for networks. Use the Wrapper to combine the loss or build the training steps. """ from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \ - ParameterUpdate, GetNextSingleOp + ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell from .grad_reducer import DistributedGradReducer @@ -33,5 +33,6 @@ __all__ = [ "DistributedGradReducer", "ParameterUpdate", "DynamicLossScaleUpdateCell", - "FixedLossScaleUpdateCell" + "FixedLossScaleUpdateCell", + "VirtualDatasetCellTriple" ] diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 499d85b34b..fe69a2a6ea 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -278,6 +278,36 @@ class _VirtualDatasetCell(Cell): return self._backbone(data_, label_) +class VirtualDatasetCellTriple(Cell): + """ + Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. + + VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs + of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted + dynamically during the graph compile process. + + Note: + Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in + _VirtualDatasetCell. + + Args: + backbone (Cell): The target network to wrap. + + Examples: + >>> net = Net() + >>> net = VirtualDatasetCellTriple(net) + """ + + def __init__(self, backbone): + super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False) + self._backbone = backbone + self._virtual_dataset = _VirtualDataset() + + def construct(self, a, b, c): + a_, b_, c_ = self._virtual_dataset(a, b, c) + return self._backbone(a_, b_, c_) + + class WithEvalCell(Cell): r""" Cell that returns loss, output and label for evaluation. diff --git a/tests/ut/python/parallel/test_virtual_dataset_3_input.py b/tests/ut/python/parallel/test_virtual_dataset_3_input.py index 382195e3b9..484e31c21e 100644 --- a/tests/ut/python/parallel/test_virtual_dataset_3_input.py +++ b/tests/ut/python/parallel/test_virtual_dataset_3_input.py @@ -21,6 +21,7 @@ import mindspore as ms from mindspore.common.api import _executor from mindspore.ops import composite as C from mindspore.ops.operations.comm_ops import _VirtualDataset +from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from mindspore import context @@ -73,6 +74,29 @@ def test_virtual_dataset_3_input(): net.set_auto_parallel() _executor.compile(net, x, y, b) +def test_virtualdataset_cell_3_inputs(): + class Net(nn.Cell): + def __init__(self, strategy0, strategy1, strategy2, strategy3): + super().__init__() + self.matmul1 = P.MatMul().set_strategy(strategy1) + self.matmul2 = P.MatMul().set_strategy(strategy2) + self.gelu = P.Gelu().set_strategy(strategy3) + + def construct(self, x, y, b): + out = self.gelu(self.matmul1(x, y)) + out = self.matmul2(out, b) + return out + + net = GradWrap(VirtualDatasetCellTriple(NetWithLoss(Net(None, None, None, None)))) + context.set_context(save_graphs=True) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + context.set_auto_parallel_context(device_num=8, global_rank=0) + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 64]), dtype=ms.float32) + b = Tensor(np.ones([64, 2048]), dtype=ms.float32) + net.set_auto_parallel() + _executor.compile(net, x, y, b) + if __name__ == '__main__': test_virtual_dataset_3_input()