diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc index fcfcd15f79..cf2bbf330f 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc @@ -246,31 +246,31 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) { int NPUFusionPass::Run() { for (size_t i = 0; i < kernels->size(); i++) { auto kernel = (*kernels)[i]; - if (NPUPassUtils::IsNchw2Nhwc(kernel) || NPUPassUtils::IsNhwc2Nchw(kernel)) { - if (CheckFormatFusion(kernel)) { - i--; - FormatFusion(kernel); + if (CheckFusion(kernel)) { + switch (kernel->Type()) { + case schema::PrimitiveType_Concat: + i -= kernel->in_kernels().size(); + ConcatFusion(kernel); + continue; + case schema::PrimitiveType_Add: + case schema::PrimitiveType_Activation: + case schema::PrimitiveType_Eltwise: + i -= kernel->in_kernels().size(); + CommonFusion(kernel); + continue; + default: + continue; } - continue; - } - if (!CheckFusion(kernel)) { - continue; } - switch (kernel->Type()) { - case schema::PrimitiveType_Concat: - i -= kernel->in_kernels().size(); - ConcatFusion(kernel); - continue; - case schema::PrimitiveType_Add: - case schema::PrimitiveType_Activation: - case schema::PrimitiveType_Eltwise: - i -= kernel->in_kernels().size(); - CommonFusion(kernel); - continue; - default: - continue; + } + for (size_t i = 0; i < kernels->size(); ++i) { + auto kernel = (*kernels)[i]; + if (CheckFormatFusion(kernel)) { + i--; + FormatFusion(kernel); } } + return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.cc b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.cc index b403f158d3..8eb75aebac 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.cc +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.cc @@ -20,31 +20,81 @@ namespace mindspore::lite { using kernel::KERNEL_ARCH::kNPU; -enum InsertState { InsertNone, PreInsert, PostInsert }; - +enum InsertState { InsertNone, PreInsert, PostInsert, BothInsert }; std::set npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add, schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation}; +// this pass goal is to minimize subgraphs generated +// by inserting nchw2nhwc or nhwc2nchw before or after the operator (e.g. concat, add, etc..) together with +// fusion pass. If transpose inserted are more than half of input output, we will insert remaining input +// output with transpose and hopefully do a fusion pass. Otherwise, we don't insert anything. +// +// Typically concat accept output from nchw2nhwc, we fill other input with nh2nc and nc2nh so that inputs to concat are +// format same and then fusion all nchw2nhwc op. +// e.g. +// original (conv->nchw2nhwc, add(format nhwc)) -> concat-> (nhwc2nchw->conv) +// current pass (conv->nchw2nhwc, add->nhwc2nchw->nchw2nhwc) -> concat -> (nhwc2nchw->conv) +// fusion pass (conv, add->nhwc2nchw) -> concat -> conv +// original 2 cpusubgraph, after 2 pass, only 1 cpu subgraph +// +// node: +// Such ops require inputs all have same format, could be nchw or nhwc or other format. +// Their inputs outputs may not be 4d, or are already format ok, +// so we won't insert nc2nh or nh2nc when op's in kernels and out kernels contains no nc2nh or nh2nc. +// This pass should be run after npu_transform_pass, which insert transpose for nchw-input-limited op like conv2d. int GetInsertState(kernel::LiteKernel *kernel) { + // filter out irrelevant kernel if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) { return InsertNone; } - auto pre_flag = std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), - [](const kernel::LiteKernel *kernel) { return NPUPassUtils::IsNchw2Nhwc(kernel); }); - auto post_flag = std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(), - [](const kernel::LiteKernel *kernel) { return NPUPassUtils::IsNhwc2Nchw(kernel); }); - if (pre_flag && !post_flag) { - return PostInsert; + + // current kernel is target kernel + // use out kernels to count how many out lines from current kernel + size_t in_out_tensor_num = kernel->in_tensors().size() + kernel->out_kernels().size(); + size_t transpose_input_num = 0; + size_t transpose_output_num = 0; + bool need_pre_insert = false; + bool need_post_insert = false; + // count number of input tensor from nc2nh and output tensor to nh2nc + for (size_t i = 0; i < kernel->in_tensors().size(); ++i) { + auto in_kernel = NPUPassUtils::KernelInputFromKernel(kernel, i); + if (NPUPassUtils::IsNchw2Nhwc(in_kernel)) { + transpose_input_num++; + } else { + need_pre_insert = true; + } + } + for (const auto out_kernel : kernel->out_kernels()) { + if (NPUPassUtils::IsNhwc2Nchw(out_kernel)) { + transpose_output_num++; + } else { + need_post_insert = true; + } + } + + // won't insert any thing if num of transpose tensor is smaller than half of total input output. + // won't insert if total input output are all transpose tensor, the fusion pass will handle this. + size_t transpose_tensor_num = transpose_input_num + transpose_output_num; + if (transpose_tensor_num <= in_out_tensor_num / 2 || transpose_tensor_num == in_out_tensor_num) { + return InsertNone; } - if (!pre_flag && post_flag) { + + if (need_pre_insert && !need_post_insert) { return PreInsert; } + if (need_pre_insert && need_post_insert) { + return BothInsert; + } + if (!need_pre_insert && need_post_insert) { + return PostInsert; + } + return InsertNone; } int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, - std::vector *trans_kernels) { + size_t post_input_index, std::vector *trans_kernels) { // Kernel and post_kernel can't be nullptr at the same time. std::string kernel_name; Tensor *in_tensor = nullptr; @@ -54,7 +104,7 @@ int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteK if (post_kernel != nullptr) { out_kernels.push_back(post_kernel); kernel_name = post_kernel->name() + "_pre"; - in_tensor = post_kernel->in_tensors()[0]; + in_tensor = post_kernel->in_tensors().at(post_input_index); } std::vector in_kernels; // If kernel equals nullptr, post_kernel is the input of whole graph. @@ -99,87 +149,134 @@ int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteK } if (post_kernel != nullptr) { NPUPassUtils::UpdateNC2NHTransNodePostKernel(kernel, nc2nh_kernel, post_kernel); + } else { + // post_kernel nullptr mean output, we remain graph output tensor name unchanged + auto graph_output_name = in_tensor->tensor_name(); + in_tensor->set_tensor_name(graph_output_name + "_before_" + name_); + nc2nh_tensor->set_tensor_name(graph_output_name); } return RET_OK; } +int NPUInsertTransformPass::InsertForInputTensor(kernel::LiteKernel *kernel, size_t in_tensor_index, + kernel::LiteKernel *pre_kernel, + std::vector *trans_kernels) { + // insert transpose nodes before target ops + return InsertNode(pre_kernel, kernel, in_tensor_index, trans_kernels); +} + +int NPUInsertTransformPass::InsertForOutputTensor(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, + size_t post_in_tensor_index, + std::vector *trans_kernels) { + // insert transpose nodes after target ops + return InsertNode(kernel, post_kernel, post_in_tensor_index, trans_kernels); +} + int NPUInsertTransformPass::InsertPreNodes(kernel::LiteKernel *kernel, std::vector *trans_kernels) { - if (kernel->in_kernels().size() != kernel->in_tensors().size()) { - MS_LOG(DEBUG) << "The input tensors of kernel may be the input of whole graph or const tensor."; - return RET_OK; - } - if (kernel->in_kernels().empty()) { - auto ret = InsertNode(nullptr, kernel, trans_kernels); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed."; - return RET_ERROR; - } - } - for (auto in_kernel : kernel->in_kernels()) { - if (NPUPassUtils::IsNchw2Nhwc(in_kernel)) { + int ret = RET_OK; + for (size_t i = 0; i < kernel->in_tensors().size(); ++i) { + auto pre_kernel = NPUPassUtils::KernelInputFromKernel(kernel, i); + if (NPUPassUtils::IsNchw2Nhwc(pre_kernel)) { continue; } - auto ret = InsertNode(in_kernel, kernel, trans_kernels); + // if this tensor is input of graph, pre_kernel is nullptr. + ret = InsertForInputTensor(kernel, i, pre_kernel, trans_kernels); if (ret != RET_OK) { MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed."; - return RET_ERROR; + return ret; } } - return RET_OK; + return ret; } int NPUInsertTransformPass::InsertPostNodes(kernel::LiteKernel *kernel, std::vector *trans_kernels) { - if (kernel->out_kernels().empty()) { - auto ret = InsertNode(kernel, nullptr, trans_kernels); + int ret = RET_OK; + + for (const auto post_kernel : kernel->out_kernels()) { + if (NPUPassUtils::IsNhwc2Nchw(post_kernel)) { + continue; + } + auto post_kernel_in_tensors = post_kernel->in_tensors(); + // kernel's out tensor is one of post_kernel's input tensor + auto it = std::find(post_kernel_in_tensors.begin(), post_kernel_in_tensors.end(), kernel->out_tensors().at(0)); + if (it == post_kernel_in_tensors.end()) { + return RET_ERROR; + } + size_t input_index = it - post_kernel_in_tensors.begin(); + ret = InsertForOutputTensor(kernel, post_kernel, input_index, trans_kernels); if (ret != RET_OK) { MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; - return RET_ERROR; + return ret; } } - for (auto out_kernel : kernel->out_kernels()) { - if (NPUPassUtils::IsNhwc2Nchw(out_kernel)) { - continue; - } - auto ret = InsertNode(kernel, out_kernel, trans_kernels); + if (kernel->out_tensors().size() > kernel->out_kernels().size()) { + // kernel out is graph output + ret = InsertForOutputTensor(kernel, nullptr, 0, trans_kernels); if (ret != RET_OK) { MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; - return RET_ERROR; + return ret; } } - return RET_OK; + return ret; } int NPUInsertTransformPass::Run() { + std::vector insert_kernels; for (size_t i = 0; i < all_kernels_->size(); i++) { auto kernel = (*all_kernels_)[i]; if (kernel->desc().arch != kNPU) { continue; } auto insert_state = GetInsertState(kernel); + insert_kernels.clear(); // If the every output kernel is nhwc2nchw, insert // modify loop index add post_kernels.size() to the next kernel in the origin vector - if (insert_state == PreInsert) { - std::vector pre_kernels; - auto ret = InsertPreNodes(kernel, &pre_kernels); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed."; - return RET_ERROR; + switch (insert_state) { + case PreInsert: { + auto ret = InsertPreNodes(kernel, &insert_kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() + << " failed."; + return RET_ERROR; + } + all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end()); + i += insert_kernels.size(); + break; } - all_kernels_->insert(all_kernels_->begin() + i, pre_kernels.begin(), pre_kernels.end()); - i += pre_kernels.size(); - } + case PostInsert: { + auto ret = InsertPostNodes(kernel, &insert_kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; + return RET_ERROR; + } + all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end()); + i += insert_kernels.size(); + break; + } + case BothInsert: { + auto ret = InsertPreNodes(kernel, &insert_kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() + << " failed."; + return RET_ERROR; + } + all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end()); + i += insert_kernels.size(); - if (insert_state == PostInsert) { - std::vector post_kernels; - auto ret = InsertPostNodes(kernel, &post_kernels); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; - return RET_ERROR; + insert_kernels.clear(); + ret = InsertPostNodes(kernel, &insert_kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; + return RET_ERROR; + } + all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end()); + i += insert_kernels.size(); + break; } - all_kernels_->insert(all_kernels_->begin() + i + 1, post_kernels.begin(), post_kernels.end()); - i += post_kernels.size(); + default: + MS_LOG(DEBUG) << "Insert Nothing on kernel " << kernel->name(); } } return RET_OK; diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h index 8479d222f2..adc2d09027 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h @@ -45,8 +45,13 @@ class NPUInsertTransformPass : public NPUBasePass { int InsertPostNodes(kernel::LiteKernel *kernel, std::vector *trans_kernels); - int InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, + int InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, size_t post_input_index, std::vector *trans_kernels); + int InsertForInputTensor(kernel::LiteKernel *kernel, size_t in_tensor_index, kernel::LiteKernel *pre_kernel, + std::vector *trans_kernels); + + int InsertForOutputTensor(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, size_t post_in_tensor_index, + std::vector *trans_kernels); private: int total = 0; diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.cc b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.cc index a41bfedab9..b9299e5930 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.cc +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.cc @@ -172,32 +172,33 @@ void NPUPassUtils::UpdateNC2NHPostKernelInTensors(kernel::LiteKernel *kernel, ke void NPUPassUtils::UpdateNC2NHTransNodePostKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, kernel::LiteKernel *post_kernel) { - // For post_kernel after trans, kernel should be replaced with trans_kernel. + // The input tensor should be replaced with the output tensor of trans_kernel. auto post_in_tensors = post_kernel->in_tensors(); - if (kernel == nullptr) { - post_in_tensors[0] = trans_kernel->out_tensors()[0]; - } else { - for (size_t i = 0; i < post_in_tensors.size(); i++) { - if (post_in_tensors[i] == kernel->out_tensors()[0]) { - post_in_tensors[i] = trans_kernel->out_tensors()[0]; - break; - } + Tensor *old_in_tensor = nullptr; + // find out which input tensor of post_kernel should be updated + for (size_t i = 0; i < post_in_tensors.size(); ++i) { + if (KernelInputFromKernel(post_kernel, i) == kernel) { + old_in_tensor = post_in_tensors.at(i); + break; } } + if (old_in_tensor == nullptr) { + MS_LOG(WARNING) << "Could not find in tensor index"; + return; + } + std::replace(post_in_tensors.begin(), post_in_tensors.end(), old_in_tensor, trans_kernel->out_tensors().at(0)); post_kernel->set_in_tensors(post_in_tensors); - // The input tensor should be replaced with the output tensor of trans_kernel. - std::vector post_in_kernels = post_kernel->in_kernels(); - for (size_t i = 0; i < post_in_kernels.size(); i++) { - if (post_in_kernels[i] == kernel) { - post_in_kernels[i] = trans_kernel; - break; - } - } + // For post_kernel after trans, kernel in in_kernels should be replaced with trans_kernel. + auto post_in_kernels = post_kernel->in_kernels(); + std::replace(post_in_kernels.begin(), post_in_kernels.end(), kernel, trans_kernel); post_kernel->set_in_kernels(post_in_kernels); } bool NPUPassUtils::IsNhwc2Nchw(const kernel::LiteKernel *kernel) { + if (kernel == nullptr) { + return false; + } if (kernel->Type() != schema::PrimitiveType_Transpose) { return false; } @@ -215,6 +216,9 @@ bool NPUPassUtils::IsNhwc2Nchw(const kernel::LiteKernel *kernel) { } bool NPUPassUtils::IsNchw2Nhwc(const kernel::LiteKernel *kernel) { + if (kernel == nullptr) { + return false; + } if (kernel->Type() != schema::PrimitiveType_Transpose) { return false; } @@ -230,5 +234,22 @@ bool NPUPassUtils::IsNchw2Nhwc(const kernel::LiteKernel *kernel) { } return false; } - +kernel::LiteKernel *NPUPassUtils::KernelInputFromKernel(const kernel::LiteKernel *kernel, size_t in_tensor_index) { + // given kernel and input tensor index, get which kernel output this tensor. + // If input tensor is graph input, return nullptr. + if (kernel == nullptr) { + return nullptr; + } + auto tensor = kernel->in_tensors().at(in_tensor_index); + auto in_kernels = kernel->in_kernels(); + auto output_contain = [tensor](const kernel::LiteKernel *kernel) { + auto out_tensors = kernel->out_tensors(); + return std::find(out_tensors.begin(), out_tensors.end(), tensor) != out_tensors.end(); + }; + auto it = std::find_if(in_kernels.begin(), in_kernels.end(), output_contain); + if (it == in_kernels.end()) { + return nullptr; + } + return *it; +} } // namespace mindspore::lite diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.h b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.h index c3a8247f76..a92356b428 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.h +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_pass_utils.h @@ -52,6 +52,7 @@ class NPUPassUtils { static bool IsNhwc2Nchw(const kernel::LiteKernel *kernel); static bool IsNchw2Nhwc(const kernel::LiteKernel *kernel); + static kernel::LiteKernel *KernelInputFromKernel(const kernel::LiteKernel *kernel, size_t in_tensor_index); private: static PrimitiveC *CreateTransposePrimitive(); diff --git a/mindspore/lite/test/models_onnx_fp16.cfg b/mindspore/lite/test/models_onnx_fp16.cfg index 03213a338c..3e6823471a 100644 --- a/mindspore/lite/test/models_onnx_fp16.cfg +++ b/mindspore/lite/test/models_onnx_fp16.cfg @@ -26,7 +26,7 @@ crnn_lite_lstm_v2.onnx;32,32,32,1 0.3 psenet_lite_mbv2.onnx;1,32,32,3 0.6 super-resolution-10.onnx;1,224,224,1 4.5 tinyyolov2-8.onnx;1,416,416,3 5.5 -ml_2012_ocr_cn.onnx -1 +#ml_2012_ocr_cn.onnx -1 #ml_2012_ocr_cn_noLSTM.onnx 1 candy-9.onnx 5 mosaic-9.onnx 4