concat output fix npu pass

pull/13572/head
zhaozhenlong 4 years ago
parent 0fb8cd888d
commit e43005242c

@ -141,6 +141,23 @@ void UpdatePreTensors(kernel::LiteKernel *cur_kernel) {
void UpdatePostTensors(kernel::LiteKernel *cur_kernel) {
auto tensor = cur_kernel->out_tensors()[0];
// in case: node->nh2nc->nc2nh(graph output) --->>> node->nc2nh, node out_tensor should be put to nnc2nh out tensors
auto out_kernels = cur_kernel->out_kernels();
if (out_kernels.size() == 1 && out_kernels[0]->out_kernels().size() == 1 &&
out_kernels[0]->out_kernels()[0]->out_kernels().empty() &&
out_kernels[0]->out_kernels()[0]->type_str() == "Transpose") {
auto nc_tensor = out_kernels[0]->out_tensors()[0]; // nh2nc's out tensor
cur_kernel->set_out_tensors({nc_tensor});
auto post_post_kernel = out_kernels[0]->out_kernels()[0];
// nc2nh kernel set in_tensor out_tensor
auto post_post_k_in_tensors = post_post_kernel->in_tensors();
post_post_k_in_tensors[0] = nc_tensor;
post_post_kernel->set_in_tensors(post_post_k_in_tensors);
post_post_kernel->set_out_tensors({tensor});
return;
}
tensor->set_format(schema::Format_NCHW);
auto nhwc_shape = tensor->shape();
tensor->set_shape({nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]});

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h"
#include <algorithm>
#include <set>
#include <string>
#include "src/runtime/agent/npu/optimizer/npu_pass_utils.h"
@ -51,7 +52,8 @@ int GetInsertState(kernel::LiteKernel *kernel) {
// current kernel is target kernel
// use out kernels to count how many out lines from current kernel
size_t in_out_tensor_num = kernel->in_tensors().size() + kernel->out_kernels().size();
size_t in_out_tensor_num =
kernel->in_tensors().size() + std::max(kernel->out_kernels().size(), static_cast<size_t>(1));
size_t transpose_input_num = 0;
size_t transpose_output_num = 0;
bool need_pre_insert = false;
@ -65,6 +67,9 @@ int GetInsertState(kernel::LiteKernel *kernel) {
need_pre_insert = true;
}
}
if (kernel->out_kernels().empty()) {
need_post_insert = true;
}
for (const auto out_kernel : kernel->out_kernels()) {
if (NPUPassUtils::IsNhwc2Nchw(out_kernel)) {
transpose_output_num++;

@ -106,6 +106,9 @@ void NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *pre_kernel,
break;
}
}
if (out_kernels.empty()) {
out_kernels.push_back(trans_kernel);
}
pre_kernel->set_out_kernels(out_kernels);
}

@ -156,8 +156,8 @@ int NPUTransformPass::InsertPostNodes(kernel::LiteKernel *kernel, std::vector<ke
nc2nh_out_tensors[0] = out_tensor;
// Create post transform kernel: Nchw2Nhwc
auto *post_trans_kernel =
NPUPassUtils::CreateNchw2NhwcKernel({nc2nh_tensor, nc2nh_perm_tensor}, nc2nh_out_tensors, context_, name);
auto *post_trans_kernel = NPUPassUtils::CreateNchw2NhwcKernel(
{nc2nh_tensor, nc2nh_perm_tensor}, nc2nh_out_tensors, context_, name + "_" + std::to_string(i));
// Set in_kernels, out_kernels, in_tensors, out_tensors for transform kernel
NPUPassUtils::UpdateKernel(post_trans_kernel, {kernel}, {post_insert_kernel}, post_trans_kernel->in_tensors(),
post_trans_kernel->out_tensors());

@ -70,3 +70,4 @@ ml_video_edit_v10_best_model_nomean_20200723 8
ml_edu_kit_hand_key_position.onnx 2
#ml_video_edit_oneclick_adaptis.pb #too many subgraphs
densenet.tflite 3
resnet_v2_101_299.tflite 1

Loading…
Cancel
Save