From 110640e2ad7dd862a5f2d9ad96f4607a481bc639 Mon Sep 17 00:00:00 2001 From: yangzhenzhang <285824651@qq.com> Date: Tue, 31 Mar 2020 15:40:43 +0800 Subject: [PATCH] add parallel ops for neg and batchmatmul --- mindspore/ccsrc/parallel/dynamic_creator.h | 2 + .../ccsrc/parallel/ops_info/activation_info.h | 7 ++ .../ccsrc/parallel/ops_info/matmul_info.h | 8 ++ mindspore/ccsrc/parallel/ops_info/ops_utils.h | 2 + .../ccsrc/parallel/step_auto_parallel.cc | 2 + tests/ut/python/parallel/test_batch_matmul.py | 93 +++++++++++++++++++ tests/ut/python/parallel/test_neg.py | 84 +++++++++++++++++ 7 files changed, 198 insertions(+) create mode 100644 tests/ut/python/parallel/test_batch_matmul.py create mode 100644 tests/ut/python/parallel/test_neg.py diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index 59b8722435..e6e1b41d76 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -123,6 +123,8 @@ REGISTER(ReLUInfo); REGISTER(GatherV2Info); REGISTER(SqrtInfo); REGISTER(GetNextInfo); +REGISTER(NegInfo); +REGISTER(BatchMatMulInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index d8de19b328..d05f8743b0 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -167,6 +167,13 @@ class SqrtInfo : public ActivationOther { : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SqrtInfo() override = default; }; + +class NegInfo : public ActivationOther { + public: + NegInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~NegInfo() override = default; +}; } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/parallel/ops_info/matmul_info.h index c9feae55b6..b434e4522d 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.h @@ -87,6 +87,14 @@ class MatMulInfo : public MatMul { : MatMul(name, inputs_shape, outputs_shape, attrs) {} ~MatMulInfo() override = default; }; + +class BatchMatMulInfo : public MatMul { + public: + BatchMatMulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, + const PrimitiveAttrs& attrs) + : MatMul(name, inputs_shape, outputs_shape, attrs) {} + ~BatchMatMulInfo() override = default; +}; } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index a25200c3c1..2b8fc0ee3f 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -188,6 +188,8 @@ constexpr char SQRT[] = "Sqrt"; constexpr char ASSIGN[] = "Assign"; constexpr char GET_NEXT[] = "GetNext"; constexpr char SQUEEZE[] = "Squeeze"; +constexpr char Neg[] = "Neg"; +constexpr char BATCH_MATMUL[] = "BatchMatMul"; // Parallel don't care constexpr char TUPLE_GETITEM[] = "tuple_getitem"; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index c3e3f5893e..cf388bea40 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -101,6 +101,8 @@ std::vector splittable_op_ = {MATMUL, SQRT, GET_NEXT, CAST, + Neg, + BATCH_MATMUL, SQUEEZE}; std::vector elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, diff --git a/tests/ut/python/parallel/test_batch_matmul.py b/tests/ut/python/parallel/test_batch_matmul.py new file mode 100644 index 0000000000..88ba818c91 --- /dev/null +++ b/tests/ut/python/parallel/test_batch_matmul.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P +from mindspore.common.api import _executor + + +class Net(Cell): + def __init__(self, mul_weight, batch_matmul_weight, transpose_b=False, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.batch_matmul = P.BatchMatMul(transpose_b=transpose_b).set_strategy(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + self.batch_matmul_weight = Parameter(batch_matmul_weight, "w2") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.batch_matmul(out, self.batch_matmul_weight) + return out + + +_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w2 = Tensor(np.ones([128, 32, 32]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 16]), dtype=ms.float32) + + +def compile(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_batch_matmul_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1), (16, 1, 1)) + strategy2 = ((16, 1, 1), (16, 1, 1)) + net = Net(_w1, _w2, False, strategy1, strategy2) + compile(net) + + +def test_batch_matmul_model_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 1, 1), (1, 1, 1)) + strategy2 = ((1, 1, 1), (1, 1, 16)) + net = Net(_w1, _w2, False, strategy1, strategy2) + compile(net) + + +def test_batch_matmul_hybrid_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((2, 2, 2), (2, 2, 2)) + net = Net(_w1, _w2, False, strategy1, strategy2) + compile(net) + + +def test_batch_matmul_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w1, _w2, False) + compile(net) + + +def test_batch_matmul_repeat_calc(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((1, 2, 2), (1, 2, 2)) + net = Net(_w1, _w2, False, strategy1, strategy2) + compile(net) + + +def test_batch_matmul_transpose_b(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((1, 2, 2), (1, 2, 2)) + net = Net(_w1, _w2, True, strategy1, strategy2) + compile(net) diff --git a/tests/ut/python/parallel/test_neg.py b/tests/ut/python/parallel/test_neg.py new file mode 100644 index 0000000000..0e08e8c096 --- /dev/null +++ b/tests/ut/python/parallel/test_neg.py @@ -0,0 +1,84 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P +from mindspore.common.api import _executor + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.neg = P.Neg().set_strategy(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.neg(out) + return out + + +_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) + + +def compile(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_neg_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1), (16, 1, 1)) + strategy2 = ((16, 1, 1), ) + net = Net(_w1, strategy1, strategy2) + compile(net) + + +def test_neg_model_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 1, 16), (1, 1, 16)) + strategy2 = ((1, 1, 16), ) + net = Net(_w1, strategy1, strategy2) + compile(net) + + +def test_neg_hybrid_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((2, 2, 4), ) + net = Net(_w1, strategy1, strategy2) + compile(net) + + +def test_neg_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w1) + compile(net) + + +def test_neg_repeat_calc(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((1, 2, 2), ) + net = Net(_w1, strategy1, strategy2) + compile(net) +