|
|
@ -15,11 +15,9 @@
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
#include "src/runtime/kernel/arm/base/matmul_base.h"
|
|
|
|
#include "src/runtime/kernel/arm/base/matmul_base.h"
|
|
|
|
#include "src/runtime/kernel/arm/fp32/matmul.h"
|
|
|
|
#include "src/runtime/kernel/arm/fp32/matmul.h"
|
|
|
|
#include "src/runtime/kernel/arm/int8/matmul_int8.h"
|
|
|
|
|
|
|
|
#include "src/kernel_registry.h"
|
|
|
|
#include "src/kernel_registry.h"
|
|
|
|
#include "include/errorcode.h"
|
|
|
|
#include "include/errorcode.h"
|
|
|
|
#include "include/context.h"
|
|
|
|
#include "include/context.h"
|
|
|
|
#include "src/runtime/kernel/arm/base/dequant.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using mindspore::lite::KernelRegistrar;
|
|
|
|
using mindspore::lite::KernelRegistrar;
|
|
|
|
using mindspore::lite::RET_ERROR;
|
|
|
|
using mindspore::lite::RET_ERROR;
|
|
|
@ -34,35 +32,14 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
|
|
|
|
MS_ASSERT(opParameter != nullptr);
|
|
|
|
MS_ASSERT(opParameter != nullptr);
|
|
|
|
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
|
|
|
|
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
|
|
|
|
|
|
|
|
|
|
|
|
auto *weight_tensor = inputs.at(kWeightIndex);
|
|
|
|
|
|
|
|
auto *restore_data = weight_tensor->data_c();
|
|
|
|
|
|
|
|
auto is_const_quant_weight =
|
|
|
|
|
|
|
|
(restore_data != nullptr) &&
|
|
|
|
|
|
|
|
((weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16));
|
|
|
|
|
|
|
|
if (is_const_quant_weight) {
|
|
|
|
|
|
|
|
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
|
|
|
|
|
|
|
if (dequant_weight == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
|
|
|
|
|
|
|
free(opParameter);
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
weight_tensor->SetData(dequant_weight);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto input_tensor = inputs.at(kInputIndex);
|
|
|
|
auto input_tensor = inputs.at(kInputIndex);
|
|
|
|
auto data_type = input_tensor->data_type();
|
|
|
|
auto data_type = input_tensor->data_type();
|
|
|
|
kernel::LiteKernel *kernel = nullptr;
|
|
|
|
kernel::LiteKernel *kernel = nullptr;
|
|
|
|
if (data_type == kNumberTypeInt8) {
|
|
|
|
if (data_type == kNumberTypeFloat32) {
|
|
|
|
kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
|
|
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
|
|
|
if (is_const_quant_weight) {
|
|
|
|
|
|
|
|
weight_tensor->FreeData();
|
|
|
|
|
|
|
|
weight_tensor->SetData(restore_data);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
free(opParameter);
|
|
|
|
free(opParameter);
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -71,21 +48,9 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
|
|
|
|
delete kernel;
|
|
|
|
delete kernel;
|
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
|
|
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
|
|
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
|
|
|
if (is_const_quant_weight) {
|
|
|
|
|
|
|
|
weight_tensor->FreeData();
|
|
|
|
|
|
|
|
weight_tensor->SetData(restore_data);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (is_const_quant_weight) {
|
|
|
|
|
|
|
|
weight_tensor->FreeData();
|
|
|
|
|
|
|
|
weight_tensor->SetData(restore_data);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return kernel;
|
|
|
|
return kernel;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulKernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulKernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMul, CpuMatmulKernelCreator)
|
|
|
|
|
|
|
|
} // namespace mindspore::kernel
|
|
|
|
} // namespace mindspore::kernel
|
|
|
|