From 1c296b96c9a25096719e85470cbb3c9470462c33 Mon Sep 17 00:00:00 2001 From: panyifeng Date: Mon, 10 Aug 2020 11:44:53 +0800 Subject: [PATCH] fix switch layer sigle prim cell --- mindspore/ccsrc/debug/anf_ir_utils.cc | 3 ++ .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 2 + mindspore/ccsrc/frontend/optimizer/irpass.cc | 5 ++ mindspore/ccsrc/frontend/optimizer/irpass.h | 3 ++ .../ccsrc/frontend/optimizer/irpass/inline.h | 2 +- .../irpass/switch_layer_defer_inline.h | 47 +++++++++++++++++++ mindspore/ccsrc/pipeline/jit/pass.cc | 1 + mindspore/core/ir/func_graph.cc | 1 + mindspore/core/ir/func_graph.h | 5 ++ mindspore/core/ir/func_graph_cloner.cc | 2 + tests/ut/python/ops/test_control_ops.py | 20 ++++++++ 11 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 4d6edd18cb..cc682fe378 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -607,6 +607,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun std::vector parameters = func_graph->parameters(); OrderedMap param_map; + if (*(func_graph->switch_layer_input())) { + ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n"; + } ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "." << func_graph->debug_info()->get_id() << "\n"; if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index d4fe201710..909d9bc6d3 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -49,6 +49,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas std::string grad_op_name = GetValue(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); } + // To keep switch_layer's inputs from being inlined + k_graph_->set_switch_layer_input(primal_graph->switch_layer_input()); TraceManager::EndTrace(); TraceManager::DebugTrace(std::make_shared(primal_graph->debug_info())); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index b41c3081b4..9291748311 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -45,6 +45,7 @@ #include "frontend/optimizer/opt.h" #include "frontend/optimizer/irpass/row_tensor_eliminate.h" #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" +#include "frontend/optimizer/irpass/switch_layer_defer_inline.h" namespace mindspore { namespace opt { @@ -170,6 +171,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // Value_Based Eliminate value_based_eliminate_ = MakeSubstitution(std::make_shared(), "value_based_eliminate", {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum}); + + // switch_layer defer inline + switch_layer_defer_inline_ = + MakeSubstitution(std::make_shared(), "switch_layer_defer_inline", prim::kPrimSwitchLayer); } ResolveIRPassLib::ResolveIRPassLib() { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 5a0f2ed5b7..835c4e5cf5 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -113,6 +113,9 @@ class OptimizeIRPassLib { // Value_Based Eliminate SubstitutionPtr value_based_eliminate_; + + // SwitchLayer defer inline + SubstitutionPtr switch_layer_defer_inline_; }; // the collection of irpass for resolve action diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index 0be228f44b..47be4abfbc 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor { } auto fg = GetValueNode(node); - if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { + if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_layer_input())) { return nullptr; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h new file mode 100644 index 0000000000..01472c2fd9 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/anf_visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimSwitchLayer, {Index, layers}} +class SwitchLayerDeferInline : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + auto cnode = node->cast(); + auto tuple = dyn_cast(cnode->inputs()[2]->abstract()); + for (auto elem : tuple->elements()) { + auto abstract = dyn_cast(elem); + *(abstract->func_graph()->switch_layer_input()) = true; + } + return nullptr; + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 0c27ba7c48..3a5fa3128b 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -90,6 +90,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) { namespace { OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_1 = opt::OptPassConfig({ + irpass.switch_layer_defer_inline_, irpass.switch_simplify_, // Safe inlining diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index 6832a94e82..9ec8a6e1c5 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -48,6 +48,7 @@ FuncGraph::FuncGraph() manager_(std::weak_ptr()), stub_(false) { debug_info_ = std::make_shared(); + switch_layer_input_ = std::make_shared(false); } abstract::AbstractBasePtr FuncGraph::ToAbstract() { diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 8bcbd3fdc1..e9694c6596 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -353,6 +353,8 @@ class FuncGraph : public FuncGraphBase { bool stub() const { return stub_; } void set_stub(bool stub) { stub_ = stub; } static void set_drawer(Drawer drawer) { drawer_ = drawer; } + std::shared_ptr switch_layer_input() const { return switch_layer_input_; } + void set_switch_layer_input(std::shared_ptr switch_layer_input) { switch_layer_input_ = switch_layer_input; } private: // graph is manipulated by manager and others @@ -414,6 +416,9 @@ class FuncGraph : public FuncGraphBase { std::list order_; bool stub_; inline static Drawer drawer_ = nullptr; + // Design switch_layer_input as a ptr to + // share between derived backpropagator and cloned graphs + std::shared_ptr switch_layer_input_; }; inline CNodePtr NewCNode(const std::vector &inputs, const FuncGraphPtr &fg) { diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index 010c931afa..51d7b34e28 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -228,6 +228,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); (*target_func_graph)->set_is_generate(func_graph->is_generated()); (*target_func_graph)->set_stub(func_graph->stub()); + (*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input()); TraceManager::EndTrace(); } @@ -645,6 +646,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_is_generate(func_graph->is_generated()); new_func_graph->set_stub(func_graph->stub()); + new_func_graph->set_switch_layer_input(func_graph->switch_layer_input()); for (auto &item : func_graph->parameter_default_value()) { new_func_graph->set_param_default_value(item.first, cloner[item.second]); } diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 369fe5f9b1..ac31420f6b 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -444,6 +444,26 @@ def test_index_to_switch_layer(): C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) +def test_switch_layer_with_single_prim(): + class SwitchLayerCell(nn.Cell): + def __init__(self): + super(SwitchLayerCell, self).__init__() + self.layers = (nn.ReLU(), nn.ReLU()) + self.z3 = Parameter( + Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') + + def construct(self, index, x): + ret = self.layers[index](x) * self.z3 + return ret + + index = Tensor(0, dtype=mstype.int32) + net = SwitchLayerCell() + net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, + Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + + def test_control_depend_check(): with pytest.raises(TypeError) as e: P.ControlDepend(0.0)