From 0978bdc3018293e6cf8b4a52e4adb1090d871a67 Mon Sep 17 00:00:00 2001 From: lvliang Date: Thu, 9 Jul 2020 11:55:56 +0800 Subject: [PATCH] add-st-to-protect-pynative-hook-from-abnormal --- mindspore/ccsrc/session/session_basic.cc | 5 + tests/st/pynative/test_pynative_hook.py | 198 +++++++++++++++++++++++ 2 files changed, 203 insertions(+) create mode 100644 tests/st/pynative/test_pynative_hook.py diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 91e430182c..8fa68edfca 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -931,6 +931,11 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { auto backend_anf = graph->GetBackendAnfByFrontAnf(out); if (backend_anf != nullptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->execution_mode() == kPynativeMode) { + return backend_anf; + } auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); MS_EXCEPTION_IF_NULL(out); diff --git a/tests/st/pynative/test_pynative_hook.py b/tests/st/pynative/test_pynative_hook.py new file mode 100644 index 0000000000..0ce4ba4f69 --- /dev/null +++ b/tests/st/pynative/test_pynative_hook.py @@ -0,0 +1,198 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype + +from mindspore import Tensor +from mindspore import context +from mindspore import ParameterTuple +from mindspore.nn import Momentum +from mindspore.nn import WithLossCell +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.common.initializer import TruncatedNormal + +context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +class test_custom_hook_function_base(): + def __init__(self): + pass + + def test_custom_hook_function(self, hook_function, cell_hook_function): + return hook_function, cell_hook_function + + +def cell_hook_function_print_grad(cell_id, grad_input, grad_output): + assert grad_output[0].asnumpy().shape == (32, 6, 14, 14) + assert grad_input[0].asnumpy().shape == (32, 16, 10, 10) + + +def custom_hook_function_print_and_save_grad(grad_out): + assert grad_out[0].asnumpy().shape == (32, 6, 28, 28) + + +class LeNet5(nn.Cell): + def __init__(self, hook_function, cell_hook_function, num_class=10): + super(LeNet5, self).__init__() + self.num_class = num_class + self.batch_size = 32 + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.conv1.register_backward_hook(cell_hook_function) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + self.hook = P.HookBackward(hook_function) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.hook(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.reshape(x, (self.batch_size, -1)) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class GradWrap(nn.Cell): + """ GradWrap definition """ + def __init__(self, network): + super(GradWrap, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) + + def construct(self, x, label): + weights = self.weights + return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label) + + +class test_custom_cell_base(): + def __init__(self): + pass + + def test_custom_cell_function(self, cell): + return cell + + +class MulAdd(nn.Cell): + def __init__(self): + super(MulAdd, self).__init__() + + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + assert x.asnumpy() == 1.0 + assert y.asnumpy() == 2.0 + assert out.asnumpy() == 4.0 + assert dout.asnumpy() == 1.0 + return dout, y + + +class Ms_Cell(nn.Cell): + def __init__(self): + super(Ms_Cell, self).__init__() + self.relu = P.ReLU() + + def construct(self, x): + return self.relu(x) + + def bprop(self, x, out, dout): + dout = Tensor(np.ones([5, 5]).astype(np.float32)) + assert dout.shape == (5, 5) + return dout + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_lenet_train_hook_function_print_and_save_grad(): + hook = test_custom_hook_function_base() + function = hook.test_custom_hook_function(custom_hook_function_print_and_save_grad, + cell_hook_function_print_grad) + net = LeNet5(hook_function=function[0], cell_hook_function=function[1]) + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) + criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False) + net_with_criterion = WithLossCell(net, criterion) + train_network = GradWrap(net_with_criterion) + train_network.set_train() + + input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32)) + output = net(Tensor(input_data)) + criterion(output, label) + grads = train_network(input_data, label) + success = optimizer(grads) + assert success + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_custom_bprop_and_Cell_MulAdd(): + custom_cell = test_custom_cell_base() + mul_add = custom_cell.test_custom_cell_function(MulAdd()) + mul_add.bprop_debug = True + C.grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) + assert C.grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) == \ + (Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32)) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_custom_bprop_and_Cell_Ms_Cell(): + custom_cell = test_custom_cell_base() + ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell()) + ms_Cell.bprop_debug = True + assert C.grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(1.0, mstype.float32),) + \ No newline at end of file