|
|
|
@ -172,32 +172,33 @@ void NPUPassUtils::UpdateNC2NHPostKernelInTensors(kernel::LiteKernel *kernel, ke
|
|
|
|
|
|
|
|
|
|
void NPUPassUtils::UpdateNC2NHTransNodePostKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
|
|
|
|
|
kernel::LiteKernel *post_kernel) {
|
|
|
|
|
// For post_kernel after trans, kernel should be replaced with trans_kernel.
|
|
|
|
|
// The input tensor should be replaced with the output tensor of trans_kernel.
|
|
|
|
|
auto post_in_tensors = post_kernel->in_tensors();
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
post_in_tensors[0] = trans_kernel->out_tensors()[0];
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t i = 0; i < post_in_tensors.size(); i++) {
|
|
|
|
|
if (post_in_tensors[i] == kernel->out_tensors()[0]) {
|
|
|
|
|
post_in_tensors[i] = trans_kernel->out_tensors()[0];
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
Tensor *old_in_tensor = nullptr;
|
|
|
|
|
// find out which input tensor of post_kernel should be updated
|
|
|
|
|
for (size_t i = 0; i < post_in_tensors.size(); ++i) {
|
|
|
|
|
if (KernelInputFromKernel(post_kernel, i) == kernel) {
|
|
|
|
|
old_in_tensor = post_in_tensors.at(i);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (old_in_tensor == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "Could not find in tensor index";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::replace(post_in_tensors.begin(), post_in_tensors.end(), old_in_tensor, trans_kernel->out_tensors().at(0));
|
|
|
|
|
post_kernel->set_in_tensors(post_in_tensors);
|
|
|
|
|
|
|
|
|
|
// The input tensor should be replaced with the output tensor of trans_kernel.
|
|
|
|
|
std::vector<kernel::LiteKernel *> post_in_kernels = post_kernel->in_kernels();
|
|
|
|
|
for (size_t i = 0; i < post_in_kernels.size(); i++) {
|
|
|
|
|
if (post_in_kernels[i] == kernel) {
|
|
|
|
|
post_in_kernels[i] = trans_kernel;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// For post_kernel after trans, kernel in in_kernels should be replaced with trans_kernel.
|
|
|
|
|
auto post_in_kernels = post_kernel->in_kernels();
|
|
|
|
|
std::replace(post_in_kernels.begin(), post_in_kernels.end(), kernel, trans_kernel);
|
|
|
|
|
post_kernel->set_in_kernels(post_in_kernels);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NPUPassUtils::IsNhwc2Nchw(const kernel::LiteKernel *kernel) {
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (kernel->Type() != schema::PrimitiveType_Transpose) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -215,6 +216,9 @@ bool NPUPassUtils::IsNhwc2Nchw(const kernel::LiteKernel *kernel) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NPUPassUtils::IsNchw2Nhwc(const kernel::LiteKernel *kernel) {
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (kernel->Type() != schema::PrimitiveType_Transpose) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -230,5 +234,22 @@ bool NPUPassUtils::IsNchw2Nhwc(const kernel::LiteKernel *kernel) {
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel::LiteKernel *NPUPassUtils::KernelInputFromKernel(const kernel::LiteKernel *kernel, size_t in_tensor_index) {
|
|
|
|
|
// given kernel and input tensor index, get which kernel output this tensor.
|
|
|
|
|
// If input tensor is graph input, return nullptr.
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto tensor = kernel->in_tensors().at(in_tensor_index);
|
|
|
|
|
auto in_kernels = kernel->in_kernels();
|
|
|
|
|
auto output_contain = [tensor](const kernel::LiteKernel *kernel) {
|
|
|
|
|
auto out_tensors = kernel->out_tensors();
|
|
|
|
|
return std::find(out_tensors.begin(), out_tensors.end(), tensor) != out_tensors.end();
|
|
|
|
|
};
|
|
|
|
|
auto it = std::find_if(in_kernels.begin(), in_kernels.end(), output_contain);
|
|
|
|
|
if (it == in_kernels.end()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return *it;
|
|
|
|
|
}
|
|
|
|
|
} // namespace mindspore::lite
|
|
|
|
|