# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ test_pynative_mixed_precision_cells """ import pytest import numpy as np import mindspore as ms import mindspore.nn as nn import mindspore.ops.operations as P from mindspore import context from mindspore.nn import Cell from mindspore.nn import ReLU from mindspore.common.tensor import Tensor class MetaFactory: def __init__(self): self.device_target = context.get_context('device_target') self.rank_size = None self.device_id = None self.global_rank_id = None class ReluTanhSoftmax(Cell, MetaFactory): def __init__(self): super().__init__() MetaFactory.__init__(self) self.relu = ReLU() self.tanh = nn.Tanh() self.softmax = nn.Softmax() def construct(self, x): x = self.relu(x) y = self.tanh(x) z = self.softmax(x) return x, y, z class Add(Cell, MetaFactory): def __init__(self): super().__init__() MetaFactory.__init__(self) self.add = P.Add() def construct(self, x, y): return self.add(x, y) class ReluTanhAdd(Cell, MetaFactory): def __init__(self): super().__init__() MetaFactory.__init__(self) self.relu = ReLU() self.tanh = nn.Tanh() self.add = Add() def construct(self, x): x_1 = self.relu(x) y = self.tanh(x) x = self.add(x_1, y) return x def _count_unequal_element(data_expected, data_me, rtol, atol): assert data_expected.shape == data_me.shape total_count = len(data_expected.flatten()) error = np.abs(data_expected - data_me) greater = np.greater(error, atol + np.abs(data_me)*rtol) loss_count = np.count_nonzero(greater) assert (loss_count/total_count) < rtol, \ "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ format(data_expected[greater], data_me[greater], error[greater]) def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): if np.any(np.isnan(data_expected)): assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): _count_unequal_element(data_expected, data_me, rtol, atol) else: assert True def mixed_precision_multiple_cells_temp_01(): np.random.seed(1) x = np.random.randn(1, 3, 28, 28).astype(np.float32) net = ReluTanhSoftmax() net.to_float(ms.float16) net.relu.to_float(ms.float32) net.softmax.to_float(ms.float16) out_me_relu_01, out_me_tanh_01, out_me_softmax_01 = net(Tensor(x)) return out_me_relu_01, out_me_tanh_01, out_me_softmax_01 def mixed_precision_multiple_cells_temp_02(): np.random.seed(1) x = np.random.randn(1, 3, 28, 28).astype(np.float32) net = ReluTanhSoftmax() net.relu.to_float(ms.float32) net.softmax.to_float(ms.float16) net.to_float(ms.float16) out_me_relu_02, out_me_tanh_02, out_me_softmax_02 = net(Tensor(x)) return out_me_relu_02, out_me_tanh_02, out_me_softmax_02 def mixed_precision_multiple_cells_temp_03(): np.random.seed(1) x = np.random.randn(1, 3, 28, 28).astype(np.float32) net = ReluTanhAdd() net.to_float(ms.float16) net.relu.to_float(ms.float32) net.add.to_float(ms.float32) out_me = net(Tensor(x)) return out_me def mixed_precision_multiples_cell_01(): context.set_context(mode=context.GRAPH_MODE, device_target=context.get_context('device_target')) graph_relu_01, graph_tanh_01, graph_softmax_01 = mixed_precision_multiple_cells_temp_01() context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target')) pynative_relu_01, pynative_tanh_01, pynative_softmax_01 = mixed_precision_multiple_cells_temp_01() allclose_nparray(graph_relu_01.asnumpy(), pynative_relu_01.asnumpy(), 0.001, 0.001) allclose_nparray(graph_tanh_01.asnumpy(), pynative_tanh_01.asnumpy(), 0.001, 0.001) allclose_nparray(graph_softmax_01.asnumpy(), pynative_softmax_01.asnumpy(), 0.001, 0.001) def mixed_precision_multiples_cell_02(): context.set_context(mode=context.GRAPH_MODE, device_target=context.get_context('device_target')) graph_relu_02, graph_tanh_02, graph_softmax_02 = mixed_precision_multiple_cells_temp_02() context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target')) pynative_relu_02, pynative_tanh_02, pynative_softmax_02 = mixed_precision_multiple_cells_temp_02() allclose_nparray(graph_relu_02.asnumpy(), pynative_relu_02.asnumpy(), 0.001, 0.001) allclose_nparray(graph_tanh_02.asnumpy(), pynative_tanh_02.asnumpy(), 0.001, 0.001) allclose_nparray(graph_softmax_02.asnumpy(), pynative_softmax_02.asnumpy(), 0.001, 0.001) def mixed_precision_multiples_cell_03(): context.set_context(mode=context.GRAPH_MODE, device_target=context.get_context('device_target')) graph_output_03 = mixed_precision_multiple_cells_temp_03() context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target')) pynative_output_03 = mixed_precision_multiple_cells_temp_03() allclose_nparray(graph_output_03.asnumpy(), pynative_output_03.asnumpy(), 0.001, 0.001) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard def test_mixed_precision_multiples_cell_ascend_01(): context.set_context(device_target="Ascend") mixed_precision_multiples_cell_01() @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_mixed_precision_multiples_cell_gpu_01(): context.set_context(device_target="GPU") mixed_precision_multiples_cell_01() @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard def test_mixed_precision_multiples_cell_ascend_02(): context.set_context(device_target="Ascend") mixed_precision_multiples_cell_02() @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_mixed_precision_multiples_cell_gpu_02(): context.set_context(device_target="GPU") mixed_precision_multiples_cell_02() @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard def test_mixed_precision_multiples_cell_ascend_03(): context.set_context(device_target="Ascend") mixed_precision_multiples_cell_03() @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_mixed_precision_multiples_cell_gpu_03(): context.set_context(device_target="GPU") mixed_precision_multiples_cell_03()