|
|
|
@ -23,15 +23,14 @@ namespace mindspore::lite {
|
|
|
|
|
bool CheckFusion(kernel::LiteKernel *kernel) {
|
|
|
|
|
auto pre_flag =
|
|
|
|
|
std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *in_kernel) {
|
|
|
|
|
return NPUPassUtils::IsNchw2Nhwc(const_cast<kernel::LiteKernel *>(in_kernel)) &&
|
|
|
|
|
in_kernel->out_kernels().size() == 1;
|
|
|
|
|
return NPUPassUtils::IsNchw2Nhwc(in_kernel) && in_kernel->out_kernels().size() == 1;
|
|
|
|
|
});
|
|
|
|
|
if (!pre_flag) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto post_flag =
|
|
|
|
|
std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(), [](const kernel::LiteKernel *out_kernel) {
|
|
|
|
|
return NPUPassUtils::IsNhwc2Nchw(const_cast<kernel::LiteKernel *>(out_kernel));
|
|
|
|
|
return NPUPassUtils::IsNhwc2Nchw(out_kernel) && (!out_kernel->out_kernels().empty());
|
|
|
|
|
});
|
|
|
|
|
return post_flag;
|
|
|
|
|
}
|
|
|
|
@ -40,16 +39,16 @@ bool CheckFormatFusion(kernel::LiteKernel *kernel) {
|
|
|
|
|
if (kernel->out_kernels().empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (NPUPassUtils::IsNhwc2Nchw(const_cast<kernel::LiteKernel *>(kernel))) {
|
|
|
|
|
if (NPUPassUtils::IsNhwc2Nchw(kernel)) {
|
|
|
|
|
return std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(),
|
|
|
|
|
[](const kernel::LiteKernel *kernel) {
|
|
|
|
|
return NPUPassUtils::IsNchw2Nhwc(const_cast<kernel::LiteKernel *>(kernel));
|
|
|
|
|
return NPUPassUtils::IsNchw2Nhwc(kernel) && (!kernel->out_kernels().empty());
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
if (NPUPassUtils::IsNchw2Nhwc(const_cast<kernel::LiteKernel *>(kernel))) {
|
|
|
|
|
if (NPUPassUtils::IsNchw2Nhwc(kernel)) {
|
|
|
|
|
return std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(),
|
|
|
|
|
[](const kernel::LiteKernel *kernel) {
|
|
|
|
|
return NPUPassUtils::IsNhwc2Nchw(const_cast<kernel::LiteKernel *>(kernel));
|
|
|
|
|
return NPUPassUtils::IsNhwc2Nchw(kernel) && (!kernel->out_kernels().empty());
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
@ -230,8 +229,7 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
|
|
|
|
|
int NPUFusionPass::Run() {
|
|
|
|
|
for (size_t i = 0; i < kernels->size(); i++) {
|
|
|
|
|
auto kernel = (*kernels)[i];
|
|
|
|
|
if (NPUPassUtils::IsNchw2Nhwc(const_cast<kernel::LiteKernel *>(kernel)) ||
|
|
|
|
|
NPUPassUtils::IsNhwc2Nchw(const_cast<kernel::LiteKernel *>(kernel))) {
|
|
|
|
|
if (NPUPassUtils::IsNchw2Nhwc(kernel) || NPUPassUtils::IsNhwc2Nchw(kernel)) {
|
|
|
|
|
if (CheckFormatFusion(kernel)) {
|
|
|
|
|
i--;
|
|
|
|
|
FormatFusion(kernel);
|
|
|
|
|