From 0099da2c9940d2f1e1a85fceb1fdd238deed5e09 Mon Sep 17 00:00:00 2001 From: huangdongrun Date: Fri, 21 Aug 2020 16:31:03 +0800 Subject: [PATCH] add support for tuple parameter transform add support for pynative pass add testcases --- .../frontend/optimizer/graph_transform.cc | 144 ++++++++++ .../frontend/optimizer/graph_transform.h | 108 ++++++++ mindspore/ccsrc/frontend/optimizer/irpass.cc | 5 + mindspore/ccsrc/frontend/optimizer/irpass.h | 3 + .../irpass/call_graph_tuple_transform.h | 246 ++++++++++++++++++ mindspore/ccsrc/pipeline/jit/action.cc | 1 + mindspore/ccsrc/pipeline/jit/pass.cc | 44 +++- .../pipeline/pynative/pynative_execute.cc | 39 ++- .../st/pynative/test_graph_param_transform.py | 201 ++++++++++++++ .../pynative_mode/test_graph_param_cases.py | 136 ++++++++++ 10 files changed, 923 insertions(+), 4 deletions(-) create mode 100644 mindspore/ccsrc/frontend/optimizer/graph_transform.cc create mode 100644 mindspore/ccsrc/frontend/optimizer/graph_transform.h create mode 100644 mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h create mode 100644 tests/st/pynative/test_graph_param_transform.py create mode 100644 tests/ut/python/pynative_mode/test_graph_param_cases.py diff --git a/mindspore/ccsrc/frontend/optimizer/graph_transform.cc b/mindspore/ccsrc/frontend/optimizer/graph_transform.cc new file mode 100644 index 0000000000..4b1275be98 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/graph_transform.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/graph_transform.h" +#include +#include +#include +#include "ir/graph_utils.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { +// check cnode input values, whether it is tuple input +bool CNodeHasTupleInput(const CNodePtr &cnode) { + auto &inputs = cnode->inputs(); + for (size_t i = 1; i < inputs.size(); i++) { + if (IsValueNode(inputs[i])) { + continue; + } + if (IsValueNode(inputs[i])) { + // unexpected high order primitvie as cnode input when transform graph + MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitve as input" << cnode->DebugString(); + return false; + } + auto abs = inputs[i]->abstract(); + if (abs == nullptr) { + MS_LOG(WARNING) << "CheckTupleInput, got abstract nullptr for node:" << cnode->DebugString(); + return false; + } + if (abs->isa()) { + return true; + } + } + return false; +} + +bool FuncGraphHasTupleInput(const FuncGraphPtr &fg) { + auto ¶ms = fg->parameters(); + for (auto ¶m : params) { + if (param->abstract()->isa()) { + return true; + } + } + return false; +} + +std::vector TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node, + const abstract::AbstractTuplePtr &abs) { + auto &elements = abs->elements(); + std::vector tuple_node_expanded; + for (size_t i = 0; i < elements.size(); i++) { + auto elem_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(SizeToInt(i))}); + elem_node->set_abstract(elements[i]); + if (elements[i]->isa()) { + auto nodes = TransformTupleArgument(fg, elem_node, elements[i]->cast()); + tuple_node_expanded.insert(tuple_node_expanded.end(), nodes.begin(), nodes.end()); + } else { + tuple_node_expanded.push_back(elem_node); + } + } + return tuple_node_expanded; +} + +AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) { + auto &cinputs = cnode->inputs(); + auto fg = cnode->func_graph(); + std::vector inputs; + inputs.push_back(NewValueNode(trans_fg)); + for (size_t i = 1; i < cinputs.size(); i++) { + auto abs = cinputs[i]->abstract(); + if (abs == nullptr) { + MS_LOG(EXCEPTION) << "TransformCallGraph:Node abstract should not be nullptr" << cinputs[i]->DebugString(); + } + if (abs->isa()) { + auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast()); + inputs.insert(inputs.end(), nodes.begin(), nodes.end()); + } else { + inputs.push_back(cinputs[i]); + } + } + auto new_node = fg->NewCNode(inputs); + new_node->set_abstract(cnode->abstract()); + return new_node; +} + +AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) { + auto &cinputs = cnode->inputs(); + auto fg = cnode->func_graph(); + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimPartial)); + inputs.push_back(NewValueNode(trans_fg)); + for (size_t i = 2; i < cinputs.size(); i++) { + auto abs = cinputs[i]->abstract(); + if (abs == nullptr) { + MS_LOG(EXCEPTION) << "TransformPartial:Node abstract should not be nullptr" << cinputs[i]->DebugString(); + } + if (abs->isa()) { + auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast()); + inputs.insert(inputs.end(), nodes.begin(), nodes.end()); + } else { + inputs.push_back(cinputs[i]); + } + } + auto new_node = fg->NewCNode(inputs); + new_node->set_abstract(cnode->abstract()); + return new_node; +} + +AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode) { + auto &cinputs = cnode->inputs(); + auto fg = cnode->func_graph(); + std::vector inputs; + inputs.push_back(swtich_node); + for (size_t i = 1; i < cinputs.size(); i++) { + auto abs = cinputs[i]->abstract(); + if (abs == nullptr) { + MS_LOG(EXCEPTION) << "TransformSwitchCall:Node abstract should not be nullptr" << cinputs[i]->DebugString(); + } + if (abs->isa()) { + auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast()); + inputs.insert(inputs.end(), nodes.begin(), nodes.end()); + } else { + inputs.push_back(cinputs[i]); + } + } + auto new_node = fg->NewCNode(inputs); + new_node->set_abstract(cnode->abstract()); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/graph_transform.h b/mindspore/ccsrc/frontend/optimizer/graph_transform.h new file mode 100644 index 0000000000..3199d277f5 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/graph_transform.h @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H + +#include +#include +#include +#include +#include + +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { + +bool CNodeHasTupleInput(const CNodePtr &cnode); +bool FuncGraphHasTupleInput(const FuncGraphPtr &fg); +std::vector TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node, + const abstract::AbstractTuplePtr &abs); +AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode); +AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode); +AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode); + +class GraphTupleParamTransform { + public: + GraphTupleParamTransform() : cache_() {} + ~GraphTupleParamTransform() { cache_.clear(); } + FuncGraphPtr operator()(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { + if (cache_.find(fg) != cache_.end()) { + return cache_[fg]; + } + auto new_fg = TransformGraphParam(fg, mng); + cache_[fg] = new_fg; + return new_fg; + } + + AnfNodePtr GenerateTupleParams(const abstract::AbstractTuplePtr &tuple_abs, const FuncGraphPtr &fg, + std::vector *params) { + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + auto &elements = tuple_abs->elements(); + for (auto &item : elements) { + if (item->isa()) { + inputs.push_back(GenerateTupleParams(item->cast(), fg, params)); + } else { + auto p = std::make_shared(fg); + p->set_abstract(item); + params->push_back(p); + inputs.push_back(params->back()); + } + } + auto node = fg->NewCNode(inputs); + node->set_abstract(tuple_abs); + return node; + } + + FuncGraphPtr TransformGraphParam(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { + Cloner cloner({fg}, false, false, false, std::make_shared(), std::make_shared()); + auto new_fg = cloner[fg]; + auto ¶ms = new_fg->parameters(); + std::vector new_params; + std::unordered_map repl; + for (auto ¶m : params) { + auto abs = param->abstract(); + if (abs != nullptr && abs->isa()) { + auto tuple_abs = abs->cast(); + std::vector tuple_params; + repl.emplace(param, GenerateTupleParams(tuple_abs, new_fg, &tuple_params)); + std::transform(tuple_params.begin(), tuple_params.end(), std::back_inserter(new_params), + [](AnfNodePtr p) { return p; }); + } else { + new_params.push_back(param); + } + } + auto tmp_mng = mindspore::Manage(new_fg, false); + auto tr = tmp_mng->Transact(); + for (auto &item : repl) { + bool ret = tr.Replace(item.first, item.second); + if (ret == false) { + MS_LOG(ERROR) << "replace failed" << item.first->DebugString() << " with__" << item.second->DebugString(2); + } + } + tr.SetParameters(new_fg, new_params); + tr.Commit(); + mng->AddFuncGraph(new_fg); + return new_fg; + } + std::unordered_map cache_; +}; + +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index f7e7027664..d112ed4674 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -44,6 +44,7 @@ #include "frontend/optimizer/irpass/row_tensor_eliminate.h" #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" #include "frontend/optimizer/irpass/switch_layer_defer_inline.h" +#include "frontend/optimizer/irpass/call_graph_tuple_transform.h" namespace mindspore { namespace opt { @@ -158,6 +159,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { unused_output_eliminate_ = MakeSubstitution(std::make_shared(), "unused_output_eliminate", IsCNodeGraphKernel); + // tuple parameter graph transform + call_graph_tuple_transform_ = + MakeSubstitution(std::make_shared(), "graph_param_transorm", IsCNode); + // AddN eliminate addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index afb485ead8..2dc4acc2db 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -103,6 +103,9 @@ class OptimizeIRPassLib { SubstitutionPtr unused_parameter_eliminate_; SubstitutionPtr unused_output_eliminate_; + // tuple parameter graph transform + SubstitutionPtr call_graph_tuple_transform_; + // AddN eliminate SubstitutionPtr addn_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h b/mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h new file mode 100644 index 0000000000..44b0780b4a --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h @@ -0,0 +1,246 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_ + +#include +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "frontend/optimizer/optimizer_caller.h" +#include "frontend/optimizer/anf_visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/graph_transform.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {G, Xs}-->transform graph call tuple inputs to flat inputs. +class GraphCallTupleTransform : public AnfVisitor { + public: + explicit GraphCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {} + ~GraphCallTupleTransform() override = default; + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + auto fg = GetValueNode(inputs[0]); + if (fg == nullptr) { + return nullptr; + } + if (!CNodeHasTupleInput(node->cast())) { + return nullptr; + } + FuncGraphPtr transformed_fg = graph_transform_(fg, optimizer->manager()); + auto new_node = TransformCallGraph(transformed_fg, node->cast()); + return new_node; + } + + private: + GraphTupleParamTransform &graph_transform_; +}; + +// {{switch, cond, true_branch, false_branch}, Xs} -->transform switch graph call tuple inputs to flat inputs. +class SwitchCallTupleTransform : public AnfVisitor { + public: + explicit SwitchCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {} + ~SwitchCallTupleTransform() override = default; + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + auto switch_call_cnode = node->cast(); + auto call_inputs = switch_call_cnode->inputs(); + if (call_inputs.size() < 1) { + return nullptr; + } + if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitch)) { + return nullptr; + } + auto swich_cnode = call_inputs[0]->cast(); + auto switch_inputs = swich_cnode->inputs(); + if (switch_inputs.size() != 4) { + return nullptr; + } + + AnfNodePtr transformed = nullptr; + bool true_br_changed = TransformBranchNode(switch_inputs[2], optimizer->manager(), &transformed); + if (true_br_changed) { + switch_inputs[2] = transformed; + } + bool false_br_changed = TransformBranchNode(switch_inputs[3], optimizer->manager(), &transformed); + if (false_br_changed) { + switch_inputs[3] = transformed; + } + if (true_br_changed || false_br_changed) { + call_inputs[0] = swich_cnode->func_graph()->NewCNode(switch_inputs); + } + if (CNodeHasTupleInput(switch_call_cnode)) { + return TransformSwitchCall(call_inputs[0], switch_call_cnode); + } + if (true_br_changed || false_br_changed) { + return switch_call_cnode->func_graph()->NewCNode(call_inputs); + } + return nullptr; + } + + bool TransformBranchNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) { + if (IsValueNode(node)) { + FuncGraphPtr fg = GetValueNode(node); + if (FuncGraphHasTupleInput(fg)) { + FuncGraphPtr transformed_fg = graph_transform_(fg, mng); + *trans_node = NewValueNode(transformed_fg); + return true; + } + return false; + } + if (IsPrimitiveCNode(node, prim::kPrimPartial)) { + auto partial_inputs = node->cast()->inputs(); + if (IsValueNode(partial_inputs[1])) { + FuncGraphPtr fg = GetValueNode(partial_inputs[1]); + if (FuncGraphHasTupleInput(fg)) { + fg = graph_transform_(fg, mng); + } + if (CNodeHasTupleInput(node->cast())) { + *trans_node = TransformPartial(fg, node->cast()); + return true; + } + } + return false; + } + + MS_LOG(WARNING) << "Got unexpected switch branch node " << node->DebugString(); + return false; + } + + private: + GraphTupleParamTransform &graph_transform_; +}; + +// {{switch_layer, index, {make_tuple, br1, br2,...,}}, Xs} -> +// transform switch layer graph call tuple inputs to flat inputs. +class SwitchLayerCallTupleTransform : public AnfVisitor { + public: + explicit SwitchLayerCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {} + ~SwitchLayerCallTupleTransform() override = default; + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + auto switch_layer_call_cnode = node->cast(); + auto call_inputs = switch_layer_call_cnode->inputs(); + if (call_inputs.size() < 1) { + return nullptr; + } + if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitchLayer)) { + return nullptr; + } + auto swich_layer_cnode = call_inputs[0]->cast(); + auto switch_layer_inputs = swich_layer_cnode->inputs(); + if (switch_layer_inputs.size() != 3) { + return nullptr; + } + + AnfNodePtr transformed = nullptr; + bool layer_changed = TransformLayerNode(switch_layer_inputs[2], optimizer->manager(), &transformed); + if (layer_changed) { + switch_layer_inputs[2] = transformed; + call_inputs[0] = switch_layer_call_cnode->func_graph()->NewCNode(switch_layer_inputs); + } + if (CNodeHasTupleInput(switch_layer_call_cnode)) { + return TransformSwitchCall(call_inputs[0], switch_layer_call_cnode); + } + if (layer_changed) { + return switch_layer_call_cnode->func_graph()->NewCNode(call_inputs); + } + return nullptr; + } + + bool TransformLayerNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) { + if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + MS_LOG(WARNING) << "SwitchLayer input is not MakeTuple"; + return false; + } + auto tuple_inputs = node->cast()->inputs(); + bool changed = false; + for (size_t i = 1; i < tuple_inputs.size(); i++) { + if (!IsValueNode(tuple_inputs[i])) { + MS_LOG(WARNING) << "SwitchLayer input is not FuncGraph"; + return false; + } + FuncGraphPtr fg = GetValueNode(tuple_inputs[i]); + if (FuncGraphHasTupleInput(fg)) { + FuncGraphPtr transformed_fg = graph_transform_(fg, mng); + tuple_inputs[i] = NewValueNode(transformed_fg); + changed = true; + } + } + if (changed) { + *trans_node = node->func_graph()->NewCNode(tuple_inputs); + } + return changed; + } + + private: + GraphTupleParamTransform &graph_transform_; +}; + +class CallGraphTupleTransform : public OptimizerCaller { + public: + CallGraphTupleTransform() + : graph_transformer_(), + graph_call_transform_(std::make_shared(graph_transformer_)), + switch_call_transform_(std::make_shared(graph_transformer_)), + switch_layer_call_transform_(std::make_shared(graph_transformer_)) { + transformers_.emplace_back(graph_call_transform_); + transformers_.emplace_back(switch_call_transform_); + transformers_.emplace_back(switch_layer_call_transform_); + } + ~CallGraphTupleTransform() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (auto &transform : transformers_) { + new_node = (*transform)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + GraphTupleParamTransform graph_transformer_; + OptimizerCallerPtr graph_call_transform_; + OptimizerCallerPtr switch_call_transform_; + OptimizerCallerPtr switch_layer_call_transform_; + std::vector transformers_{}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_ diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 9b1c893851..b2434d5a1c 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -277,6 +277,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) MS_EXCEPTION_IF_NULL(func_graph); func_graph->DumpFuncGraph(fg_name); DumpIR(fg_name + ".ir", func_graph); + ExportIR(fg_name + ".dat", "", func_graph); MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; } counter++; diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 0172adb793..6465f0f89f 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -33,6 +33,7 @@ #include "frontend/optimizer/clean.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/control_depend.h" +#include "frontend/optimizer/graph_transform.h" #include "frontend/parallel/step_parallel.h" #include "frontend/parallel/step_auto_parallel.h" #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" @@ -166,12 +167,23 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig c_1 = opt::OptPassConfig({ - // Safe inlining + // Safe inlining, irpass.inline_, irpass.partial_eliminate_, }); - OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); + OptPassGroupMap map_a({{"c_1", c_1}, + {"cse", opt::OptPassConfig(opt::CSEPass(false))}, + {"renormalize", opt::OptPassConfig::Renormalize()}}); + + return map_a; +} + +OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining + irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_}); + + OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); return map_a; } @@ -262,6 +274,8 @@ void InitOpt(const ResourcePtr &res) { g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); g_pass_opts["opt_after_cconv"] = Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true); + g_pass_opts["opt_trans_graph"] = + Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true); g_pass_opts["opt_graph_kernel_a"] = Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); g_pass_opts["opt_graph_kernel_b"] = @@ -307,6 +321,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); } +bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); } bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } @@ -365,6 +380,24 @@ bool CconvPass(const ResourcePtr &res) { return true; } +bool TransformTopGraphPass(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "Transform top graph error."; + } + FuncGraphPtr func_graph = res->func_graph(); + if (opt::FuncGraphHasTupleInput(func_graph)) { + opt::GraphTupleParamTransform graph_trans; + func_graph = graph_trans(func_graph, res->manager()); + res->set_func_graph(func_graph); + AbstractBasePtrList abs_spec_list; + auto ¶ms = func_graph->parameters(); + std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list), + [](AnfNodePtr node) { return node->abstract(); }); + res->set_args_spec(abs_spec_list); + } + return true; +} + bool ValidatePass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); @@ -388,6 +421,7 @@ std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStru {"cconv", CconvPass}, {"opt_after_cconv", OptPassAfterCconvGroup}, {"remove_dup_value", RemoveValueNodeDuplicationsPass}, + {"tuple_transform", OptPassTransformGraphGroup}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"add_control_depend", AddControlDependPass}}; @@ -401,6 +435,10 @@ std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStru {"opt_prepare", PrepareGroup}, {"cconv", CconvPass}}; -std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; +std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, + {"opt_b", OptPassBGroup}, + {"cconv", CconvPass}, + {"transform_top", TransformTopGraphPass}, + {"transform_graph", OptPassTransformGraphGroup}}; } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index a4a4870927..a18bbf366c 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1351,9 +1351,46 @@ void PynativeExecutor::ClearRes() { resource_.reset(); } +size_t GetTupleSize(const py::tuple &args) { + size_t count = 0; + for (size_t i = 0; i < args.size(); i++) { + if (py::isinstance(args[i])) { + count += GetTupleSize(args[i]); + } else { + count += 1; + } + } + return count; +} + +void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) { + for (size_t i = 0; i < arg.size(); i++) { + if (py::isinstance(arg[i])) { + ConvertTupleArg(res, index, arg[i]); + } else { + (*res)[(*index)++] = arg[i]; + } + } +} + +py::tuple ConvertArgs(const py::tuple &args) { + size_t tuple_size = GetTupleSize(args); + py::tuple res(tuple_size); + size_t index = 0; + for (size_t i = 0; i < args.size(); i++) { + if (py::isinstance(args[i])) { + ConvertTupleArg(&res, &index, args[i]); + } else { + res[index++] = args[i]; + } + } + return res; +} + py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) { VectorRef arg_list; - pipeline::ProcessVmArgInner(args, resource_, &arg_list); + py::tuple converted_args = ConvertArgs(args); + pipeline::ProcessVmArgInner(converted_args, resource_, &arg_list); if (resource_->results().find(pipeline::kOutput) == resource_->results().end() || !resource_->results()[pipeline::kOutput].is()) { MS_LOG(EXCEPTION) << "Can't find run graph func for "; diff --git a/tests/st/pynative/test_graph_param_transform.py b/tests/st/pynative/test_graph_param_transform.py new file mode 100644 index 0000000000..647d85cd85 --- /dev/null +++ b/tests/st/pynative/test_graph_param_transform.py @@ -0,0 +1,201 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np +from mindspore import RowTensor +from mindspore import context, nn, Tensor, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.common import ms_function +from mindspore.ops import operations as P +from mindspore.ops import composite as C + + +def setup_module(): + context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) + + +class _Grad(nn.Cell): + def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): + super().__init__() + self.network = network + self.grad = grad + self.sens_param = self.grad.sens_param + self.wrt_params = wrt_params + self.real_inputs_count = real_inputs_count + if self.wrt_params: + self.params = ParameterTuple(self.network.trainable_params()) + + def construct(self, *inputs): + if self.wrt_params: + if self.real_inputs_count is None or self.sens_param is False: + return self.grad(self.network, self.params)(*inputs) + real_inputs = inputs[:self.real_inputs_count] + sense_param_inputs = inputs[self.real_inputs_count:] + return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) + + if self.real_inputs_count is None or self.sens_param is False: + return self.grad(self.network)(*inputs) + real_inputs = inputs[:self.real_inputs_count] + sense_param_inputs = inputs[self.real_inputs_count:] + return self.grad(self.network)(*real_inputs, sense_param_inputs) + + +class GradOfFirstInput(_Grad): + """ + get grad of first input + """ + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=C.GradOperation(sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +class GradOfAllInputs(_Grad): + """ + get grad of first input + """ + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_row_tensor_in_while(): + class RowTensorValuesDouble(nn.Cell): + + def construct(self, x): + indices = x.indices + values = x.values * 2 + dense_shape = x.dense_shape + return RowTensor(indices, values, dense_shape) + + class RowTensorValuesAdd2(nn.Cell): + + def construct(self, x): + indices = x.indices + values = x.values + 2 + dense_shape = x.dense_shape + return RowTensor(indices, values, dense_shape) + + class RowTensorWithControlWhile(nn.Cell): + def __init__(self, dense_shape): + super().__init__() + self.op1 = RowTensorValuesDouble() + self.op2 = RowTensorValuesAdd2() + self.dense_shape = dense_shape + + @ms_function + def construct(self, a, b, indices, values): + x = RowTensor(indices, values, self.dense_shape) + x = self.op2(x) + while a > b: + x = self.op1(x) + b = b + 1 + return x.indices, x.values, x.dense_shape + a = Tensor(np.array(3).astype(np.int32)) + b = Tensor(np.array(0).astype(np.int32)) + indices = Tensor(np.array([0, 2]).astype(np.int32)) + values = Tensor(np.ones([2, 2]).astype(np.float32)) + dense_shape = (5, 2) + net = RowTensorWithControlWhile(dense_shape) + out = net(a, b, indices, values) + assert np.allclose(indices.asnumpy(), out[0].asnumpy(), .0, .0) + assert np.allclose(values.asnumpy()*24, out[1].asnumpy(), .0, .0) + assert dense_shape == out[2] + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_parser_switch_layer_inputs_tuple(): + class Add(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.TensorAdd() + + def construct(self, x): + y = self.op(x[0], x[1]) + return self.op(x[0], y) + + class Mul(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + + def construct(self, x): + y = self.op(x[0], x[1]) + return self.op(x[0], y) + + class MulTwoInput(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + + @ms_function + def construct(self, x, y): + y = self.op(x, y) + return self.op(x, y) + + class TwoInputTupleFinalNet(nn.Cell): + def __init__(self, funcs): + super().__init__() + self.funcs = funcs + + @ms_function + def construct(self, i, inputa, inputb): + inputs = (inputa, inputb) + x = self.funcs[i](inputs) + return x + + func1 = Add() + func2 = Mul() + + funcs = (func1, func2) + net = TwoInputTupleFinalNet(funcs) + + input_data = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) + input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) + i = Tensor(1, mstype.int32) + netout = net(i, input_data, input2) + net_good = MulTwoInput() + goodout = net_good(input_data, input2) + assert np.allclose(goodout.asnumpy(), netout.asnumpy(), 0, 0) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_imagenet(): + class ImageGradients(nn.Cell): + def __init__(self): + super().__init__() + self.imagegradients = nn.ImageGradients() + + def construct(self, inputs): + return self.imagegradients(inputs) + + net = ImageGradients() + net_me = GradOfFirstInput(net, real_inputs_count=1) + net_me.set_train() + input_data = Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32) + output_grad = (Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32), + Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32)) + net_me(input_data, *output_grad) diff --git a/tests/ut/python/pynative_mode/test_graph_param_cases.py b/tests/ut/python/pynative_mode/test_graph_param_cases.py new file mode 100644 index 0000000000..96a1ab25fc --- /dev/null +++ b/tests/ut/python/pynative_mode/test_graph_param_cases.py @@ -0,0 +1,136 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from mindspore import RowTensor +from mindspore import context, nn, Tensor, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.common import ms_function +from mindspore.ops import composite as C + + +def setup_module(): + context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) + + +class _Grad(nn.Cell): + def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): + super().__init__() + self.network = network + self.grad = grad + self.sens_param = self.grad.sens_param + self.wrt_params = wrt_params + self.real_inputs_count = real_inputs_count + if self.wrt_params: + self.params = ParameterTuple(self.network.trainable_params()) + + def construct(self, *inputs): + if self.wrt_params: + if self.real_inputs_count is None or self.sens_param is False: + return self.grad(self.network, self.params)(*inputs) + real_inputs = inputs[:self.real_inputs_count] + sense_param_inputs = inputs[self.real_inputs_count:] + return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) + + if self.real_inputs_count is None or self.sens_param is False: + return self.grad(self.network)(*inputs) + real_inputs = inputs[:self.real_inputs_count] + sense_param_inputs = inputs[self.real_inputs_count:] + return self.grad(self.network)(*real_inputs, sense_param_inputs) + + +class GradOfFirstInput(_Grad): + """ + get grad of first input + """ + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=C.GradOperation(sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +class GradOfAllInputs(_Grad): + """ + get grad of first input + """ + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +def test_row_tensor_in_while(): + class RowTensorValuesDouble(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + indices = x.indices + values = x.values * 2 + dense_shape = x.dense_shape + return RowTensor(indices, values, dense_shape) + + class RowTensorValuesAdd2(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + indices = x.indices + values = x.values + 2 + dense_shape = x.dense_shape + return RowTensor(indices, values, dense_shape) + + class RowTensorWithControlWhile(nn.Cell): + def __init__(self, dense_shape): + super().__init__() + self.op1 = RowTensorValuesDouble() + self.op2 = RowTensorValuesAdd2() + self.dense_shape = dense_shape + + @ms_function + def construct(self, a, b, indices, values): + x = RowTensor(indices, values, self.dense_shape) + x = self.op2(x) + while (a > b): + x = self.op1(x) + b = b + 1 + return x.indices, x.values, x.dense_shape + a = Tensor(np.array(3).astype(np.int32)) + b = Tensor(np.array(0).astype(np.int32)) + indices = Tensor(np.array([0, 2]).astype(np.int32)) + values = Tensor(np.ones([2, 2]).astype(np.float32)) + dense_shape = (5, 2) + + net = RowTensorWithControlWhile(dense_shape) + net(a, b, indices, values) + + +def test_multi_out_sens(): + class ImageGradients(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y, z): + resa = x * y + resb = y * z + resc = x * z + return resa, (resb, resc) + + net = ImageGradients() + net_me = GradOfAllInputs(net, real_inputs_count=3) + net_me.set_train() + input_data = Tensor(np.ones([32]), dtype=mstype.float32) + output_grad = (Tensor(np.ones([32]), dtype=mstype.float32), + (Tensor(np.ones([32]), dtype=mstype.float32), Tensor(np.ones([32]), dtype=mstype.float32))) + net_me(input_data, input_data, input_data, *output_grad)