|
|
|
@ -14,17 +14,19 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "src/kernel_registry.h"
|
|
|
|
|
#include "src/ops/nhwc2nchw.h"
|
|
|
|
|
#include "src/ops/nchw2nhwc.h"
|
|
|
|
|
#include "src/runtime/agent/npu/optimizer/npu_pass_utils.h"
|
|
|
|
|
#include "src/ops/transpose.h"
|
|
|
|
|
#include "nnacl/transpose.h"
|
|
|
|
|
#include "src/ops/populate/populate_register.h"
|
|
|
|
|
#include "src/runtime/kernel/arm/fp32/transpose_fp32.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore::lite {
|
|
|
|
|
using kernel::KERNEL_ARCH::kCPU;
|
|
|
|
|
using kernel::KERNEL_ARCH::kNPU;
|
|
|
|
|
PrimitiveC *NPUPassUtils::CreateNchw2NhwcPrimitive() {
|
|
|
|
|
PrimitiveC *NPUPassUtils::CreateTransposePrimitive() {
|
|
|
|
|
flatbuffers::FlatBufferBuilder fbb(1024);
|
|
|
|
|
auto val_offset = schema::CreateNchw2Nhwc(fbb);
|
|
|
|
|
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Nchw2Nhwc, val_offset.o);
|
|
|
|
|
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Transpose, val_offset.o);
|
|
|
|
|
fbb.Finish(prim_offset);
|
|
|
|
|
auto buf = fbb.GetBufferPointer();
|
|
|
|
|
if (buf == nullptr) {
|
|
|
|
@ -39,56 +41,72 @@ PrimitiveC *NPUPassUtils::CreateNchw2NhwcPrimitive() {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
memcpy(primitive_buf, buf, fbb.GetSize());
|
|
|
|
|
auto *primitive = PrimitiveC::NewPrimitiveC<Nchw2Nhwc>(flatbuffers::GetRoot<schema::Primitive>(primitive_buf));
|
|
|
|
|
auto *primitive = PrimitiveC::NewPrimitiveC<Transpose>(flatbuffers::GetRoot<schema::Primitive>(primitive_buf));
|
|
|
|
|
free(primitive_buf);
|
|
|
|
|
fbb.Clear();
|
|
|
|
|
return primitive;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PrimitiveC *NPUPassUtils::CreateNhwc2NchwPrimitive() {
|
|
|
|
|
flatbuffers::FlatBufferBuilder fbb(1024);
|
|
|
|
|
auto val_offset = schema::CreateNhwc2Nchw(fbb);
|
|
|
|
|
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Nhwc2Nchw, val_offset.o);
|
|
|
|
|
fbb.Finish(prim_offset);
|
|
|
|
|
auto buf = fbb.GetBufferPointer();
|
|
|
|
|
if (buf == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "GetBufferPointer return nullptr";
|
|
|
|
|
fbb.Clear();
|
|
|
|
|
kernel::LiteKernel *NPUPassUtils::CreateNchw2NhwcKernel(const std::vector<Tensor *> &in_tensors,
|
|
|
|
|
const std::vector<Tensor *> &out_tensors,
|
|
|
|
|
const InnerContext *ctx, const std::string &name) {
|
|
|
|
|
kernel::KernelKey key{kCPU, kNumberTypeFloat32, schema::PrimitiveType_Transpose};
|
|
|
|
|
auto nchw2nhwc_primitive = CreateTransposePrimitive();
|
|
|
|
|
auto *transpose_param = reinterpret_cast<TransposeParameter *>(malloc(sizeof(TransposeParameter)));
|
|
|
|
|
if (transpose_param == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc TransposeParameter failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto primitive_buf = reinterpret_cast<char *>(malloc(fbb.GetSize()));
|
|
|
|
|
if (primitive_buf == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Malloc primitive buffer failed.";
|
|
|
|
|
fbb.Clear();
|
|
|
|
|
memset(transpose_param, 0, sizeof(TransposeParameter));
|
|
|
|
|
transpose_param->op_parameter_.type_ = nchw2nhwc_primitive->Type();
|
|
|
|
|
transpose_param->perm_[0] = 0;
|
|
|
|
|
transpose_param->perm_[1] = 2;
|
|
|
|
|
transpose_param->perm_[2] = 3;
|
|
|
|
|
transpose_param->perm_[3] = 1;
|
|
|
|
|
transpose_param->num_axes_ = 4;
|
|
|
|
|
|
|
|
|
|
auto kernel = new (std::nothrow) kernel::TransposeCPUKernel(reinterpret_cast<OpParameter *>(transpose_param),
|
|
|
|
|
in_tensors, out_tensors, ctx, nchw2nhwc_primitive);
|
|
|
|
|
if (kernel != nullptr) {
|
|
|
|
|
kernel->set_desc(key);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "New Nchw2Nhwc Kernel failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
memcpy(primitive_buf, buf, fbb.GetSize());
|
|
|
|
|
auto *primitive = PrimitiveC::NewPrimitiveC<Nhwc2Nchw>(flatbuffers::GetRoot<schema::Primitive>(primitive_buf));
|
|
|
|
|
free(primitive_buf);
|
|
|
|
|
fbb.Clear();
|
|
|
|
|
return primitive;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel::LiteKernel *NPUPassUtils::CreateNchw2NhwcKernel(const std::vector<Tensor *> &in_tensors,
|
|
|
|
|
const std::vector<Tensor *> &out_tensors,
|
|
|
|
|
const InnerContext *ctx, const std::string &name) {
|
|
|
|
|
kernel::KernelKey key{kCPU, kNumberTypeFloat32, schema::PrimitiveType_Nchw2Nhwc};
|
|
|
|
|
auto nchw2nhwc_primitive = CreateNchw2NhwcPrimitive();
|
|
|
|
|
auto *nchw2nhwc_kernel =
|
|
|
|
|
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, nchw2nhwc_primitive, ctx, key);
|
|
|
|
|
nchw2nhwc_kernel->set_name(name);
|
|
|
|
|
return nchw2nhwc_kernel;
|
|
|
|
|
kernel->set_name(name);
|
|
|
|
|
return kernel;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel::LiteKernel *NPUPassUtils::CreateNhwc2NchwKernel(const std::vector<Tensor *> &in_tensors,
|
|
|
|
|
const std::vector<Tensor *> &out_tensors,
|
|
|
|
|
const InnerContext *ctx, const std::string &name) {
|
|
|
|
|
kernel::KernelKey key{kCPU, kNumberTypeFloat32, schema::PrimitiveType_Nhwc2Nchw};
|
|
|
|
|
auto nhwc2nchw_primitive = CreateNhwc2NchwPrimitive();
|
|
|
|
|
auto *nhwc2nchw_kernel =
|
|
|
|
|
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, nhwc2nchw_primitive, ctx, key);
|
|
|
|
|
nhwc2nchw_kernel->set_name(name);
|
|
|
|
|
return nhwc2nchw_kernel;
|
|
|
|
|
kernel::KernelKey key{kCPU, kNumberTypeFloat32, schema::PrimitiveType_Transpose};
|
|
|
|
|
auto nhwc2nchw_primitive = CreateTransposePrimitive();
|
|
|
|
|
auto *transpose_param = reinterpret_cast<TransposeParameter *>(malloc(sizeof(TransposeParameter)));
|
|
|
|
|
if (transpose_param == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc TransposeParameter failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
memset(transpose_param, 0, sizeof(TransposeParameter));
|
|
|
|
|
transpose_param->op_parameter_.type_ = nhwc2nchw_primitive->Type();
|
|
|
|
|
transpose_param->perm_[0] = 0;
|
|
|
|
|
transpose_param->perm_[1] = 3;
|
|
|
|
|
transpose_param->perm_[2] = 1;
|
|
|
|
|
transpose_param->perm_[3] = 2;
|
|
|
|
|
transpose_param->num_axes_ = 4;
|
|
|
|
|
|
|
|
|
|
auto kernel = new (std::nothrow) kernel::TransposeCPUKernel(reinterpret_cast<OpParameter *>(transpose_param),
|
|
|
|
|
in_tensors, out_tensors, ctx, nhwc2nchw_primitive);
|
|
|
|
|
if (kernel != nullptr) {
|
|
|
|
|
kernel->set_desc(key);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "New Nhwc2Nchw Kernel failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel->set_name(name);
|
|
|
|
|
return kernel;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void NPUPassUtils::UpdateKernel(kernel::LiteKernel *kernel, const std::vector<kernel::LiteKernel *> &in_kernels,
|
|
|
|
@ -173,4 +191,39 @@ void NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, k
|
|
|
|
|
post_kernel->set_in_kernels(post_in_kernels);
|
|
|
|
|
post_kernel->set_in_tensors({post_in_tensors});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NPUPassUtils::IsNhwc2Nchw(kernel::LiteKernel *kernel) {
|
|
|
|
|
if (kernel->Type() != schema::PrimitiveType_Transpose) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto parameter = reinterpret_cast<TransposeParameter *>(kernel->op_parameter());
|
|
|
|
|
if (parameter->num_axes_ != 4) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> perm = {parameter->perm_[0], parameter->perm_[1], parameter->perm_[2], parameter->perm_[3]};
|
|
|
|
|
std::vector<int> nh2nc_perm = {0, 3, 1, 2};
|
|
|
|
|
if (nh2nc_perm == perm) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NPUPassUtils::IsNchw2Nhwc(kernel::LiteKernel *kernel) {
|
|
|
|
|
if (kernel->Type() != schema::PrimitiveType_Transpose) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto parameter = reinterpret_cast<TransposeParameter *>(kernel->op_parameter());
|
|
|
|
|
if (parameter->num_axes_ != 4) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> perm = {parameter->perm_[0], parameter->perm_[1], parameter->perm_[2], parameter->perm_[3]};
|
|
|
|
|
std::vector<int> nh2nc_perm = {0, 2, 3, 1};
|
|
|
|
|
if (nh2nc_perm == perm) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace mindspore::lite
|
|
|
|
|