From 7c7da0cb77a4b8ba20b4893dafa93f0a365efd4f Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Thu, 24 Dec 2020 11:33:59 +0800 Subject: [PATCH] split unsupported transdata after ref pass --- .../ascend/ascend_backend_optimization.cc | 2 - .../backend/optimizer/ascend/ascend_helper.cc | 6 +- .../backend/optimizer/ascend/ascend_helper.h | 6 +- .../format_type/deal_ref_trans_and_cast.cc | 70 ++++++++++++++----- .../format_type/deal_ref_trans_and_cast.h | 23 +++++- .../ascend/format_type/insert_trans_op.cc | 9 --- .../split_unsupported_transdata.cc | 66 ----------------- .../format_type/split_unsupported_transdata.h | 37 ---------- .../ascend/ir_fission/transdata_split.cc | 46 ++++++------ .../ascend/ir_fission/transdata_split.h | 14 ++-- .../ccsrc/backend/optimizer/common/helper.cc | 4 +- .../ccsrc/backend/optimizer/common/helper.h | 2 +- 12 files changed, 110 insertions(+), 175 deletions(-) delete mode 100644 mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc delete mode 100644 mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index b7258a2ca9..743a97e0f8 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -71,7 +71,6 @@ #include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h" #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" -#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" #include "backend/optimizer/ascend/format_type/convert_cast_format.h" #include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/pass/optimize_dependence.h" @@ -250,7 +249,6 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); optimizer->AddPassManager(mixed_precision_pm); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 3783fc7dd5..dc90ea3f35 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -256,9 +256,9 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, return trans_node; } -AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, - const TypeId &input_type, const TypeId &output_type, - const std::vector &origin_shape, const TypeId &origin_type) { +CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, + const TypeId &input_type, const TypeId &output_type, + const std::vector &origin_shape, const TypeId &origin_type) { MS_EXCEPTION_IF_NULL(func_graph); std::string input_format = format; std::string output_format = format; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h index 4496354cc9..c4901a7b10 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -94,9 +94,9 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, const bool need_padding, const std::string &op_name); -AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, - const TypeId &input_type, const TypeId &output_type, - const std::vector &origin_shape, const TypeId &origin_type); +CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, + const TypeId &input_type, const TypeId &output_type, + const std::vector &origin_shape, const TypeId &origin_type); AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc index 376d0e5623..713fb70274 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc @@ -26,8 +26,7 @@ namespace mindspore { namespace opt { -namespace { -session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { +session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const { session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); AnfNodePtr cur_node = kernel_with_index.first; size_t cur_out_index = kernel_with_index.second; @@ -62,8 +61,8 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { return kernel_with_index; } -void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index, - const size_t input_index) { +void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const size_t output_index, const size_t input_index) const { // record the ref_pair auto kernel_graph = func_graph->cast(); MS_EXCEPTION_IF_NULL(kernel_graph); @@ -72,9 +71,10 @@ void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); } -void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, - const AnfNodePtr &final_node, size_t final_index, - const session::KernelWithIndex &origin_pair) { +void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const AnfNodePtr &get_item, const AnfNodePtr &final_node, + size_t final_index, + const session::KernelWithIndex &origin_pair) const { // record the ref_pair auto kernel_graph = func_graph->cast(); MS_EXCEPTION_IF_NULL(kernel_graph); @@ -95,9 +95,10 @@ void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cno // if get_item is nullptr, the additional node will link to the cnode // else the additional node will link to the get_item node (the get_item node link to cnode) -AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, - size_t input_index, const AnfNodePtr &get_item) { - AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); +CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + size_t output_index, size_t input_index, + const CNodePtr &get_item) const { + CNodePtr final_node = (get_item == nullptr ? cnode : get_item); bool need_refresh_ref_addr = false; size_t final_index = output_index; AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); @@ -119,6 +120,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP auto kernel_select = std::make_shared(); final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type); + final_node = SplitTransdataIfNotSupported(func_graph, final_node); final_index = 0; need_refresh_ref_addr = true; MS_EXCEPTION_IF_NULL(final_node); @@ -148,15 +150,15 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP return final_node; } -AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::shared_ptr &op_info) { +CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::shared_ptr &op_info) const { MS_EXCEPTION_IF_NULL(op_info); auto ref_infos = op_info->ref_infos(); std::vector make_tuple_inputs; AbstractBasePtrList abstract_list; make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { - AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); + CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); // deal with ref output if (ref_infos.count(output_index) != 0) { auto input_index = ref_infos.at(output_index); @@ -167,14 +169,14 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP make_tuple_inputs.push_back(final_node); } MS_EXCEPTION_IF_NULL(func_graph); - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); MS_EXCEPTION_IF_NULL(make_tuple); make_tuple->set_abstract(std::make_shared(abstract_list)); return make_tuple; } -AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::shared_ptr &op_info) { +CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::shared_ptr &op_info) const { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(op_info); auto ref_infos = op_info->ref_infos(); @@ -187,7 +189,6 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn } return nullptr; } -} // namespace const BaseRef DealRefTransAndCast::DefinePattern() const { VarPtr V = std::make_shared(UnVisited); @@ -195,7 +196,7 @@ const BaseRef DealRefTransAndCast::DefinePattern() const { return VectorRef({V, Xs}); } -void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { +void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { auto input_size = AnfAlgo::GetInputTensorNum(cnode); for (size_t i = 0; i < input_size; ++i) { @@ -238,5 +239,38 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A } return nullptr; } + +CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, + const CNodePtr &cnode) const { + MS_EXCEPTION_IF_NULL(cnode); + auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode); + MS_EXCEPTION_IF_NULL(kernel_info); + if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || + kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { + if (IsFormatInvaild(cnode)) { + return DoSplit(func_graph, cnode); + } + return cnode; + } + auto builder_info_to_default = std::make_shared(kernel_info); + auto builder_info_to_special_foramt = std::make_shared(kernel_info); + builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); + builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT}); + std::vector next_trans_node_inputs = { + NewValueNode(std::make_shared(prim::KPrimTransData->name())), cnode}; + MS_EXCEPTION_IF_NULL(func_graph); + auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs); + next_trans_node->set_abstract(cnode->abstract()); + AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get()); + AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); + if (IsFormatInvaild(cnode)) { + auto after_split_node = DoSplit(func_graph, cnode); + AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0); + } + if (IsFormatInvaild(next_trans_node)) { + return DoSplit(func_graph, next_trans_node); + } + return next_trans_node; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h index da85844db8..fe9bd3e93a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h @@ -16,20 +16,37 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ - +#include #include "ir/anf.h" #include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fission/transdata_split.h" #include "backend/optimizer/common/pattern_engine.h" #include "backend/optimizer/ascend/ascend_helper.h" namespace mindspore { namespace opt { -class DealRefTransAndCast : public PatternProcessPass { +class DealRefTransAndCast : public TransDataSplit { public: - explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} + explicit DealRefTransAndCast(bool multigraph = true) : TransDataSplit(multigraph, "deal_ref_trans_and_cast") {} ~DealRefTransAndCast() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; + void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; + CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::shared_ptr &op_info) const; + CNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::shared_ptr &op_info) const; + CNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, + size_t input_index, const CNodePtr &get_item) const; + void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, + const AnfNodePtr &final_node, size_t final_index, + const session::KernelWithIndex &origin_pair) const; + void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index, + const size_t input_index) const; + session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) const; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc index b6a9529229..0e6accb14e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc @@ -31,15 +31,6 @@ const BaseRef InsertTransOp::DefinePattern() const { return VectorRef({V, Xs}); } -bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { - auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}); - auto iter = std::find(outputs.begin(), outputs.end(), node); - if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) { - return true; - } - return false; -} - const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc deleted file mode 100644 index 40be5856f8..0000000000 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" -#include -#include -#include "backend/session/anf_runtime_algorithm.h" -#include "utils/trace_base.h" - -namespace mindspore { -namespace opt { -const BaseRef SplitUnsupportedTransData::DefinePattern() const { - VarPtr X = std::make_shared(); - return VectorRef({prim::KPrimTransData, X}); -} - -const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - auto ori_trans_data = node->cast(); - if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) { - return nullptr; - } - auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data); - MS_EXCEPTION_IF_NULL(kernel_info); - if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) { - MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1" - << ori_trans_data->DebugString() << trace::DumpSourceLines(node); - } - return SplitTransData(func_graph, ori_trans_data); -} -AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const { - auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node); - if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || - kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { - return trans_node; - } - auto builder_info_to_default = std::make_shared(kernel_info); - auto builder_info_to_special_foramt = std::make_shared(kernel_info); - builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); - builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT}); - std::vector next_trans_node_inputs = { - NewValueNode(std::make_shared(prim::KPrimTransData->name())), trans_node}; - auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs); - next_trans_node->set_abstract(trans_node->abstract()); - AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get()); - AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); - return next_trans_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h deleted file mode 100644 index d4df2b57a8..0000000000 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H - -#include "backend/optimizer/common/optimizer.h" - -namespace mindspore { -namespace opt { -class SplitUnsupportedTransData : public PatternProcessPass { - public: - explicit SplitUnsupportedTransData(bool multigraph = true) - : PatternProcessPass("split_unsupported_transdata", multigraph) {} - ~SplitUnsupportedTransData() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc index 94b74e847f..cce479e95a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc @@ -27,22 +27,20 @@ const std::set> invalid_formats_pair = { {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM}, {kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; -bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { +const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); - bool changed = false; - std::vector node_list = TopoSort(func_graph->get_return()); - for (auto &node : node_list) { - if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { - CheckCNodeInputSize(node->cast(), kBackendTransDataInputNum); - if (IsFormatInvaild(node)) { - TraceGuard guard(std::make_shared(node->debug_info())); - changed = DoSplit(func_graph, node); - } + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { + CheckCNodeInputSize(node->cast(), kBackendTransDataInputNum); + if (IsFormatInvaild(node)) { + TraceGuard guard(std::make_shared(node->debug_info())); + return DoSplit(func_graph, node); } } - return changed; + return nullptr; } -bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { + +bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); @@ -52,8 +50,14 @@ bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); } + +const BaseRef TransDataSplit::DefinePattern() const { + VarPtr X = std::make_shared(); + return VectorRef({prim::KPrimTransData, X}); +} + // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) -bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { +CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); @@ -63,9 +67,9 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n auto input_format = AnfAlgo::GetInputFormat(node, 0); auto output_format = AnfAlgo::GetOutputFormat(node, 0); - AnfNodePtr new_transdata_node = nullptr; - AnfNodePtr new_transpose_node = nullptr; - AnfNodePtr new_replace_node = nullptr; + CNodePtr new_transdata_node = nullptr; + CNodePtr new_transpose_node = nullptr; + CNodePtr new_replace_node = nullptr; auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); // if output_format=default transdata need split transdata->transpose else transpose->transdata if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { @@ -96,16 +100,8 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n new_transdata_node->set_abstract(node->abstract()); new_replace_node = new_transdata_node; } - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - - if (!manager->Replace(node, new_replace_node)) { - MS_LOG(EXCEPTION) << "Manager replace node failed" - << " trace: " << trace::DumpSourceLines(node); - } MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; - return true; + return new_replace_node; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h index cde03f68f4..765c1da4c7 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h @@ -29,15 +29,17 @@ namespace mindspore { namespace opt { -class TransDataSplit : public Pass { +class TransDataSplit : public PatternProcessPass { public: - TransDataSplit() : Pass("trans_data_split"), kernel_select_(std::make_shared()) {} + explicit TransDataSplit(bool multigraph = true, const string &name = "trans_data_split") + : PatternProcessPass(name, multigraph), kernel_select_(std::make_shared()) {} ~TransDataSplit() override = default; - bool Run(const FuncGraphPtr &graph) override; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; - private: - bool DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node); - bool IsFormatInvaild(const AnfNodePtr &node); + protected: + CNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; + bool IsFormatInvaild(const AnfNodePtr &node) const; KernelSelectPtr kernel_select_; }; } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 35ac2489e9..d2ea3742c3 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -481,13 +481,13 @@ bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { return true; } -AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { +CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { auto idx = NewValueNode(SizeToLong(output_idx)); MS_EXCEPTION_IF_NULL(idx); auto imm = std::make_shared(SizeToLong(output_idx)); auto abstract_scalar = std::make_shared(imm); idx->set_abstract(abstract_scalar); - AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); MS_EXCEPTION_IF_NULL(tuple_getitem); tuple_getitem->set_scope(node->scope()); std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 03b90a8850..55162164ef 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -169,7 +169,7 @@ void HideNopNode(session::KernelGraph *const graph); void RemoveNopNode(session::KernelGraph *const graph); -AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); +CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);