diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 40e7a29c92..2636def192 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -96,6 +96,7 @@ #include "backend/optimizer/ascend/format_type/modify_ops_attrs.h" #include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" +#include "backend/optimizer/ascend/format_type/remove_internal_output.h" #include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" @@ -199,6 +200,7 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); optimizer->AddPassManager(data_layout_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); @@ -220,6 +222,7 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); optimizer->AddPassManager(mixed_precision_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index fd4c0e5952..9e1f6234b9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -142,6 +142,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const MS_EXCEPTION_IF_NULL(node); std::vector make_tuple_inputs; make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + auto kernel_graph = func_graph->cast(); for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); if (output_format == kOpFormat_NC1KHKWHWC0) { @@ -151,7 +152,11 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { - make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); + auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { + kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); + } + make_tuple_inputs.emplace_back(trans_op); } else { // No need insert trans op. make_tuple_inputs.push_back(tuple_getitem); @@ -249,9 +254,14 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP if (outputs_num == 0) { return node; } + auto kernel_graph = func_graph->cast(); // Single output if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) { - return InsertTransOpForSingleOutput(func_graph, node, kernel_select); + auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { + kernel_graph->ReplaceInternalOutput(node, new_node); + } + return new_node; } // Multiple output return InsertTransOpForMultipleOutput(func_graph, node, kernel_select); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc index c3f7900645..bc68511bb2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc @@ -40,6 +40,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo std::vector make_tuple_inputs; AbstractBasePtrList abstract_list; make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + auto kernel_graph = func_graph->cast(); for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { AnfNodePtr replace_node = nullptr; const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); @@ -64,6 +65,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo MS_EXCEPTION_IF_NULL(replace_node); replace_node->set_scope(cnode->scope()); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) { + kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0); + } } else { replace_node = getitem; } @@ -87,6 +91,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c return cnode; } MS_EXCEPTION_IF_NULL(cnode->Type()); + auto kernel_graph = func_graph->cast(); // Single output if (!cnode->Type()->isa()) { if (!need_insert_cast[0]) { @@ -109,6 +114,9 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c MS_EXCEPTION_IF_NULL(replace_node); replace_node->set_scope(cnode->scope()); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) { + kernel_graph->ReplaceInternalOutput(cnode, replace_node); + } } return replace_node; } @@ -188,6 +196,10 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo CNodePtr cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto new_node = InsertCastForInput(func_graph, cnode); + auto kernel_graph = func_graph->cast>(); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { + kernel_graph->ReplaceInternalOutput(node, new_node); + } // process output return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc index a22a1faa5f..8f0d5dd48e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc @@ -46,14 +46,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { return nullptr; } - AnfNodePtr front_node; + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + MS_LOG(DEBUG) << "process op: " << node->DebugString(); + AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); auto kernel_graph = func_graph->cast>(); if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { - front_node = kernel_graph->GetFrontNodeByInternalOutput(node); + kernel_graph->ReplaceInternalOutput(node, new_node); } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - MS_LOG(DEBUG) << "====process op: " << node->DebugString(); - AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { @@ -61,12 +60,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An return new_node; } } - auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_); - if (kernel_graph != nullptr && front_node != nullptr) { - auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node); - kernel_graph->ReplaceInternalOutput(old_node, final_node); - } - return final_node; + return InsertTransOpForOutput(func_graph, new_node, kernel_select_); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc new file mode 100644 index 0000000000..e9238fe006 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/remove_internal_output.h" +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +bool UsedForOutputOnly(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(node); + if (iter == node_users.end()) { + return false; + } + const auto &node_set = iter->second; + for (const auto &node_index : node_set) { + if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimMakeTuple)) { + return false; + } + } + return true; +} +} // namespace +const BaseRef RemoveInternalOutputTransOp::DefinePattern() const { + VarPtr X = std::make_shared(); + auto prim = std::make_shared(kTransDataOpName); + return VectorRef({prim, X}); +} + +const BaseRef RemoveInternalOutputCast::DefinePattern() const { + VarPtr X = std::make_shared(); + return VectorRef({prim::kPrimCast, X}); +} + +const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto kernel_graph = func_graph->cast(); + if (kernel_graph == nullptr) { + return nullptr; + } + if (!kernel_graph->IsInternalOutput(node)) { + return nullptr; + } + if (!UsedForOutputOnly(func_graph, node)) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + CheckCNodeInputSize(cnode, kTransOpInputNum); + auto input_node = cnode->input(1); + if (!AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimTupleGetItem)) { + kernel_graph->ReplaceInternalOutput(node, input_node); + } else { + auto tuple_getitem = input_node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + int idx = AnfAlgo::GetTupleGetItemOutIndex(tuple_getitem); + AnfNodePtr real_input_node = AnfAlgo::GetTupleGetItemRealInput(tuple_getitem); + kernel_graph->ReplaceInternalOutput(node, real_input_node, 0, idx); + } + return input_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h new file mode 100644 index 0000000000..6fa9b7421c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RemoveInternalOutput : public PatternProcessPass { + public: + explicit RemoveInternalOutput(const std::string &name, bool multigraph = true) + : PatternProcessPass(name, multigraph) {} + ~RemoveInternalOutput() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class RemoveInternalOutputTransOp : public RemoveInternalOutput { + public: + explicit RemoveInternalOutputTransOp(bool multigraph = true) + : RemoveInternalOutput("remove_internal_output_trans_op", multigraph) {} + ~RemoveInternalOutputTransOp() override = default; + const BaseRef DefinePattern() const override; +}; + +class RemoveInternalOutputCast : public RemoveInternalOutput { + public: + explicit RemoveInternalOutputCast(bool multigraph = true) + : RemoveInternalOutput("remove_internal_output_cast", multigraph) {} + ~RemoveInternalOutputCast() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_ diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index b1cae89a40..3e462ca618 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -929,10 +929,15 @@ void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodeP } MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); front_to_internal_outputs_map_[front_node] = node; - internal_outputs_to_front_map_[node] = front_node; + int output_idx = 0; + if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { + output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast()); + } + internal_outputs_to_front_map_[node][output_idx] = front_node; } -void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) { +void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx, + int dst_output_idx) { if (new_node == nullptr || node == nullptr) { MS_LOG(INFO) << "New node or node is nullptr"; return; @@ -947,9 +952,30 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr return; } MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); - internal_outputs_to_front_map_[new_node] = iter->second; - front_to_internal_outputs_map_[iter->second] = new_node; - internal_outputs_to_front_map_.erase(iter); + auto &front_nodes = iter->second; + // Move all front nodes to new node mapping + if (src_output_idx == -1) { + internal_outputs_to_front_map_[new_node] = front_nodes; + for (const auto &front_node_iter : front_nodes) { + front_to_internal_outputs_map_[front_node_iter.second] = new_node; + } + internal_outputs_to_front_map_.erase(iter); + return; + } + // Move specified front node to new node mapping + int index = SizeToInt(src_output_idx); + auto front_node_iter = front_nodes.find(index); + if (front_node_iter == front_nodes.end()) { + MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node"; + return; + } + auto front_node = front_node_iter->second; + internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node; + front_to_internal_outputs_map_[front_node] = new_node; + front_nodes.erase(index); + if (front_nodes.empty()) { + internal_outputs_to_front_map_.erase(iter); + } } AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { @@ -967,14 +993,6 @@ bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { return false; } -AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const { - auto iter = internal_outputs_to_front_map_.find(node); - if (iter != internal_outputs_to_front_map_.end()) { - return iter->second; - } - return nullptr; -} - void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { if (node == nullptr) { return; diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 48df351120..3ba5f333da 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -148,10 +148,10 @@ class KernelGraph : public FuncGraph { const std::map> &summary_nodes() const { return summary_nodes_; } void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); - void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); + void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1, + int dst_output_idx = -1); AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; bool IsInternalOutput(const AnfNodePtr &node) const; - AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const; void AddFinalOutputKernel(const AnfNodePtr &node); bool IsFinalOutputKernel(const AnfNodePtr &node) const; uint32_t current_epoch() const { return current_epoch_; } @@ -223,7 +223,7 @@ class KernelGraph : public FuncGraph { CNodePtr end_goto_; bool null_output_; std::unordered_map front_to_internal_outputs_map_; - std::unordered_map internal_outputs_to_front_map_; + std::unordered_map> internal_outputs_to_front_map_; std::set final_output_kernels_; uint32_t current_epoch_; }; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index fa55b07fe5..80777482dd 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -300,7 +300,11 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const MS_LOG(INFO) << "No corresponding internal output for output node"; return; } - auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0); + size_t output_idx = 0; + if (AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { + output_idx = AnfAlgo::GetTupleGetItemOutIndex(out_node->cast()); + } + auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx); auto ref_real_node = real_kernel.first; auto ref_real_node_index = real_kernel.second; if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node) && @@ -325,6 +329,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const builder.SetOutputsFormat({format}); d_kernel_info->set_select_kernel_build_info(builder.Build()); AnfAlgo::SetOutputAddr(address, 0, parameter.get()); + AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get()); } } diff --git a/tests/st/host_device/test_host_device_lenet.py b/tests/st/host_device/test_host_device_lenet.py new file mode 100644 index 0000000000..d1c49dc1e4 --- /dev/null +++ b/tests/st/host_device/test_host_device_lenet.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class LeNet(nn.Cell): + def __init__(self): + super(LeNet, self).__init__() + self.relu = P.ReLU() + self.batch_size = 32 + + self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + self.fc1 = nn.Dense(400, 120) + self.fc1.matmul.add_prim_attr("primitive_target", "CPU") + self.fc1.bias_add.add_prim_attr("primitive_target", "CPU") + self.fc2 = nn.Dense(120, 84) + self.fc2.matmul.add_prim_attr("primitive_target", "CPU") + self.fc2.bias_add.add_prim_attr("primitive_target", "CPU") + self.fc3 = nn.Dense(84, 10) + self.fc3.matmul.add_prim_attr("primitive_target", "CPU") + self.fc3.bias_add.add_prim_attr("primitive_target", "CPU") + + def construct(self, input_x): + output = self.conv1(input_x) + output = self.relu(output) + output = self.pool(output) + output = self.conv2(output) + output = self.relu(output) + output = self.pool(output) + output = self.reshape(output, (self.batch_size, -1)) + output = self.fc1(output) + output = self.relu(output) + output = self.fc2(output) + output = self.relu(output) + output = self.fc3(output) + return output + + +def train(net, data, label): + learning_rate = 0.01 + momentum = 0.9 + + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) + criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + res = train_network(data, label) + print("+++++++++Loss+++++++++++++") + print(res) + print("+++++++++++++++++++++++++++") + diff = res.asnumpy()[0] - 2.3025851 + assert np.all(diff < 1.e-7) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_lenet(): + data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = LeNet() + train(net, data, label) diff --git a/tests/st/ops/cpu/test_sparse_apply_adam_op.py b/tests/st/ops/cpu/test_sparse_apply_adam_op.py index 06b4a70b39..6dd866e96c 100644 --- a/tests/st/ops/cpu/test_sparse_apply_adam_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_adam_op.py @@ -14,6 +14,7 @@ # ============================================================================ import numpy as np +import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor @@ -43,6 +44,9 @@ class Net(nn.Cell): return out +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard def test_net(): gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32)) indices = Tensor([0, 1, 2], mstype.int32) diff --git a/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py b/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py index babaefbd86..dca5cf7a77 100644 --- a/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py @@ -14,6 +14,7 @@ # ============================================================================ import numpy as np +import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor @@ -35,6 +36,9 @@ class Net(nn.Cell): return out +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard def test_net(): gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32)) indices = Tensor([0, 1, 2], mstype.int32) diff --git a/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py b/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py index c2a129a86c..5d52e71896 100644 --- a/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py @@ -14,6 +14,7 @@ # ============================================================================ import numpy as np +import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor @@ -37,6 +38,9 @@ class Net(nn.Cell): return out +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard def test_net(): gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32)) indices = Tensor([0, 1, 2], mstype.int32) diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc new file mode 100644 index 0000000000..72b7c6e361 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/backend_common_test.h" +#include "debug/anf_ir_dump.h" +#include "common/py_func_graph_fetcher.h" +#include "backend/optimizer/ascend/format_type/remove_internal_output.h" + +#define private public +#define protected public +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; + +class TestHWRemoveInternalOutput : public BackendCommon { + public: + TestHWRemoveInternalOutput() : getPyFun_("gtest_input.pre_activate.remove_internal_output_test", true) {} + ~TestHWRemoveInternalOutput() override = default; + + AnfNodePtr GetMakeTuple(const KernelGraphPtr &kg) { + auto ret = kg->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto make_tuple = ret->input(1); + return make_tuple; + } + + KernelGraphPtr GetSingleOutputGraph(const std::string &func_name, const std::string &sub_func_name) { + FuncGraphPtr g = getPyFun_.CallAndParseRet(func_name, sub_func_name); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + auto make_tuple = GetMakeTuple(kg); + auto add = make_tuple->cast()->input(1); + MS_EXCEPTION_IF_NULL(add); + kg->AddInternalOutput(add, add); + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); + builder.SetOutputsFormat({kOpFormat_NC1HWC0}); + builder.SetOutputsDeviceType({kFloat16->type_id()}); + add->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), add.get()); + return kg; + } + + KernelGraphPtr GetMutilpleOutputGraph(const std::string &func_name, const std::string &sub_func_name) { + FuncGraphPtr g = getPyFun_.CallAndParseRet(func_name, sub_func_name); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + auto output_make_tuple = GetMakeTuple(kg); + auto make_tuple = output_make_tuple->cast()->input(1); + MS_EXCEPTION_IF_NULL(make_tuple); + auto tuple_getitem1 = make_tuple->cast()->input(1); + MS_EXCEPTION_IF_NULL(tuple_getitem1); + auto tuple_getitem2 = make_tuple->cast()->input(2); + MS_EXCEPTION_IF_NULL(tuple_getitem2); + auto max_pool = tuple_getitem1->cast()->input(1); + MS_EXCEPTION_IF_NULL(max_pool); + kg->AddInternalOutput(tuple_getitem1, max_pool); + kg->AddInternalOutput(tuple_getitem2, max_pool); + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({kOpFormat_DEFAULT}); + builder.SetInputsDeviceType({kFloat32->type_id()}); + builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); + builder.SetOutputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); + max_pool->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), max_pool.get()); + return kg; + } + UT::PyFuncGraphFetcher getPyFun_; +}; + +class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect { + public: + MockRemoveInternalOutputTransOpKernelSelect() = default; + ~MockRemoveInternalOutputTransOpKernelSelect() override = default; + void SelectKernel(const CNodePtr &cnode) override { + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({kOpFormat_NC1HWC0}); + builder.SetInputsDeviceType({kFloat16->type_id()}); + builder.SetOutputsFormat({kOpFormat_DEFAULT}); + builder.SetOutputsDeviceType({kFloat32->type_id()}); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); + } +}; + +TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_single_output) { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->set_execution_mode(kGraphMode); + auto kg = GetSingleOutputGraph("test_remove_internal_output_trans_op_for_single_output", "before"); + // insert trans op for output + auto graph_optimizer = std::make_shared(); + auto pass_manager = std::make_shared(); + auto insert_trans_op_pass = std::make_shared(); + insert_trans_op_pass->kernel_select_ = std::make_shared(); + pass_manager->AddPass(insert_trans_op_pass); + graph_optimizer->AddPassManager(pass_manager); + auto new_g = graph_optimizer->Optimize(kg); + FuncGraphPtr g_after = + getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_single_output", "after_insert_trans_op"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_g)); + + auto make_tuple = GetMakeTuple(kg); + auto trans_data = make_tuple->cast()->input(1); + EXPECT_TRUE(kg->IsInternalOutput(trans_data)); + + // remove trans op for internal output + auto graph_optimizer1 = std::make_shared(); + auto pass_manager1 = std::make_shared(); + auto remove_internal_output_trans_op_pass = std::make_shared(); + pass_manager1->AddPass(remove_internal_output_trans_op_pass); + graph_optimizer1->AddPassManager(pass_manager1); + auto new_g1 = graph_optimizer1->Optimize(new_g); + FuncGraphPtr g_after1 = getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_single_output", + "after_remove_internal_output_trans_op"); + EXPECT_TRUE(CheckEqualGraph(g_after1, new_g1)); +} + +TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_multiple_output) { + auto kg = GetMutilpleOutputGraph("test_remove_internal_output_trans_op_for_multiple_output", "before"); + // insert trans op for output + auto graph_optimizer = std::make_shared(); + auto pass_manager = std::make_shared(); + auto insert_trans_op_pass = std::make_shared(); + insert_trans_op_pass->kernel_select_ = std::make_shared(); + pass_manager->AddPass(insert_trans_op_pass); + graph_optimizer->AddPassManager(pass_manager); + auto new_g = graph_optimizer->Optimize(kg); + FuncGraphPtr g_after = + getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_multiple_output", "after_insert_trans_op"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_g)); + + auto output_make_tuple = GetMakeTuple(kg); + auto make_tuple = output_make_tuple->cast()->input(1); + auto tuple_getitem = make_tuple->cast()->input(1); + auto make_tuple1 = tuple_getitem->cast()->input(1); + auto trans_data1 = make_tuple1->cast()->input(1); + auto trans_data2 = make_tuple1->cast()->input(2); + EXPECT_TRUE(kg->IsInternalOutput(trans_data1)); + EXPECT_TRUE(kg->IsInternalOutput(trans_data2)); + + // remove trans op for internal output + auto graph_optimizer1 = std::make_shared(); + auto pass_manager1 = std::make_shared(); + auto remove_internal_output_trans_op_pass = std::make_shared(); + pass_manager1->AddPass(remove_internal_output_trans_op_pass); + graph_optimizer1->AddPassManager(pass_manager1); + auto new_g1 = graph_optimizer1->Optimize(new_g); + FuncGraphPtr g_after1 = getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_multiple_output", + "after_remove_internal_output_trans_op"); + EXPECT_TRUE(CheckEqualGraph(g_after1, new_g1)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py new file mode 100644 index 0000000000..0c02864816 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops import Primitive +from mindspore.ops import operations as P + +tuple_getitem = Primitive('tuple_getitem') +add = P.TensorAdd() +max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) +make_tuple = Primitive('make_tuple') +trans_data = Primitive("TransData") + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_remove_internal_output_trans_op_for_single_output(tag): + fns = FnDict() + + @fns + def before(x, y): + res = add(x, y) + return res + + @fns + def after_insert_trans_op(x, y): + output = add(x, y) + res = trans_data(output) + return make_tuple(res) + + @fns + def after_remove_internal_output_trans_op(x, y): + res = add(x, y) + return make_tuple(res) + + return fns[tag] + + +def test_remove_internal_output_trans_op_for_multiple_output(tag): + fns = FnDict() + + @fns + def before(x): + max_pool_res = max_pool(x) + res = make_tuple(tuple_getitem(max_pool_res, 0), tuple_getitem(max_pool_res, 1)) + return res + + @fns + def after_insert_trans_op(x): + output = max_pool(x) + trans_data0 = trans_data(tuple_getitem(output, 0)) + trans_data1 = trans_data(tuple_getitem(output, 1)) + new_make_tuple = make_tuple(trans_data0, trans_data1) + res = make_tuple(tuple_getitem(new_make_tuple, 0), tuple_getitem(new_make_tuple, 1)) + return make_tuple(res) + + @fns + def after_remove_internal_output_trans_op(x): + output = max_pool(x) + new_make_tuple = make_tuple(tuple_getitem(output, 0), tuple_getitem(output, 1)) + res = make_tuple(tuple_getitem(new_make_tuple, 0), tuple_getitem(new_make_tuple, 1)) + return make_tuple(res) + + return fns[tag]