!12365 npu insert trans by in tensor

From: @zhaozhenlong
Reviewed-by: 
Signed-off-by:
pull/12365/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0026f2b84d

@ -246,31 +246,31 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
int NPUFusionPass::Run() {
for (size_t i = 0; i < kernels->size(); i++) {
auto kernel = (*kernels)[i];
if (NPUPassUtils::IsNchw2Nhwc(kernel) || NPUPassUtils::IsNhwc2Nchw(kernel)) {
if (CheckFormatFusion(kernel)) {
i--;
FormatFusion(kernel);
if (CheckFusion(kernel)) {
switch (kernel->Type()) {
case schema::PrimitiveType_Concat:
i -= kernel->in_kernels().size();
ConcatFusion(kernel);
continue;
case schema::PrimitiveType_Add:
case schema::PrimitiveType_Activation:
case schema::PrimitiveType_Eltwise:
i -= kernel->in_kernels().size();
CommonFusion(kernel);
continue;
default:
continue;
}
continue;
}
if (!CheckFusion(kernel)) {
continue;
}
switch (kernel->Type()) {
case schema::PrimitiveType_Concat:
i -= kernel->in_kernels().size();
ConcatFusion(kernel);
continue;
case schema::PrimitiveType_Add:
case schema::PrimitiveType_Activation:
case schema::PrimitiveType_Eltwise:
i -= kernel->in_kernels().size();
CommonFusion(kernel);
continue;
default:
continue;
}
for (size_t i = 0; i < kernels->size(); ++i) {
auto kernel = (*kernels)[i];
if (CheckFormatFusion(kernel)) {
i--;
FormatFusion(kernel);
}
}
return RET_OK;
}
} // namespace mindspore::lite

@ -45,8 +45,13 @@ class NPUInsertTransformPass : public NPUBasePass {
int InsertPostNodes(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *trans_kernels);
int InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel,
int InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, size_t post_input_index,
std::vector<kernel::LiteKernel *> *trans_kernels);
int InsertForInputTensor(kernel::LiteKernel *kernel, size_t in_tensor_index, kernel::LiteKernel *pre_kernel,
std::vector<kernel::LiteKernel *> *trans_kernels);
int InsertForOutputTensor(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, size_t post_in_tensor_index,
std::vector<kernel::LiteKernel *> *trans_kernels);
private:
int total = 0;

@ -172,32 +172,33 @@ void NPUPassUtils::UpdateNC2NHPostKernelInTensors(kernel::LiteKernel *kernel, ke
void NPUPassUtils::UpdateNC2NHTransNodePostKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *post_kernel) {
// For post_kernel after trans, kernel should be replaced with trans_kernel.
// The input tensor should be replaced with the output tensor of trans_kernel.
auto post_in_tensors = post_kernel->in_tensors();
if (kernel == nullptr) {
post_in_tensors[0] = trans_kernel->out_tensors()[0];
} else {
for (size_t i = 0; i < post_in_tensors.size(); i++) {
if (post_in_tensors[i] == kernel->out_tensors()[0]) {
post_in_tensors[i] = trans_kernel->out_tensors()[0];
break;
}
Tensor *old_in_tensor = nullptr;
// find out which input tensor of post_kernel should be updated
for (size_t i = 0; i < post_in_tensors.size(); ++i) {
if (KernelInputFromKernel(post_kernel, i) == kernel) {
old_in_tensor = post_in_tensors.at(i);
break;
}
}
if (old_in_tensor == nullptr) {
MS_LOG(WARNING) << "Could not find in tensor index";
return;
}
std::replace(post_in_tensors.begin(), post_in_tensors.end(), old_in_tensor, trans_kernel->out_tensors().at(0));
post_kernel->set_in_tensors(post_in_tensors);
// The input tensor should be replaced with the output tensor of trans_kernel.
std::vector<kernel::LiteKernel *> post_in_kernels = post_kernel->in_kernels();
for (size_t i = 0; i < post_in_kernels.size(); i++) {
if (post_in_kernels[i] == kernel) {
post_in_kernels[i] = trans_kernel;
break;
}
}
// For post_kernel after trans, kernel in in_kernels should be replaced with trans_kernel.
auto post_in_kernels = post_kernel->in_kernels();
std::replace(post_in_kernels.begin(), post_in_kernels.end(), kernel, trans_kernel);
post_kernel->set_in_kernels(post_in_kernels);
}
bool NPUPassUtils::IsNhwc2Nchw(const kernel::LiteKernel *kernel) {
if (kernel == nullptr) {
return false;
}
if (kernel->Type() != schema::PrimitiveType_Transpose) {
return false;
}
@ -215,6 +216,9 @@ bool NPUPassUtils::IsNhwc2Nchw(const kernel::LiteKernel *kernel) {
}
bool NPUPassUtils::IsNchw2Nhwc(const kernel::LiteKernel *kernel) {
if (kernel == nullptr) {
return false;
}
if (kernel->Type() != schema::PrimitiveType_Transpose) {
return false;
}
@ -230,5 +234,22 @@ bool NPUPassUtils::IsNchw2Nhwc(const kernel::LiteKernel *kernel) {
}
return false;
}
kernel::LiteKernel *NPUPassUtils::KernelInputFromKernel(const kernel::LiteKernel *kernel, size_t in_tensor_index) {
// given kernel and input tensor index, get which kernel output this tensor.
// If input tensor is graph input, return nullptr.
if (kernel == nullptr) {
return nullptr;
}
auto tensor = kernel->in_tensors().at(in_tensor_index);
auto in_kernels = kernel->in_kernels();
auto output_contain = [tensor](const kernel::LiteKernel *kernel) {
auto out_tensors = kernel->out_tensors();
return std::find(out_tensors.begin(), out_tensors.end(), tensor) != out_tensors.end();
};
auto it = std::find_if(in_kernels.begin(), in_kernels.end(), output_contain);
if (it == in_kernels.end()) {
return nullptr;
}
return *it;
}
} // namespace mindspore::lite

@ -52,6 +52,7 @@ class NPUPassUtils {
static bool IsNhwc2Nchw(const kernel::LiteKernel *kernel);
static bool IsNchw2Nhwc(const kernel::LiteKernel *kernel);
static kernel::LiteKernel *KernelInputFromKernel(const kernel::LiteKernel *kernel, size_t in_tensor_index);
private:
static PrimitiveC *CreateTransposePrimitive();

@ -26,7 +26,7 @@ crnn_lite_lstm_v2.onnx;32,32,32,1 0.3
psenet_lite_mbv2.onnx;1,32,32,3 0.6
super-resolution-10.onnx;1,224,224,1 4.5
tinyyolov2-8.onnx;1,416,416,3 5.5
ml_2012_ocr_cn.onnx -1
#ml_2012_ocr_cn.onnx -1
#ml_2012_ocr_cn_noLSTM.onnx 1
candy-9.onnx 5
mosaic-9.onnx 4

Loading…
Cancel
Save