From c246b177a62eba0fa8475c7f44181671cf57408c Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Thu, 16 Jul 2020 14:08:13 +0800 Subject: [PATCH] Debug for Pynative mode. Debug test. pynative debug --- mindspore/ccsrc/pipeline/jit/resource.h | 1 + .../pipeline/pynative/pynative_execute.cc | 8 +- .../pipeline/pynative/pynative_execute.h | 1 + tests/st/pynative/test_pynative_mindarmour.py | 162 ++++++++++++++++++ 4 files changed, 170 insertions(+), 2 deletions(-) create mode 100644 tests/st/pynative/test_pynative_mindarmour.py diff --git a/mindspore/ccsrc/pipeline/jit/resource.h b/mindspore/ccsrc/pipeline/jit/resource.h index 243e424d03..80f3c729a3 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.h +++ b/mindspore/ccsrc/pipeline/jit/resource.h @@ -41,6 +41,7 @@ namespace py = pybind11; const char kBackend[] = "backend"; const char kStepParallelGraph[] = "step_parallel"; const char kOutput[] = "output"; +const char kPynativeGraphId[] = "graph_id"; class InferenceResource; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index b05ed420f8..f731c22b5d 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -68,6 +68,7 @@ static std::shared_ptr session = nullptr; PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; std::mutex PynativeExecutor::instance_lock_; ResourcePtr PynativeExecutor::resource_; +int PynativeExecutor::graph_id_ = 0; template void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) { @@ -616,7 +617,8 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { auto id = GetOpId(op_exec_info); - auto op = id; + int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast(); + auto op = std::to_string(graph_id) + id; op.append(std::to_string(op_id_map_[id])); auto iter = op_forward_map_.find(op); if (iter != op_forward_map_.end()) { @@ -709,7 +711,8 @@ void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::ob void PynativeExecutor::SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) { auto id = GetOpId(op_exec_info); - auto op = id; + int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast(); + auto op = std::to_string(graph_id) + id; op.append(std::to_string(op_id_map_[id])); auto iter = op_forward_map_.find(op); if (iter != op_forward_map_.end()) { @@ -942,6 +945,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg if (top_g_ == nullptr) { top_g_ = curr_g_ = g; resource_ = std::make_shared(); + resource_->results()[pipeline::kPynativeGraphId] = graph_id_++; cell_resource_map_[cell_id] = resource_; df_builder_ = std::make_shared(); MS_LOG(DEBUG) << "First new graph" << top_g_.get(); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 11d651c775..6cebddcf45 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -130,6 +130,7 @@ class PynativeExecutor : public std::enable_shared_from_this { static std::shared_ptr executor_; static std::mutex instance_lock_; static ResourcePtr resource_; + static int graph_id_; bool grad_flag_; std::unordered_map graph_map_; std::unordered_map cell_graph_map_; diff --git a/tests/st/pynative/test_pynative_mindarmour.py b/tests/st/pynative/test_pynative_mindarmour.py new file mode 100644 index 0000000000..469964c871 --- /dev/null +++ b/tests/st/pynative/test_pynative_mindarmour.py @@ -0,0 +1,162 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This test is used to monitor some features of MindArmour. +""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.nn import Cell, WithLossCell, TrainOneStepCell +from mindspore.nn.optim.momentum import Momentum +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops.composite import GradOperation + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +class LeNet(nn.Cell): + """ + Lenet network + Args: + num_class (int): Num classes, Default: 10. + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + """ + + def __init__(self, num_class=10): + super(LeNet, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class GradWithSens(Cell): + def __init__(self, network): + super(GradWithSens, self).__init__() + self.grad = GradOperation(name="grad", get_all=False, + sens_param=True) + self.network = network + + def construct(self, inputs, weight): + gout = self.grad(self.network)(inputs, weight) + return gout + + +class GradWrapWithLoss(Cell): + def __init__(self, network): + super(GradWrapWithLoss, self).__init__() + self._grad_all = GradOperation(name="get_all", + get_all=True, + sens_param=False) + self._network = network + + def construct(self, inputs, labels): + gout = self._grad_all(self._network)(inputs, labels) + return gout[0] + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_values_and_infer_shape(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + inputs_np = np.random.rand(32, 1, 32, 32).astype(np.float32) + sens = np.ones((inputs_np.shape[0], 10)).astype(np.float32) + inputs_np_2 = np.random.rand(64, 1, 32, 32).astype(np.float32) + + net = LeNet() + grad_all = GradWithSens(net) + + grad_out = grad_all(Tensor(inputs_np), Tensor(sens)).asnumpy() + out_shape = net(Tensor(inputs_np_2)).asnumpy().shape + assert np.any(grad_out != 0), 'grad result can not be all zeros' + assert out_shape == (64, 10), 'output shape should be (64, 10)' + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_multi_grads(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + sparse = False + inputs_np = np.random.rand(32, 1, 32, 32).astype(np.float32) + labels_np = np.random.randint(10, size=32).astype(np.int32) + inputs_np_2 = np.random.rand(64, 1, 32, 32).astype(np.float32) + labels_np_2 = np.random.randint(10, size=64).astype(np.int32) + if not sparse: + labels_np = np.eye(10)[labels_np].astype(np.float32) + labels_np_2 = np.eye(10)[labels_np_2].astype(np.float32) + + net = LeNet() + + # grad operation + loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=sparse) + with_loss_cell = WithLossCell(net, loss_fn) + grad_all = GradWrapWithLoss(with_loss_cell) + grad_out = grad_all(Tensor(inputs_np), Tensor(labels_np)).asnumpy() + assert np.any(grad_out != 0), 'grad result can not be all zeros' + + # train-one-step operation + loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=sparse) + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), + 0.01, 0.9) + loss_net = WithLossCell(net, loss_fn) + train_net = TrainOneStepCell(loss_net, optimizer) + train_net.set_train() + train_net(Tensor(inputs_np_2), Tensor(labels_np_2))