|
|
|
@ -23,8 +23,10 @@ namespace mindspore::lite {
|
|
|
|
|
using kernel::KERNEL_ARCH::kNPU;
|
|
|
|
|
enum InsertState { InsertNone, PreInsert, PostInsert, BothInsert };
|
|
|
|
|
std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {
|
|
|
|
|
schema::PrimitiveType_Concat, schema::PrimitiveType_AddFusion, schema::PrimitiveType_Eltwise,
|
|
|
|
|
schema::PrimitiveType_Activation};
|
|
|
|
|
schema::PrimitiveType_Concat, schema::PrimitiveType_AddFusion, schema::PrimitiveType_Eltwise,
|
|
|
|
|
schema::PrimitiveType_Activation, schema::PrimitiveType_Split, schema::PrimitiveType_PadFusion,
|
|
|
|
|
schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Activation};
|
|
|
|
|
|
|
|
|
|
// this pass goal is to minimize subgraphs generated
|
|
|
|
|
// by inserting nchw2nhwc or nhwc2nchw before or after the operator (e.g. concat, add, etc..) together with
|
|
|
|
|
// fusion pass. If transpose inserted are more than half of input output, we will insert remaining input
|
|
|
|
@ -44,7 +46,7 @@ std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {
|
|
|
|
|
// so we won't insert nc2nh or nh2nc when op's in kernels and out kernels contains no nc2nh or nh2nc.
|
|
|
|
|
// This pass should be run after npu_transform_pass, which insert transpose for nchw-input-limited op like conv2d.
|
|
|
|
|
|
|
|
|
|
int GetInsertState(kernel::LiteKernel *kernel) {
|
|
|
|
|
int NPUInsertTransformPass::GetInsertState(kernel::LiteKernel *kernel) {
|
|
|
|
|
// filter out irrelevant kernel
|
|
|
|
|
if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) {
|
|
|
|
|
return InsertNone;
|
|
|
|
@ -52,15 +54,17 @@ int GetInsertState(kernel::LiteKernel *kernel) {
|
|
|
|
|
|
|
|
|
|
// current kernel is target kernel
|
|
|
|
|
// use out kernels to count how many out lines from current kernel
|
|
|
|
|
std::vector<Tensor *> in_tensors = NPUPassUtils::GetNonConstInputs(kernel);
|
|
|
|
|
size_t in_out_tensor_num =
|
|
|
|
|
kernel->in_tensors().size() + std::max(kernel->out_kernels().size(), static_cast<size_t>(1));
|
|
|
|
|
in_tensors.size() +
|
|
|
|
|
std::max(std::max(kernel->out_kernels().size(), static_cast<size_t>(1)), kernel->out_tensors().size());
|
|
|
|
|
size_t transpose_input_num = 0;
|
|
|
|
|
size_t transpose_output_num = 0;
|
|
|
|
|
bool need_pre_insert = false;
|
|
|
|
|
bool need_post_insert = false;
|
|
|
|
|
// count number of input tensor from nc2nh and output tensor to nh2nc
|
|
|
|
|
for (size_t i = 0; i < kernel->in_tensors().size(); ++i) {
|
|
|
|
|
auto in_kernel = NPUPassUtils::KernelInputFromKernel(kernel, i);
|
|
|
|
|
for (size_t i = 0; i < in_tensors.size(); ++i) {
|
|
|
|
|
auto in_kernel = NPUPassUtils::KernelInputFromKernel(kernel, in_tensors.at(i));
|
|
|
|
|
if (NPUPassUtils::IsNchw2Nhwc(in_kernel)) {
|
|
|
|
|
transpose_input_num++;
|
|
|
|
|
} else {
|
|
|
|
@ -81,21 +85,22 @@ int GetInsertState(kernel::LiteKernel *kernel) {
|
|
|
|
|
// won't insert any thing if num of transpose tensor is smaller than half of total input output.
|
|
|
|
|
// won't insert if total input output are all transpose tensor, the fusion pass will handle this.
|
|
|
|
|
size_t transpose_tensor_num = transpose_input_num + transpose_output_num;
|
|
|
|
|
if (transpose_tensor_num <= in_out_tensor_num / 2 || transpose_tensor_num == in_out_tensor_num) {
|
|
|
|
|
if (transpose_tensor_num == 0 || transpose_tensor_num * 2 < in_out_tensor_num ||
|
|
|
|
|
transpose_tensor_num == in_out_tensor_num) {
|
|
|
|
|
return InsertNone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
InsertState ret;
|
|
|
|
|
if (need_pre_insert && !need_post_insert) {
|
|
|
|
|
return PreInsert;
|
|
|
|
|
}
|
|
|
|
|
if (need_pre_insert && need_post_insert) {
|
|
|
|
|
return BothInsert;
|
|
|
|
|
}
|
|
|
|
|
if (!need_pre_insert && need_post_insert) {
|
|
|
|
|
return PostInsert;
|
|
|
|
|
ret = PreInsert;
|
|
|
|
|
} else if (need_pre_insert && need_post_insert) {
|
|
|
|
|
ret = BothInsert;
|
|
|
|
|
} else if (!need_pre_insert && need_post_insert) {
|
|
|
|
|
ret = PostInsert;
|
|
|
|
|
} else {
|
|
|
|
|
ret = InsertNone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return InsertNone;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel,
|
|
|
|
@ -200,13 +205,20 @@ int NPUInsertTransformPass::InsertForOutputTensor(kernel::LiteKernel *kernel, ke
|
|
|
|
|
int NPUInsertTransformPass::InsertPreNodes(kernel::LiteKernel *kernel,
|
|
|
|
|
std::vector<kernel::LiteKernel *> *trans_kernels) {
|
|
|
|
|
int ret = RET_OK;
|
|
|
|
|
for (size_t i = 0; i < kernel->in_tensors().size(); ++i) {
|
|
|
|
|
auto pre_kernel = NPUPassUtils::KernelInputFromKernel(kernel, i);
|
|
|
|
|
auto in_tensors = NPUPassUtils::GetNonConstInputs(kernel);
|
|
|
|
|
for (auto tensor : in_tensors) {
|
|
|
|
|
auto pre_kernel = NPUPassUtils::KernelInputFromKernel(kernel, tensor);
|
|
|
|
|
if (NPUPassUtils::IsNchw2Nhwc(pre_kernel)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// if this tensor is input of graph, pre_kernel is nullptr.
|
|
|
|
|
ret = InsertForInputTensor(kernel, i, pre_kernel, trans_kernels);
|
|
|
|
|
auto it = find(kernel->in_tensors().begin(), kernel->in_tensors().end(), tensor);
|
|
|
|
|
if (it == kernel->in_tensors().end()) {
|
|
|
|
|
MS_LOG(ERROR) << "Find in tensor index error";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
size_t index = it - kernel->in_tensors().begin();
|
|
|
|
|
ret = InsertForInputTensor(kernel, index, pre_kernel, trans_kernels);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed.";
|
|
|
|
|
return ret;
|
|
|
|
@ -249,59 +261,63 @@ int NPUInsertTransformPass::InsertPostNodes(kernel::LiteKernel *kernel,
|
|
|
|
|
|
|
|
|
|
int NPUInsertTransformPass::Run() {
|
|
|
|
|
std::vector<kernel::LiteKernel *> insert_kernels;
|
|
|
|
|
for (size_t i = 0; i < all_kernels_->size(); i++) {
|
|
|
|
|
auto kernel = (*all_kernels_)[i];
|
|
|
|
|
if (kernel->desc().arch != kNPU) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto insert_state = GetInsertState(kernel);
|
|
|
|
|
insert_kernels.clear();
|
|
|
|
|
// If the every output kernel is nhwc2nchw, insert
|
|
|
|
|
// modify loop index add post_kernels.size() to the next kernel in the origin vector
|
|
|
|
|
switch (insert_state) {
|
|
|
|
|
case PreInsert: {
|
|
|
|
|
auto ret = InsertPreNodes(kernel, &insert_kernels);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name()
|
|
|
|
|
<< " failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end());
|
|
|
|
|
i += insert_kernels.size();
|
|
|
|
|
break;
|
|
|
|
|
for (int j = 0; j < 2; ++j) {
|
|
|
|
|
for (size_t i = 0; i < all_kernels_->size(); i++) {
|
|
|
|
|
auto kernel = (*all_kernels_)[i];
|
|
|
|
|
if (kernel->desc().arch != kNPU) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
case PostInsert: {
|
|
|
|
|
auto ret = InsertPostNodes(kernel, &insert_kernels);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
auto insert_state = GetInsertState(kernel);
|
|
|
|
|
insert_kernels.clear();
|
|
|
|
|
// If the every output kernel is nhwc2nchw, insert
|
|
|
|
|
// modify loop index add post_kernels.size() to the next kernel in the origin vector
|
|
|
|
|
switch (insert_state) {
|
|
|
|
|
case PreInsert: {
|
|
|
|
|
auto ret = InsertPreNodes(kernel, &insert_kernels);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name()
|
|
|
|
|
<< " failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end());
|
|
|
|
|
i += insert_kernels.size();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end());
|
|
|
|
|
i += insert_kernels.size();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case BothInsert: {
|
|
|
|
|
auto ret = InsertPreNodes(kernel, &insert_kernels);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name()
|
|
|
|
|
<< " failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
case PostInsert: {
|
|
|
|
|
auto ret = InsertPostNodes(kernel, &insert_kernels);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name()
|
|
|
|
|
<< " failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end());
|
|
|
|
|
i += insert_kernels.size();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end());
|
|
|
|
|
i += insert_kernels.size();
|
|
|
|
|
case BothInsert: {
|
|
|
|
|
auto ret = InsertPreNodes(kernel, &insert_kernels);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name()
|
|
|
|
|
<< " failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end());
|
|
|
|
|
i += insert_kernels.size();
|
|
|
|
|
|
|
|
|
|
insert_kernels.clear();
|
|
|
|
|
ret = InsertPostNodes(kernel, &insert_kernels);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
insert_kernels.clear();
|
|
|
|
|
ret = InsertPostNodes(kernel, &insert_kernels);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name()
|
|
|
|
|
<< " failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end());
|
|
|
|
|
i += insert_kernels.size();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end());
|
|
|
|
|
i += insert_kernels.size();
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(DEBUG) << "Insert Nothing on kernel " << kernel->name();
|
|
|
|
|
}
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(DEBUG) << "Insert Nothing on kernel " << kernel->name();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|