From 842c1bfc0f8148b0de4ba5007fda9fb8e37e12d9 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Thu, 24 Dec 2020 19:47:05 +0800 Subject: [PATCH] [MSLITE][DEVELOP] modify npu transpose pass --- .../agent/npu/optimizer/npu_fusion_pass.cc | 59 ++++++++-------- .../agent/npu/optimizer/npu_fusion_pass.h | 5 +- .../optimizer/npu_insert_transform_pass.cc | 69 ++++++++++--------- .../npu/optimizer/npu_insert_transform_pass.h | 8 +-- .../agent/npu/optimizer/npu_pass_utils.cc | 50 +++++++------- .../agent/npu/optimizer/npu_pass_utils.h | 10 +-- .../agent/npu/optimizer/npu_transform_pass.cc | 59 ++++++++-------- .../agent/npu/optimizer/npu_transform_pass.h | 4 +- 8 files changed, 140 insertions(+), 124 deletions(-) diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc index d6dfc3c2ea..beab4db676 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc @@ -21,20 +21,22 @@ namespace mindspore::lite { bool CheckFusion(kernel::LiteKernel *kernel) { auto pre_flag = - std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *kernel) { - return kernel->Type() == schema::PrimitiveType_Nchw2Nhwc && kernel->out_kernels().size() == 1; + std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *in_kernel) { + return in_kernel->Type() == schema::PrimitiveType_Nchw2Nhwc && in_kernel->out_kernels().size() == 1; }); if (!pre_flag) { return false; } - auto post_flag = - std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(), [](const kernel::LiteKernel *kernel) { - return kernel->Type() == schema::PrimitiveType_Nhwc2Nchw && kernel->in_kernels().size() == 1; - }); + auto post_flag = std::all_of( + kernel->out_kernels().begin(), kernel->out_kernels().end(), + [](const kernel::LiteKernel *out_kernel) { return out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw; }); return post_flag; } bool CheckFormatFusion(kernel::LiteKernel *kernel) { + if (kernel->out_kernels().empty()) { + return false; + } if (kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) { return std::all_of( kernel->out_kernels().begin(), kernel->out_kernels().end(), @@ -159,38 +161,26 @@ int TransFormAxis(int axis) { } } -int NPUFusionPass::AddFusion(kernel::LiteKernel *kernel) { - if (!CheckFusion(kernel)) { - return RET_OK; - } +void NPUFusionPass::UpdateKernel(kernel::LiteKernel *kernel) { UpdatePreTensors(kernel); UpdatePostTensors(kernel); UpdatePreKernels(kernel); UpdatePostKernels(kernel); +} + +int NPUFusionPass::CommonFusion(kernel::LiteKernel *kernel) { + UpdateKernel(kernel); return RET_OK; } int NPUFusionPass::ConcatFusion(kernel::LiteKernel *kernel) { - if (!CheckFusion(kernel)) { - return RET_OK; - } - UpdatePreTensors(kernel); - UpdatePostTensors(kernel); - UpdatePreKernels(kernel); - UpdatePostKernels(kernel); + UpdateKernel(kernel); auto concat_param = reinterpret_cast(kernel->op_parameter()); concat_param->axis_ = TransFormAxis(concat_param->axis_); return RET_OK; } int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) { - if (kernel->out_kernels().empty()) { - return RET_OK; - } - if (!CheckFormatFusion(kernel)) { - return RET_OK; - } - auto pre_kernel = kernel->in_kernels()[0]; auto in_tensor = kernel->in_tensors()[0]; auto out_tensor = kernel->out_tensors()[0]; @@ -237,17 +227,28 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) { } int NPUFusionPass::Run() { - for (auto kernel : *kernels) { + for (size_t i = 0; i < kernels->size(); i++) { + auto kernel = (*kernels)[i]; + if (kernel->Type() == schema::PrimitiveType_Nchw2Nhwc || kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) { + if (CheckFormatFusion(kernel)) { + i--; + FormatFusion(kernel); + } + continue; + } + if (!CheckFusion(kernel)) { + continue; + } switch (kernel->Type()) { case schema::PrimitiveType_Concat: + i -= kernel->in_kernels().size(); ConcatFusion(kernel); continue; case schema::PrimitiveType_Add: case schema::PrimitiveType_Activation: - AddFusion(kernel); - continue; - case schema::PrimitiveType_Nchw2Nhwc: - FormatFusion(kernel); + case schema::PrimitiveType_Eltwise: + i -= kernel->in_kernels().size(); + CommonFusion(kernel); continue; default: continue; diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.h b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.h index ede9818749..f895b66dac 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.h +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.h @@ -33,11 +33,12 @@ class NPUFusionPass : public NPUBasePass { int Run() override; protected: - void RemoveAndFreeKernel(kernel::LiteKernel *cur_kernel); void UpdatePreKernels(kernel::LiteKernel *kernel); void UpdatePostKernels(kernel::LiteKernel *kernel); + void RemoveAndFreeKernel(kernel::LiteKernel *cur_kernel); + void UpdateKernel(kernel::LiteKernel *kernel); + int CommonFusion(kernel::LiteKernel *kernel); int ConcatFusion(kernel::LiteKernel *kernel); - int AddFusion(kernel::LiteKernel *kernel); int FormatFusion(kernel::LiteKernel *kernel); private: diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.cc b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.cc index 823a0a0e07..e321787f4a 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.cc +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.cc @@ -21,7 +21,9 @@ namespace mindspore::lite { using kernel::KERNEL_ARCH::kNPU; enum InsertState { InsertNone, PreInsert, PostInsert }; -std::set npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add}; +std::set npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add, + schema::PrimitiveType_Eltwise, + schema::PrimitiveType_Activation}; int GetInsertState(kernel::LiteKernel *kernel) { if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) { @@ -42,16 +44,18 @@ int GetInsertState(kernel::LiteKernel *kernel) { return InsertNone; } -int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *cur_kernel, - std::vector *all_kernels, +int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, + std::vector *trans_kernels, std::vector *all_tensors) { - for (auto kernel : cur_kernel->in_kernels()) { - if (kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) { + for (auto in_kernel : kernel->in_kernels()) { + if (in_kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) { continue; } - auto nhwc_shape = cur_kernel->out_tensors()[0]->shape(); + auto nhwc_shape = in_kernel->out_tensors()[0]->shape(); std::vector nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; - auto nh2nc_tensor = new Tensor(kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR); + + auto nh2nc_tensor = + new Tensor(in_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR); std::vector nh2nc_tensors = {nh2nc_tensor}; all_tensors->push_back(nh2nc_tensors[0]); @@ -59,34 +63,36 @@ int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::L std::vector nc2nh_tensors = {nc2nh_tensor}; all_tensors->push_back(nc2nh_tensors[0]); - auto nh2nc_name = kernel->name() + "_nh2nc_" + std::to_string(total++); - auto *nh2nc_kernel = NPUPassUtils::CreateNhwc2NchwKernel(kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name); - all_kernels->push_back(nh2nc_kernel); + auto nh2nc_name = in_kernel->name() + "_nh2nc_" + std::to_string(total++); + auto *nh2nc_kernel = + NPUPassUtils::CreateNhwc2NchwKernel(in_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name); + trans_kernels->push_back(nh2nc_kernel); insert_primitive_.push_back(nh2nc_kernel->GetPrimitive()); - auto nc2nh_name = kernel->name() + "_nc2nh_" + std::to_string(total++); + + auto nc2nh_name = in_kernel->name() + "_nc2nh_" + std::to_string(total++); auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name); - all_kernels->push_back(nc2nh_kernel); + trans_kernels->push_back(nc2nh_kernel); insert_primitive_.push_back(nc2nh_kernel->GetPrimitive()); - NPUPassUtils::UpdateKernel(nh2nc_kernel, {kernel}, {nc2nh_kernel}, kernel->out_tensors(), nh2nc_tensors); - NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {cur_kernel}, nh2nc_tensors, nc2nh_tensors); - NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, cur_kernel); - NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, cur_kernel); + + NPUPassUtils::UpdateKernel(nh2nc_kernel, {in_kernel}, {nc2nh_kernel}, in_kernel->out_tensors(), nh2nc_tensors); + NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {kernel}, nh2nc_tensors, nc2nh_tensors); + NPUPassUtils::UpdateNH2NCTransNodePreKernel(in_kernel, nh2nc_kernel, kernel); + NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(in_kernel, nc2nh_kernel, kernel); } return RET_OK; } -int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *cur_kernel, - std::vector *all_kernels, +int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, + std::vector *trans_kernels, std::vector *all_tensors) { - for (auto out_kernel : cur_kernel->out_kernels()) { + for (auto out_kernel : kernel->out_kernels()) { if (out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) { continue; } - auto nhwc_shape = cur_kernel->out_tensors()[0]->shape(); + auto nhwc_shape = kernel->out_tensors()[0]->shape(); std::vector nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; - auto nh2nc_tensor = - new Tensor(cur_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR); + auto nh2nc_tensor = new Tensor(kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR); std::vector nh2nc_tensors = {nh2nc_tensor}; all_tensors->push_back(nh2nc_tensors[0]); @@ -94,19 +100,20 @@ int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel:: std::vector nc2nh_tensors = {nc2nh_tensor}; all_tensors->push_back(nc2nh_tensors[0]); - auto nh2nc_name = cur_kernel->name() + "_nh2nc_" + std::to_string(total++); - auto *nh2nc_kernel = - NPUPassUtils::CreateNhwc2NchwKernel(cur_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name); - all_kernels->push_back(nh2nc_kernel); + auto nh2nc_name = kernel->name() + "_nh2nc_" + std::to_string(total++); + auto *nh2nc_kernel = NPUPassUtils::CreateNhwc2NchwKernel(kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name); + trans_kernels->push_back(nh2nc_kernel); insert_primitive_.push_back(nh2nc_kernel->GetPrimitive()); - auto nc2nh_name = cur_kernel->name() + "_nc2nh_" + std::to_string(total++); + + auto nc2nh_name = kernel->name() + "_nc2nh_" + std::to_string(total++); auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name); - all_kernels->push_back(nc2nh_kernel); + trans_kernels->push_back(nc2nh_kernel); insert_primitive_.push_back(nc2nh_kernel->GetPrimitive()); - NPUPassUtils::UpdateKernel(nh2nc_kernel, {cur_kernel}, {nc2nh_kernel}, cur_kernel->out_tensors(), nh2nc_tensors); + + NPUPassUtils::UpdateKernel(nh2nc_kernel, {kernel}, {nc2nh_kernel}, kernel->out_tensors(), nh2nc_tensors); NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {out_kernel}, nh2nc_tensors, nc2nh_tensors); - NPUPassUtils::UpdateNH2NCTransNodePreKernel(cur_kernel, nh2nc_kernel, out_kernel); - NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(cur_kernel, nc2nh_kernel, out_kernel); + NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, out_kernel); + NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, out_kernel); } return RET_OK; } diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h index 0671a0a753..78bf57978d 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h @@ -41,11 +41,11 @@ class NPUInsertTransformPass : public NPUBasePass { int Run() override; private: - int InsertPreNode(const InnerContext *context, kernel::LiteKernel *cur_kernel, - std::vector *all_kernels, std::vector *all_tensors); + int InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, + std::vector *trans_kernels, std::vector *all_tensors); - int InsertPostNode(const InnerContext *context, kernel::LiteKernel *cur_kernel, - std::vector *all_kernels, std::vector *all_tensors); + int InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, + std::vector *trans_kernels, std::vector *all_tensors); private: int total = 0; diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.cc b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.cc index b0992fccc9..1e91038d86 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.cc +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.cc @@ -100,25 +100,25 @@ void NPUPassUtils::UpdateKernel(kernel::LiteKernel *kernel, const std::vectorset_out_kernels(out_kernels); } -void NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, - kernel::LiteKernel *after_kernel) { +void NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *pre_kernel, kernel::LiteKernel *trans_kernel, + kernel::LiteKernel *kernel) { std::vector out_kernels; - for (auto out_kernel : kernel->out_kernels()) { - if (out_kernel == after_kernel) { + for (auto out_kernel : pre_kernel->out_kernels()) { + if (out_kernel == kernel) { out_kernels.push_back(trans_kernel); } else { out_kernels.push_back(out_kernel); } } - UpdateKernel(kernel, kernel->in_kernels(), out_kernels, kernel->in_tensors(), kernel->out_tensors()); + pre_kernel->set_out_kernels(out_kernels); } void NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, - kernel::LiteKernel *next_kernel) { + kernel::LiteKernel *post_kernel) { std::vector cur_out_kernels; for (auto out_kernel : kernel->out_kernels()) { - if (out_kernel == next_kernel) { + if (out_kernel == post_kernel) { cur_out_kernels.push_back(trans_kernel); } else { cur_out_kernels.push_back(out_kernel); @@ -130,45 +130,47 @@ void NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, ker std::vector nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; kernel_out_tensor->set_format(schema::Format_NCHW); kernel_out_tensor->set_shape(nchw_shape); - UpdateKernel(kernel, kernel->in_kernels(), cur_out_kernels, kernel->in_tensors(), {kernel_out_tensor}); + kernel->set_out_kernels(cur_out_kernels); + kernel->set_out_tensors({kernel_out_tensor}); } void NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, - kernel::LiteKernel *before_kernel) { + kernel::LiteKernel *pre_kernel) { std::vector cur_kernel_in_tensors = {trans_kernel->out_tensors()[0]}; for (int i = 1; i < kernel->in_tensors().size(); i++) { cur_kernel_in_tensors.push_back(kernel->in_tensors()[i]); } std::vector cur_in_kernels = {trans_kernel}; - for (int i = 0; i < kernel->in_kernels().size(); i++) { + for (int i = 1; i < kernel->in_kernels().size(); i++) { auto in_kernel = kernel->in_kernels()[i]; if (in_kernel != kernel) { cur_in_kernels.push_back(in_kernel); } } - UpdateKernel(kernel, cur_in_kernels, kernel->out_kernels(), cur_kernel_in_tensors, kernel->out_tensors()); + kernel->set_in_kernels(cur_in_kernels); + kernel->set_in_tensors({cur_kernel_in_tensors}); } void NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, - kernel::LiteKernel *next_kernel) { - std::vector next_in_tensors; - for (auto next_in_tensor : next_kernel->in_tensors()) { - if (next_in_tensor != kernel->out_tensors()[0]) { - next_in_tensors.push_back(next_in_tensor); + kernel::LiteKernel *post_kernel) { + std::vector post_in_tensors; + for (auto post_in_tensor : post_kernel->in_tensors()) { + if (post_in_tensor != kernel->out_tensors()[0]) { + post_in_tensors.push_back(post_in_tensor); } else { - next_in_tensors.push_back(trans_kernel->out_tensors()[0]); + post_in_tensors.push_back(trans_kernel->out_tensors()[0]); } } - next_kernel->set_in_tensors(next_in_tensors); - std::vector next_in_kernels; - for (auto in_kernel : next_kernel->in_kernels()) { + post_kernel->set_in_tensors(post_in_tensors); + std::vector post_in_kernels; + for (auto in_kernel : post_kernel->in_kernels()) { if (in_kernel == kernel) { - next_in_kernels.push_back(trans_kernel); + post_in_kernels.push_back(trans_kernel); } else { - next_in_kernels.push_back(in_kernel); + post_in_kernels.push_back(in_kernel); } } - NPUPassUtils::UpdateKernel(next_kernel, next_in_kernels, next_kernel->out_kernels(), next_in_tensors, - next_kernel->out_tensors()); + post_kernel->set_in_kernels(post_in_kernels); + post_kernel->set_in_tensors({post_in_tensors}); } } // namespace mindspore::lite diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.h b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.h index b7ff59e8b2..c5a3cc1eab 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.h +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.h @@ -35,17 +35,17 @@ class NPUPassUtils { const std::vector &out_kernels, const std::vector &in_tensors, const std::vector &out_tensors); - static void UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, - kernel::LiteKernel *after_kernel); + static void UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *pre_kernel, kernel::LiteKernel *trans_kernel, + kernel::LiteKernel *kernel); static void UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, - kernel::LiteKernel *next_kernel); + kernel::LiteKernel *post_kernel); static void UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, - kernel::LiteKernel *before_kernel); + kernel::LiteKernel *pre_kernel); static void UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, - kernel::LiteKernel *next_kernel); + kernel::LiteKernel *post_kernel); private: static PrimitiveC *CreateNchw2NhwcPrimitive(); diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.cc b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.cc index a858c8a643..2779bb7ae8 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.cc +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.cc @@ -19,51 +19,53 @@ #include "src/runtime/agent/npu/npu_manager.h" #include "src/runtime/agent/npu/optimizer/npu_pass_utils.h" namespace mindspore::lite { -using kernel::KERNEL_ARCH::kCPU; using kernel::KERNEL_ARCH::kNPU; int NPUTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, - std::vector *all_kernels, + std::vector *trans_kernels, std::vector *all_tensors) { bool is_input_kernel = kernel->in_kernels().empty(); if (is_input_kernel || kernel->in_kernels()[0]->desc().arch != kNPU || npu_trans_nodes.find(kernel->in_kernels()[0]->Type()) == npu_trans_nodes.end()) { - kernel::LiteKernel *before_kernel = nullptr; + kernel::LiteKernel *pre_kernel = nullptr; if (!is_input_kernel) { - before_kernel = kernel->in_kernels()[0]; + pre_kernel = kernel->in_kernels()[0]; } - // Create pre transform kernel out tensors. + + // Create pre transform kernel's out tensor. auto nhwc_shape = kernel->in_tensors()[0]->shape(); std::vector nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; auto tensor = new Tensor(kernel->in_tensors()[0]->data_type(), nchw_shape, schema::Format_NCHW, Tensor::VAR); std::vector pre_trans_out_tensors = {tensor}; all_tensors->push_back(pre_trans_out_tensors[0]); - // Replace the output tensor of the previous node + + // Create pre transform kernel: Nhwc2Nchw auto name = kernel->name() + "_pre_trans" + "_Nhwc2Nchw_" + std::to_string(total++); - auto *pre_trans_kernel = + auto *trans_kernel = NPUPassUtils::CreateNhwc2NchwKernel({kernel->in_tensors()[0]}, pre_trans_out_tensors, context, name); - // Insert Nhwc2Nchw into the front of the current queue - all_kernels->push_back(pre_trans_kernel); - insert_primitive_.push_back(pre_trans_kernel->GetPrimitive()); - // Replace the output kernel of the previous node + + trans_kernels->push_back(trans_kernel); + insert_primitive_.push_back(trans_kernel->GetPrimitive()); + + // Set in_kernels, out_kernels, in_tensors,out_tensors for transform kernel std::vector pre_trans_in_kernel; if (is_input_kernel) { pre_trans_in_kernel = {}; } else { - pre_trans_in_kernel = {before_kernel}; + pre_trans_in_kernel = {pre_kernel}; } - NPUPassUtils::UpdateKernel(pre_trans_kernel, pre_trans_in_kernel, {kernel}, {kernel->in_tensors()[0]}, + NPUPassUtils::UpdateKernel(trans_kernel, pre_trans_in_kernel, {kernel}, {kernel->in_tensors()[0]}, pre_trans_out_tensors); - if (before_kernel != nullptr) { - NPUPassUtils::UpdateNH2NCTransNodePreKernel(before_kernel, pre_trans_kernel, kernel); + if (pre_kernel != nullptr) { + NPUPassUtils::UpdateNH2NCTransNodePreKernel(pre_kernel, trans_kernel, kernel); } - NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel, pre_trans_kernel, before_kernel); + NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel, trans_kernel, pre_kernel); } return RET_OK; } int NPUTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, - std::vector *all_kernels, + std::vector *trans_kernels, std::vector *all_tensors) { // Model output does not insert operator if (kernel->out_kernels().empty()) { @@ -71,27 +73,30 @@ int NPUTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKe } // Single output multiple references for (int i = 0; i < kernel->out_kernels().size(); i++) { - auto next_kernel = kernel->out_kernels().at(i); - if (next_kernel->desc().arch == kNPU && npu_trans_nodes.find(next_kernel->Type()) != npu_trans_nodes.end()) { + auto post_kernel = kernel->out_kernels().at(i); + if (post_kernel->desc().arch == kNPU && npu_trans_nodes.find(post_kernel->Type()) != npu_trans_nodes.end()) { continue; } - // Change format the output of the current kernel nhwc->nchw + + // Create post transform kernel's out tensor. auto tensor = new Tensor(kernel->out_tensors()[0]->data_type(), kernel->out_tensors()[0]->shape(), schema::Format_NHWC, Tensor::VAR); std::vector post_trans_out_tensors = {tensor}; all_tensors->push_back(post_trans_out_tensors[0]); - // Use the output tensor of the current node as the input tensor of the post-conversion operator + + // Create post transform kernel: Nchw2Nhwc auto name = kernel->name() + "_post_trans" + "_Nchw2Nhwc" + std::to_string(total++); auto *post_trans_kernel = NPUPassUtils::CreateNchw2NhwcKernel(kernel->out_tensors(), post_trans_out_tensors, context, name); - // Replace the input tensor of the next node - NPUPassUtils::UpdateKernel(post_trans_kernel, {kernel}, {next_kernel}, kernel->out_tensors(), + + // Set in_kernels, out_kernels, in_tensors,out_tensors for transform kernel + NPUPassUtils::UpdateKernel(post_trans_kernel, {kernel}, {post_kernel}, kernel->out_tensors(), post_trans_out_tensors); insert_primitive_.push_back(post_trans_kernel->GetPrimitive()); - // Directly insert in the back, will not affect the topological sort - all_kernels->push_back(post_trans_kernel); - NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel, post_trans_kernel, next_kernel); - NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, post_trans_kernel, next_kernel); + + trans_kernels->push_back(post_trans_kernel); + NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel, post_trans_kernel, post_kernel); + NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, post_trans_kernel, post_kernel); } return RET_OK; } diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.h b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.h index 8f6115e00d..6a13c4c01f 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.h +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.h @@ -43,10 +43,10 @@ class NPUTransformPass : public NPUBasePass { private: int InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, - std::vector *all_kernels, std::vector *all_tensors); + std::vector *trans_kernels, std::vector *all_tensors); int InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, - std::vector *all_kernels, std::vector *all_tensors); + std::vector *trans_kernels, std::vector *all_tensors); private: int total = 0;