diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 1fc957ba26..b7258a2ca9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -220,8 +220,8 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); } - data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 89892baa10..3783fc7dd5 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -33,7 +33,6 @@ namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; -const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { std::vector trans_inputs; @@ -82,45 +81,18 @@ std::string InitDefaultFormat(const AnfNodePtr &node) { return default_format; } -AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { - AnfNodePtr trans_node = nullptr; - CNodePtr trans_data = nullptr; - MS_EXCEPTION_IF_NULL(node); - // Init - std::string default_format = InitDefaultFormat(node); - AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast(), insert_index) : node; - std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); - std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; - std::vector padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) - : AnfAlgo::GetOutputReshapeType(node, insert_index); - auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) - : AnfAlgo::GetOutputInferShape(input_node, insert_index); - bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) - : trans::IsNeedPadding(input_format, input_node_out_shape.size()); - if (!need_padding) { - // don't need padding insert transdata only - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); - trans_node = trans_data; - } else if (is_insert_input) { - // if need padding & is input need insert a transdata - // reshape[padding shape] -> transdata[padding shape] -> node - auto padding_shape = - trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index)); - auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); - trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); - trans_node = trans_data; - trans_data->set_abstract(input_node->abstract()); - } else { - // if need padding & is output need insert a transdata - // node -> transdata[padding shape] -> reshape[ori_shape] - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); - auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape); - trans_node = reshape_node; +void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(trans_node); + auto real_input_node = AnfAlgo::VisitKernelWithReturnType(node, 0).first; + if (!real_input_node->isa()) { + return; + } + auto op_name = AnfAlgo::GetCNodeName(real_input_node); + if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(trans_node) == prim::kPrimReshape->name()) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(trans_node, 0); + auto type = AnfAlgo::GetPrevNodeOutputInferDataType(trans_node, 0); + AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get()); } - // refresh the transdata's format to ori format & dst format - RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); - return trans_node; } AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, @@ -161,15 +133,6 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An return node; } -void ReFreshInferShape(const AnfNodePtr &node, const std::string &op_name) { - MS_EXCEPTION_IF_NULL(node); - if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(node) == prim::kPrimReshape->name()) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); - auto type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); - AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get()); - } -} - AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(func_graph); @@ -177,10 +140,6 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; auto kernel_graph = func_graph->cast(); size_t out_num = AnfAlgo::GetOutputTensorNum(node); - std::string op_name; - if (node->isa()) { - op_name = AnfAlgo::GetCNodeName(node); - } for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); if (output_format == kOpFormat_NC1KHKWHWC0) { @@ -191,7 +150,6 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) { auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); - ReFreshInferShape(trans_op, op_name); if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); } @@ -205,6 +163,50 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const return make_tuple; } } // namespace +AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { + AnfNodePtr trans_node = nullptr; + CNodePtr trans_data = nullptr; + MS_EXCEPTION_IF_NULL(node); + // Init + std::string default_format = InitDefaultFormat(node); + AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast(), insert_index) : node; + std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); + std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; + std::vector padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) + : AnfAlgo::GetOutputReshapeType(node, insert_index); + auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) + : AnfAlgo::GetOutputInferShape(input_node, insert_index); + bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) + : trans::IsNeedPadding(input_format, input_node_out_shape.size()); + if (!need_padding) { + // don't need padding insert transdata only + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); + trans_node = trans_data; + } else if (is_insert_input) { + // if need padding & is input need insert a transdata + // reshape[padding shape] -> transdata[padding shape] -> node + auto padding_shape = + trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index)); + auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); + trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); + trans_node = trans_data; + trans_data->set_abstract(input_node->abstract()); + } else { + // if need padding & is output need insert a transdata + // node -> transdata[padding shape] -> reshape[ori_shape] + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); + auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape); + trans_node = reshape_node; + } + // refresh the transdata's format to ori format & dst format + RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); + if (!is_insert_input) { + ReFreshInferShape(trans_node, node); + } + return trans_node; +} + void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const AnfNodePtr &trans_data, const std::vector &reshape_type, const TypeId &type_id) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h index e2a01cc87d..4496354cc9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -18,6 +18,7 @@ #include #include +#include #include #include "runtime/device/ascend/kernel_select_ascend.h" #include "backend/kernel_compiler/kernel_query.h" @@ -106,6 +107,11 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node); + +AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input); + +const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc index 7ac56078d9..b6a9529229 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc @@ -52,14 +52,6 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { kernel_graph->ReplaceInternalOutput(node, new_node); } - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode && - !ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_HOOK)) { - if (IsGraphOutput(node, func_graph)) { - return new_node; - } - } return InsertTransOpForOutput(func_graph, new_node, kernel_select_); } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc index a8289d8fef..2a77999d45 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc @@ -22,22 +22,46 @@ namespace mindspore { namespace opt { -const BaseRef RunOpInsertTransData::DefinePattern() const { - std::shared_ptr V = std::make_shared(UnVisited); - MS_EXCEPTION_IF_NULL(V); - std::shared_ptr Xs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - return VectorRef({V, Xs}); -} - -const AnfNodePtr RunOpInsertTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { - return nullptr; +bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + bool changed = false; + std::vector node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + bool has_changed = false; + MS_EXCEPTION_IF_NULL(node); + if (!node->cast() || !AnfAlgo::IsRealKernel(node)) { + continue; + } + auto cnode = node->cast(); + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { + auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); + auto prev_node_out_infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + auto input_format = AnfAlgo::GetInputFormat(cnode, index); + auto input_node = AnfAlgo::GetInputNode(cnode, index); + // convert the format of node's input node to default + if (kCommonFormatSet.find(prev_input_format) == kCommonFormatSet.end() && prev_node_out_infer_shape.size() > 1) { + auto trans_node = AddTransOpNodeToGraph(graph, input_node, kernel_select_, 0, false); + AnfAlgo::SetNodeInput(cnode, trans_node, index); + has_changed = true; + } + // convert node's output format + if (kCommonFormatSet.find(input_format) == kCommonFormatSet.end() && prev_node_out_infer_shape.size() > 1) { + auto trans_node = AddTransOpNodeToGraph(graph, cnode, kernel_select_, index, true); + AnfAlgo::SetNodeInput(cnode, trans_node, index); + has_changed = true; + } + } + if (has_changed) { + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto new_node = kernel_graph->NewCNode(cnode); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(cnode, new_node); + changed = true; + } } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - MS_LOG(DEBUG) << "====process op: " << node->DebugString(); - return InsertTransOpForInput(func_graph, node, kernel_select_); + return changed; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h index 67b8f823c9..6fb42094cd 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h @@ -16,29 +16,29 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ - +#include #include #include #include -#include "backend/optimizer/common/optimizer.h" + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" #include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/ascend/ascend_helper.h" namespace mindspore { namespace opt { -class RunOpInsertTransData : public PatternProcessPass { +class RunOpInsertTransData : public Pass { public: - explicit RunOpInsertTransData(bool multigraph = true) - : PatternProcessPass("insert_transdata_for_runop", multigraph), - kernel_select_(std::make_shared()) {} + RunOpInsertTransData() : Pass("insert_transdata_for_runop"), kernel_select_(std::make_shared()) {} ~RunOpInsertTransData() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + bool Run(const FuncGraphPtr &graph) override; private: KernelSelectPtr kernel_select_; }; } // namespace opt } // namespace mindspore - #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index d8d87fe615..ae9e057c49 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1601,6 +1601,11 @@ std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInf // set output CreateOutputNode(cnode, graph); graph->SetInputNodes(); + auto manager = MakeManager({graph}); + if (manager != nullptr) { + manager->AddFuncGraph(graph); + graph->set_manager(manager); + } UnifyMindIR(graph); return graph; }