From 0b7570eb53e6581626e8a087d05ee565060a4d85 Mon Sep 17 00:00:00 2001 From: Wan Hanyang Date: Sat, 12 Sep 2020 14:26:40 +0800 Subject: [PATCH] add model with loss, without loso and o2 test case --- .../python/parallel/test_model_with_loss.py | 121 ++++++++++ .../parallel/test_model_without_loss.py | 191 ++++++++++++++++ tests/ut/python/parallel/test_o2_level.py | 208 ++++++++++++++++++ 3 files changed, 520 insertions(+) create mode 100644 tests/ut/python/parallel/test_model_with_loss.py create mode 100644 tests/ut/python/parallel/test_model_without_loss.py create mode 100644 tests/ut/python/parallel/test_o2_level.py diff --git a/tests/ut/python/parallel/test_model_with_loss.py b/tests/ut/python/parallel/test_model_with_loss.py new file mode 100644 index 0000000000..cb6c566228 --- /dev/null +++ b/tests/ut/python/parallel/test_model_with_loss.py @@ -0,0 +1,121 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, Momentum +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.ops import operations as P +from mindspore.train import Model +from tests.dataset_mock import MindData + + +class Dataset(MindData): + def __init__(self, predict, label, length=3): + super(Dataset, self).__init__(size=length) + self.predict = predict + self.label = label + self.index = 0 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.predict, self.label + + def reset(self): + self.index = 0 + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().shard(strategy1) + self.neg = P.Neg().shard(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x): + out = self.mul(x, self.mul_weight) + out = self.neg(out) + return out + + +_x = Tensor(np.ones([32, 128]), dtype=ms.float32) +_b = Tensor(np.ones([32]), dtype=ms.int32) +_w1 = Tensor(np.ones([512, 128]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(save_graphs=True) + learning_rate = 0.1 + momentum = 0.9 + epoch_size = 2 + dataset = Dataset(_x, _b) + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + opt = Momentum(net.trainable_params(), learning_rate, momentum) + model = Model(net, loss, optimizer=opt) + model.train(epoch_size, dataset, dataset_sink_mode=False) + context.reset_auto_parallel_context() + + +def test_neg_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1), (16, 1)) + strategy2 = ((16, 1),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_neg_model_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 16), (1, 16)) + strategy2 = ((1, 16),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_neg_hybrid_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((4, 4), (4, 4)) + strategy2 = ((4, 4),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_neg_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w1) + compile_net(net) + + +def test_neg_repeat_calc(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((4, 4), (4, 4)) + strategy2 = ((2, 2),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_neg_repeat_calc2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((4, 2), (4, 2)) + strategy2 = ((4, 4),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) diff --git a/tests/ut/python/parallel/test_model_without_loss.py b/tests/ut/python/parallel/test_model_without_loss.py new file mode 100644 index 0000000000..39a718dae1 --- /dev/null +++ b/tests/ut/python/parallel/test_model_without_loss.py @@ -0,0 +1,191 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, Momentum +from mindspore.ops import operations as P +from mindspore.train import Model +from tests.dataset_mock import MindData + + +class Dataset(MindData): + def __init__(self, predict, label, length=3): + super(Dataset, self).__init__(size=length) + self.predict = predict + self.label = label + self.index = 0 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.predict, self.label + + def reset(self): + self.index = 0 + + +class Net(Cell): + def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True): + super().__init__() + self.concat = P.Concat(axis=0).shard(strategy1) + if is_parameter: + self.weight = Parameter(weight, "w1") + else: + self.weight = weight + self.mul = P.Mul().shard(strategy2) + self.weight2 = Parameter(weight2, "w2") + + def construct(self, x, b): + out = self.concat((self.weight, self.weight2)) + out = self.mul(x, out) + return out + + +class Net2(Cell): + def __init__(self, weight, strategy1=None, strategy2=None, axis=0): + super().__init__() + self.mul = P.Mul().shard(strategy1) + self.concat = P.Concat(axis=axis).shard(strategy2) + self.weight = Parameter(weight, "w") + + def construct(self, x, b): + out = self.mul(x, x) + out = self.concat((out, self.weight)) + return out + + +class Net3(Cell): + def __init__(self, weight, weight2, weight3, strategy1=None, strategy2=None, is_parameter=True): + super().__init__() + self.concat = P.Concat(axis=0).shard(strategy1) + if is_parameter: + self.weight = Parameter(weight, "w1") + else: + self.weight = weight + self.mul = P.Mul().shard(strategy2) + self.weight2 = Parameter(weight2, "w2") + self.weight3 = Parameter(weight3, "w3") + + def construct(self, x, b): + out = self.concat((self.weight, self.weight2, self.weight3)) + out = self.mul(x, out) + return out + + +_x = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) +_b = Tensor(np.ones([16, 64, 32, 32]), dtype=ms.int32) +_w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32) +_w2 = Tensor(np.ones([32, 64, 32]), dtype=ms.float32) +_w3 = Tensor(np.ones([128, 16, 32]), dtype=ms.float32) + +w1 = Tensor(np.ones([48, 64, 32]), dtype=ms.float32) +w2 = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) +w3 = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(save_graphs=True) + learning_rate = 0.1 + momentum = 0.9 + epoch_size = 2 + dataset = Dataset(_x, _b) + opt = Momentum(net.trainable_params(), learning_rate, momentum) + model = Model(net, optimizer=opt, amp_level="O2") + model.train(epoch_size, dataset, dataset_sink_mode=False) + context.reset_auto_parallel_context() + + +def test_concat_parameter(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 2), (1, 4, 2)) + strategy2 = ((1, 4, 2), (1, 4, 2)) + net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_concat_parameter_no_full_split(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 2, 2), (1, 2, 2)) + strategy2 = ((1, 4, 2), (1, 4, 2)) + net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_concat_tensor_and_parameter(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 2, 2), (1, 2, 2)) + strategy2 = ((1, 4, 2), (1, 4, 2)) + net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False) + compile_net(net) + + +def test_concat_output(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((1, 4, 2), (1, 4, 2)) + net = Net2(_w1, strategy1, strategy2) + compile_net(net) + + +def test_concat_output_no_full_split(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((1, 2, 2), (1, 2, 2)) + net = Net2(_w1, strategy1, strategy2) + compile_net(net) + + +def test_concat_no_strategy(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = None + net = Net2(_w3, strategy1, strategy2, axis=1) + compile_net(net) + + +def test_concat_auto_parallel(): + context.set_auto_parallel_context( + parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net2(_w2) + compile_net(net) + + +def test_concat_auto_parallel2(): + context.set_auto_parallel_context( + parallel_mode="auto_parallel", device_num=8, global_rank=0) + strategy1 = None + strategy2 = None + net = Net2(_w3, strategy1, strategy2, axis=1) + compile_net(net) + + +def test_concat_auto_parallel_3_tensor(): + context.set_auto_parallel_context( + parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net3(w1, w2, w3) + compile_net(net) diff --git a/tests/ut/python/parallel/test_o2_level.py b/tests/ut/python/parallel/test_o2_level.py new file mode 100644 index 0000000000..813ddd074d --- /dev/null +++ b/tests/ut/python/parallel/test_o2_level.py @@ -0,0 +1,208 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, Momentum +from mindspore.ops import operations as P +from mindspore.train import Model +from tests.dataset_mock import MindData + + +class Dataset(MindData): + def __init__(self, predict, label, length=3): + super(Dataset, self).__init__(size=length) + self.predict = predict + self.label = label + self.index = 0 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.predict, self.label + + def reset(self): + self.index = 0 + + +class Net(Cell): + def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True, mask=0): + super().__init__() + self.mul = P.Mul().shard(strategy1) + self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2) + if is_parameter: + self.weight = Parameter(weight, "w1") + else: + self.weight = weight + self.mul2 = P.Mul() + self.weight2 = Parameter(w2, "w2") + self.begin = begin + self.end = end + self.strides = strides + + def construct(self, x, b): + out = self.strided_slice( + self.weight, self.begin, self.end, self.strides) + out = self.mul(x, out) + out = self.mul2(out, self.weight2) + return out + + +class Net2(Cell): + def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().shard(strategy1) + self.strided_slice = P.StridedSlice().shard(strategy2) + self.weight2 = Parameter(weight2, "w2") + self.begin = begin + self.end = end + self.strides = strides + + def construct(self, x, b): + out = self.mul(x, self.weight2) + out = self.strided_slice(out, self.begin, self.end, self.strides) + return out + + +_x = Tensor(np.ones([16, 64, 1]), dtype=ms.float32) +_b = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) +_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32) +_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(save_graphs=True) + learning_rate = 0.1 + momentum = 0.9 + epoch_size = 2 + dataset = Dataset(_x, _b) + opt = Momentum(net.trainable_params(), learning_rate, momentum) + model = Model(net, optimizer=opt, amp_level="O2") + model.train(epoch_size, dataset, dataset_sink_mode=False) + context.reset_auto_parallel_context() + + +def test_stridedslice_no_fully_fetch_split_error(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((2, 2, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), + strategy1, strategy2, is_parameter=True) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_stridedslice_strides_no_1_split_error(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((1, 2, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 2), + strategy1, strategy2, is_parameter=True) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_stridedslice_mask_no_0_split_error(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((1, 2, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), + strategy1, strategy2, is_parameter=True, mask=1) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_stridedslice_begin_size_smaller(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 1), (1, 4, 2)) + strategy2 = ((1, 4, 2),) + net = Net(_w1, _w2, (0, 0), (128, 64), (1, 1), + strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_stridedslice_parameter(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 1), (1, 4, 2)) + strategy2 = ((1, 4, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), + strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_stridedslice_tensor(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 1), (1, 4, 2)) + strategy2 = ((1, 4, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), + strategy1, strategy2, is_parameter=False) + compile_net(net) + + +def test_stridedslice_parameter_no_full_split(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 1), (1, 4, 2)) + strategy2 = ((1, 2, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), + strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_stridedslice_output(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 8, 1), (1, 8, 1)) + strategy2 = ((1, 8, 1),) + net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2) + compile_net(net) + + +def test_stridedslice_output_no_full_split(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 8, 1), (1, 8, 1)) + strategy2 = ((1, 4, 1),) + net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2) + compile_net(net) + + +def test_stridedslice_no_strategy(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 8, 1), (1, 8, 1)) + strategy2 = None + net = Net2(_w2, (0, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2) + compile_net(net) + + +def test_stridedslice_auto_parallel(): + context.set_auto_parallel_context( + parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net2(_w2, (0, 0, 0), (32, 64, 1), (1, 1, 1)) + compile_net(net)