From 9c8d016d6694aa5e0674176d179f638d415ac117 Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Wed, 31 Mar 2021 15:54:30 +0800 Subject: [PATCH] fix transdata's dst format && src format is unmatched with build info when transdata has been spilted --- .../ascend/ascend_backend_optimization.cc | 4 +- ...al_ref_and_split_unsupported_transdata.cc} | 55 +++++++++++-------- ...eal_ref_and_split_unsupported_transdata.h} | 13 +++-- 3 files changed, 41 insertions(+), 31 deletions(-) rename mindspore/ccsrc/backend/optimizer/ascend/format_type/{deal_ref_trans_and_cast.cc => deal_ref_and_split_unsupported_transdata.cc} (77%) rename mindspore/ccsrc/backend/optimizer/ascend/format_type/{deal_ref_trans_and_cast.h => deal_ref_and_split_unsupported_transdata.h} (87%) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 4dbdcec55b..39f2a81fc4 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -99,7 +99,7 @@ #include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" -#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" +#include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h" #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" #include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" @@ -254,7 +254,7 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.cc similarity index 77% rename from mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc rename to mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.cc index ef44c47d51..d00c002f1f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" +#include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h" #include #include #include @@ -26,7 +26,7 @@ namespace mindspore { namespace opt { -session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const { +session::KernelWithIndex DealRefAndSpiltUnSupportedTransdata::FindRefOriginNode(const AnfNodePtr &node) const { session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); AnfNodePtr cur_node = kernel_with_index.first; size_t cur_out_index = kernel_with_index.second; @@ -61,8 +61,9 @@ session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr return kernel_with_index; } -void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const size_t output_index, const size_t input_index) const { +void DealRefAndSpiltUnSupportedTransdata::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, + const CNodePtr &cnode, const size_t output_index, + const size_t input_index) const { // record the ref_pair auto kernel_graph = func_graph->cast(); MS_EXCEPTION_IF_NULL(kernel_graph); @@ -71,10 +72,10 @@ void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_g kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); } -void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const AnfNodePtr &get_item, const AnfNodePtr &final_node, - size_t final_index, - const session::KernelWithIndex &origin_pair) const { +void DealRefAndSpiltUnSupportedTransdata::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const AnfNodePtr &get_item, + const AnfNodePtr &final_node, size_t final_index, + const session::KernelWithIndex &origin_pair) const { // record the ref_pair auto kernel_graph = func_graph->cast(); MS_EXCEPTION_IF_NULL(kernel_graph); @@ -95,9 +96,10 @@ void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph // if get_item is nullptr, the additional node will link to the cnode // else the additional node will link to the get_item node (the get_item node link to cnode) -CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - size_t output_index, size_t input_index, - const CNodePtr &get_item) const { +CNodePtr DealRefAndSpiltUnSupportedTransdata::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, + const CNodePtr &cnode, size_t output_index, + size_t input_index, + const CNodePtr &get_item) const { CNodePtr final_node = (get_item == nullptr ? cnode : get_item); bool need_refresh_ref_addr = false; size_t final_index = output_index; @@ -149,8 +151,9 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_ return final_node; } -CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node, - const CNodePtr &cnode, const FuncGraphPtr &func_graph) const { +CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node, + const CNodePtr &cnode, + const FuncGraphPtr &func_graph) const { std::vector depend_nodes; if (get_item != nullptr) { depend_nodes = std::vector{NewValueNode(prim::kPrimDepend), get_item, final_node}; @@ -159,8 +162,8 @@ CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNo } return func_graph->NewCNode(depend_nodes); } -CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::shared_ptr &op_info) const { +CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput( + const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr &op_info) const { MS_EXCEPTION_IF_NULL(op_info); auto ref_infos = op_info->ref_infos(); std::vector make_tuple_inputs; @@ -185,8 +188,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_ return make_tuple; } -CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::shared_ptr &op_info) const { +CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::shared_ptr &op_info) const { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(op_info); auto ref_infos = op_info->ref_infos(); @@ -200,13 +203,14 @@ CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, return nullptr; } -const BaseRef DealRefTransAndCast::DefinePattern() const { +const BaseRef DealRefAndSpiltUnSupportedTransdata::DefinePattern() const { VarPtr V = std::make_shared(UnVisited); VarPtr Xs = std::make_shared(); return VectorRef({V, Xs}); } -void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { +void DealRefAndSpiltUnSupportedTransdata::DealBroadCastAsRef(const FuncGraphPtr &func_graph, + const CNodePtr &cnode) const { if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { auto input_size = AnfAlgo::GetInputTensorNum(cnode); for (size_t i = 0; i < input_size; ++i) { @@ -219,8 +223,8 @@ void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, con } } -const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { +const AnfNodePtr DealRefAndSpiltUnSupportedTransdata::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { if (node == nullptr || !node->isa()) { return nullptr; } @@ -250,11 +254,12 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A return nullptr; } -CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, - const CNodePtr &cnode) const { +CNodePtr DealRefAndSpiltUnSupportedTransdata::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, + const CNodePtr &cnode) const { MS_EXCEPTION_IF_NULL(cnode); auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode); MS_EXCEPTION_IF_NULL(kernel_info); + // When the input and output format is only one special format just need to be splited into transpose and transdata if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { if (IsFormatInvaild(cnode)) { @@ -262,6 +267,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f } return cnode; } + // When input and output format are all special format + // the node should be splited to two transdata connected by default format auto builder_info_to_default = std::make_shared(kernel_info); auto builder_info_to_special_foramt = std::make_shared(kernel_info); builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); @@ -273,6 +280,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f next_trans_node->set_abstract(cnode->abstract()); AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get()); AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); + RefreshKernelBuildInfo(AnfAlgo::GetInputFormat(cnode, 0), kOpFormat_DEFAULT, cnode); + RefreshKernelBuildInfo(kOpFormat_DEFAULT, AnfAlgo::GetOutputFormat(next_trans_node, 0), next_trans_node); if (IsFormatInvaild(cnode)) { auto after_split_node = DoSplit(func_graph, cnode); AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h similarity index 87% rename from mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h rename to mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h index 48c1674ade..67d6820e2b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_ #include #include "ir/anf.h" #include "backend/optimizer/common/optimizer.h" @@ -25,10 +25,11 @@ namespace mindspore { namespace opt { -class DealRefTransAndCast : public TransDataSplit { +class DealRefAndSpiltUnSupportedTransdata : public TransDataSplit { public: - explicit DealRefTransAndCast(bool multigraph = true) : TransDataSplit(multigraph, "deal_ref_trans_and_cast") {} - ~DealRefTransAndCast() override = default; + explicit DealRefAndSpiltUnSupportedTransdata(bool multigraph = true) + : TransDataSplit(multigraph, "deal_ref_and_transdata_spilt") {} + ~DealRefAndSpiltUnSupportedTransdata() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; @@ -52,4 +53,4 @@ class DealRefTransAndCast : public TransDataSplit { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_