|
|
|
@ -21,7 +21,9 @@ namespace mindspore::lite {
|
|
|
|
|
using kernel::KERNEL_ARCH::kNPU;
|
|
|
|
|
enum InsertState { InsertNone, PreInsert, PostInsert };
|
|
|
|
|
|
|
|
|
|
std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add};
|
|
|
|
|
std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
|
|
|
|
|
schema::PrimitiveType_Eltwise,
|
|
|
|
|
schema::PrimitiveType_Activation};
|
|
|
|
|
|
|
|
|
|
int GetInsertState(kernel::LiteKernel *kernel) {
|
|
|
|
|
if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) {
|
|
|
|
@ -42,16 +44,18 @@ int GetInsertState(kernel::LiteKernel *kernel) {
|
|
|
|
|
return InsertNone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
|
|
|
|
|
std::vector<kernel::LiteKernel *> *all_kernels,
|
|
|
|
|
int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
|
|
|
|
std::vector<kernel::LiteKernel *> *trans_kernels,
|
|
|
|
|
std::vector<Tensor *> *all_tensors) {
|
|
|
|
|
for (auto kernel : cur_kernel->in_kernels()) {
|
|
|
|
|
if (kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) {
|
|
|
|
|
for (auto in_kernel : kernel->in_kernels()) {
|
|
|
|
|
if (in_kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto nhwc_shape = cur_kernel->out_tensors()[0]->shape();
|
|
|
|
|
auto nhwc_shape = in_kernel->out_tensors()[0]->shape();
|
|
|
|
|
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
|
|
|
|
|
auto nh2nc_tensor = new Tensor(kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
|
|
|
|
|
|
|
|
|
|
auto nh2nc_tensor =
|
|
|
|
|
new Tensor(in_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
|
|
|
|
|
std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor};
|
|
|
|
|
all_tensors->push_back(nh2nc_tensors[0]);
|
|
|
|
|
|
|
|
|
@ -59,34 +63,36 @@ int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::L
|
|
|
|
|
std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor};
|
|
|
|
|
all_tensors->push_back(nc2nh_tensors[0]);
|
|
|
|
|
|
|
|
|
|
auto nh2nc_name = kernel->name() + "_nh2nc_" + std::to_string(total++);
|
|
|
|
|
auto *nh2nc_kernel = NPUPassUtils::CreateNhwc2NchwKernel(kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
|
|
|
|
|
all_kernels->push_back(nh2nc_kernel);
|
|
|
|
|
auto nh2nc_name = in_kernel->name() + "_nh2nc_" + std::to_string(total++);
|
|
|
|
|
auto *nh2nc_kernel =
|
|
|
|
|
NPUPassUtils::CreateNhwc2NchwKernel(in_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
|
|
|
|
|
trans_kernels->push_back(nh2nc_kernel);
|
|
|
|
|
insert_primitive_.push_back(nh2nc_kernel->GetPrimitive());
|
|
|
|
|
auto nc2nh_name = kernel->name() + "_nc2nh_" + std::to_string(total++);
|
|
|
|
|
|
|
|
|
|
auto nc2nh_name = in_kernel->name() + "_nc2nh_" + std::to_string(total++);
|
|
|
|
|
auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name);
|
|
|
|
|
all_kernels->push_back(nc2nh_kernel);
|
|
|
|
|
trans_kernels->push_back(nc2nh_kernel);
|
|
|
|
|
insert_primitive_.push_back(nc2nh_kernel->GetPrimitive());
|
|
|
|
|
NPUPassUtils::UpdateKernel(nh2nc_kernel, {kernel}, {nc2nh_kernel}, kernel->out_tensors(), nh2nc_tensors);
|
|
|
|
|
NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {cur_kernel}, nh2nc_tensors, nc2nh_tensors);
|
|
|
|
|
NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, cur_kernel);
|
|
|
|
|
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, cur_kernel);
|
|
|
|
|
|
|
|
|
|
NPUPassUtils::UpdateKernel(nh2nc_kernel, {in_kernel}, {nc2nh_kernel}, in_kernel->out_tensors(), nh2nc_tensors);
|
|
|
|
|
NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {kernel}, nh2nc_tensors, nc2nh_tensors);
|
|
|
|
|
NPUPassUtils::UpdateNH2NCTransNodePreKernel(in_kernel, nh2nc_kernel, kernel);
|
|
|
|
|
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(in_kernel, nc2nh_kernel, kernel);
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
|
|
|
|
|
std::vector<kernel::LiteKernel *> *all_kernels,
|
|
|
|
|
int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
|
|
|
|
std::vector<kernel::LiteKernel *> *trans_kernels,
|
|
|
|
|
std::vector<Tensor *> *all_tensors) {
|
|
|
|
|
for (auto out_kernel : cur_kernel->out_kernels()) {
|
|
|
|
|
for (auto out_kernel : kernel->out_kernels()) {
|
|
|
|
|
if (out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto nhwc_shape = cur_kernel->out_tensors()[0]->shape();
|
|
|
|
|
auto nhwc_shape = kernel->out_tensors()[0]->shape();
|
|
|
|
|
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
|
|
|
|
|
|
|
|
|
|
auto nh2nc_tensor =
|
|
|
|
|
new Tensor(cur_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
|
|
|
|
|
auto nh2nc_tensor = new Tensor(kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
|
|
|
|
|
std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor};
|
|
|
|
|
all_tensors->push_back(nh2nc_tensors[0]);
|
|
|
|
|
|
|
|
|
@ -94,19 +100,20 @@ int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::
|
|
|
|
|
std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor};
|
|
|
|
|
all_tensors->push_back(nc2nh_tensors[0]);
|
|
|
|
|
|
|
|
|
|
auto nh2nc_name = cur_kernel->name() + "_nh2nc_" + std::to_string(total++);
|
|
|
|
|
auto *nh2nc_kernel =
|
|
|
|
|
NPUPassUtils::CreateNhwc2NchwKernel(cur_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
|
|
|
|
|
all_kernels->push_back(nh2nc_kernel);
|
|
|
|
|
auto nh2nc_name = kernel->name() + "_nh2nc_" + std::to_string(total++);
|
|
|
|
|
auto *nh2nc_kernel = NPUPassUtils::CreateNhwc2NchwKernel(kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
|
|
|
|
|
trans_kernels->push_back(nh2nc_kernel);
|
|
|
|
|
insert_primitive_.push_back(nh2nc_kernel->GetPrimitive());
|
|
|
|
|
auto nc2nh_name = cur_kernel->name() + "_nc2nh_" + std::to_string(total++);
|
|
|
|
|
|
|
|
|
|
auto nc2nh_name = kernel->name() + "_nc2nh_" + std::to_string(total++);
|
|
|
|
|
auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name);
|
|
|
|
|
all_kernels->push_back(nc2nh_kernel);
|
|
|
|
|
trans_kernels->push_back(nc2nh_kernel);
|
|
|
|
|
insert_primitive_.push_back(nc2nh_kernel->GetPrimitive());
|
|
|
|
|
NPUPassUtils::UpdateKernel(nh2nc_kernel, {cur_kernel}, {nc2nh_kernel}, cur_kernel->out_tensors(), nh2nc_tensors);
|
|
|
|
|
|
|
|
|
|
NPUPassUtils::UpdateKernel(nh2nc_kernel, {kernel}, {nc2nh_kernel}, kernel->out_tensors(), nh2nc_tensors);
|
|
|
|
|
NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {out_kernel}, nh2nc_tensors, nc2nh_tensors);
|
|
|
|
|
NPUPassUtils::UpdateNH2NCTransNodePreKernel(cur_kernel, nh2nc_kernel, out_kernel);
|
|
|
|
|
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(cur_kernel, nc2nh_kernel, out_kernel);
|
|
|
|
|
NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, out_kernel);
|
|
|
|
|
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, out_kernel);
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|