fix matmul_fp32 creator bug && support batch_matmul quantize

pull/7783/head
jianghui58 4 years ago
parent 4c129b12e4
commit e1d2f17d62

@ -38,7 +38,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
MS_ASSERT(desc.type == schema::PrimitiveType_FullConnection);
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();

@ -18,12 +18,18 @@
#include "include/errorcode.h"
#include "nnacl/fp32/matmul.h"
#include "src/runtime/runtime_api.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INPUT_TENSOR_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::schema::PrimitiveType_MatMul;
namespace mindspore::kernel {
MatmulCPUKernel::~MatmulCPUKernel() { FreeTmpBuffer(); }
@ -328,4 +334,56 @@ void MatmulCPUKernel::eval() {
}
}
kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_MatMul);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
}
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulFp32KernelCreator)
} // namespace mindspore::kernel

@ -19,16 +19,10 @@
#include "nnacl/common_func.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::schema::PrimitiveType_MatMul;
namespace mindspore::kernel {
MatmulInt8CPUKernel::~MatmulInt8CPUKernel() { FreeTmpBuffer(); }
@ -199,61 +193,4 @@ int MatmulInt8CPUKernel::Run() {
}
return RET_OK;
}
kernel::LiteKernel *CpuMatmulInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
bool is_const_quant_weight = !weight_tensor->GetQuantParams().empty() &&
weight_tensor->GetQuantParams().front().inited && restore_data != nullptr;
if (is_const_quant_weight) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
if (data_type == kNumberTypeInt8) {
kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (is_const_quant_weight) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
}
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (is_const_quant_weight) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
}
return nullptr;
}
if (is_const_quant_weight) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMul, CpuMatmulInt8KernelCreator)
} // namespace mindspore::kernel

@ -132,12 +132,12 @@ template <typename T>
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
int quant_max, int quant_min, size_t bitNum, bool per_channel) {
auto dims = weight->tensor_shape();
auto op_type = (schema::PrimitiveType)primitive_c->Type();
if (per_channel) {
if (dims.size() != 4 && dims.size() != 2) {
if (dims.size() != 4 && dims.size() != 2 && op_type != schema::PrimitiveType_MatMul) {
MS_LOG(INFO) << "weight dims size: " << dims.size() << " switch to per-layer quant mode.";
per_channel = false;
} else {
auto op_type = (schema::PrimitiveType)primitive_c->Type();
if (dims.size() == 2 && op_type != schema::PrimitiveType_FullConnection) {
MS_LOG(INFO) << "weight dims size is 2 but op_type is not FullConnection, switch to per-layer quant mode.";
per_channel = false;

Loading…
Cancel
Save