From e1d2f17d6270ba041229d1546d1bc4f07cf21f63 Mon Sep 17 00:00:00 2001 From: jianghui58 Date: Mon, 26 Oct 2020 20:04:41 +0800 Subject: [PATCH] fix matmul_fp32 creator bug && support batch_matmul quantize --- .../kernel/arm/base/fullconnection_base.cc | 2 +- .../src/runtime/kernel/arm/fp32/matmul.cc | 58 +++++++++++++++++ .../runtime/kernel/arm/int8/matmul_int8.cc | 63 ------------------- .../tools/converter/quantizer/quantize_util.h | 4 +- 4 files changed, 61 insertions(+), 66 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc index 926f4c2e02..621c6e0fba 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -38,7 +38,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectordata_c(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc index 43da981730..ad151773b3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc @@ -18,12 +18,18 @@ #include "include/errorcode.h" #include "nnacl/fp32/matmul.h" #include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INPUT_TENSOR_ERROR; using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::schema::PrimitiveType_MatMul; + namespace mindspore::kernel { MatmulCPUKernel::~MatmulCPUKernel() { FreeTmpBuffer(); } @@ -328,4 +334,56 @@ void MatmulCPUKernel::eval() { } } +kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_MatMul); + + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->data_c(); + bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; + if (dequant_flag) { + auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + free(opParameter); + return nullptr; + } + weight_tensor->set_data(dequant_weight); + } + + auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + if (dequant_flag) { + weight_tensor->FreeData(); + weight_tensor->set_data(restore_data); + } + free(opParameter); + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + if (dequant_flag) { + weight_tensor->FreeData(); + weight_tensor->set_data(restore_data); + } + return nullptr; + } + + if (dequant_flag) { + weight_tensor->FreeData(); + weight_tensor->set_data(restore_data); + } + + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index 1d0a4ed497..1d4e96f8af 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -19,16 +19,10 @@ #include "nnacl/common_func.h" #include "src/runtime/runtime_api.h" #include "include/errorcode.h" -#include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::schema::PrimitiveType_MatMul; - namespace mindspore::kernel { MatmulInt8CPUKernel::~MatmulInt8CPUKernel() { FreeTmpBuffer(); } @@ -199,61 +193,4 @@ int MatmulInt8CPUKernel::Run() { } return RET_OK; } -kernel::LiteKernel *CpuMatmulInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::InnerContext *ctx, const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - bool is_const_quant_weight = !weight_tensor->GetQuantParams().empty() && - weight_tensor->GetQuantParams().front().inited && restore_data != nullptr; - if (is_const_quant_weight) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data(dequant_weight); - } - - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - if (data_type == kNumberTypeInt8) { - kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); - } - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - if (is_const_quant_weight) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - } - free(opParameter); - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (is_const_quant_weight) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - } - return nullptr; - } - - if (is_const_quant_weight) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - } - - return kernel; -} -REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMul, CpuMatmulInt8KernelCreator) - } // namespace mindspore::kernel diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index ed1fa0cb85..eb85e93457 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -132,12 +132,12 @@ template STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitive_c, QuantType quantType, int quant_max, int quant_min, size_t bitNum, bool per_channel) { auto dims = weight->tensor_shape(); + auto op_type = (schema::PrimitiveType)primitive_c->Type(); if (per_channel) { - if (dims.size() != 4 && dims.size() != 2) { + if (dims.size() != 4 && dims.size() != 2 && op_type != schema::PrimitiveType_MatMul) { MS_LOG(INFO) << "weight dims size: " << dims.size() << " switch to per-layer quant mode."; per_channel = false; } else { - auto op_type = (schema::PrimitiveType)primitive_c->Type(); if (dims.size() == 2 && op_type != schema::PrimitiveType_FullConnection) { MS_LOG(INFO) << "weight dims size is 2 but op_type is not FullConnection, switch to per-layer quant mode."; per_channel = false;