!12181 [MSLITE] int8 matmul base

From: @ling_qiao_min
Reviewed-by: 
Signed-off-by:
pull/12181/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 2f1d4f9ef9

@ -182,40 +182,6 @@ void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int
return;
}
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool peroc) {
/* support per-layer && weight per-channel */
/* row4x16-major * row16x4-major => (int8)row-major*/
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c4div = c / C4NUM, c4mod = c % C4NUM;
size_t ci = r * stride + c;
int32_t value = 0;
for (int d = 0; d < deep_16; d++) {
int d16div = d / C16NUM, d16mod = d % C16NUM;
size_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum =
peroc ? input_sum[c4div * UP_ROUND(row, C4NUM) * C4NUM + r * C4NUM + c4mod] : input_sum[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = peroc ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = peroc ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = peroc ? multiplier[c] : multiplier[0];
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
dst[ci] = (int8_t)value;
}
}
return;
}
void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
@ -353,6 +319,105 @@ void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
return;
}
#ifdef ENABLE_ARM64
void PackInput4x4AndInputSumPert_arm64(const int8_t *src_ic, int8_t *pack_ic, int32_t *input_sum_r, size_t src_stride,
size_t ic_4div, size_t ic_4res, int32_t filter_zp) {
asm volatile(
"dup v2.4s, wzr \n"
"mov x14, %[input_sum_r] \n"
"dup v3.4s, %w[filter_zp] \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x15, #0 \n"
"1: \n"
"cmp x15, %[ic_4div] \n"
"add x15, x15, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 1b \n"
"3: \n" /* ic res 1 */
"dup v0.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"4: \n" /* ic res 2 */
"dup v0.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"5: \n" /* ic res 3 */
"dup v0.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"6: \n"
"mul v2.4s, v2.4s, v3.4s \n"
"st1 {v2.4s}, [x14], #16 \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
: "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3");
return;
}
#endif
void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum,
size_t input_channel, size_t plane_size, int32_t filter_zp) {
int ic4 = UP_ROUND(input_channel, C4NUM);
@ -370,99 +435,7 @@ void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input,
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v2.4s, wzr \n"
"mov x14, %[input_sum_r] \n"
"dup v3.4s, %w[filter_zp] \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x15, #0 \n"
"1: \n"
"cmp x15, %[ic_4div] \n"
"add x15, x15, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 1b \n"
"3: \n" /* ic res 1 */
"dup v0.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"4: \n" /* ic res 2 */
"dup v0.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"5: \n" /* ic res 3 */
"dup v0.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"6: \n"
"mul v2.4s, v2.4s, v3.4s \n"
"st1 {v2.4s}, [x14], #16 \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
: "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3");
PackInput4x4AndInputSumPert_arm64(src_ic, pack_ic, input_sum_r, src_stride, ic_4div, ic_4res, filter_zp);
#else
int32_t tmp_sum_value[4] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {

@ -25,12 +25,9 @@
extern "C" {
#endif
/* 4x16 16x4 -> 4x4 */
/* matmul */
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool per_channel);
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst);
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
@ -41,6 +38,7 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp);
/* 8x4 4x8 -> 8x8 */
/* optimize conv */
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
@ -48,6 +46,7 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
size_t per_channel);
/* 4x16 16x2 -> 4x2 */
/* arm32 conv1x1 */
void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
@ -55,6 +54,7 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
bool peroc);
/* 4x4 4x16 -> 4x16 */
/* optimize conv1x1 */
void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum,
size_t input_channel, size_t plane_size, int32_t filter_zp);

@ -66,17 +66,6 @@ typedef struct PreluQuantArg {
QuantArg out_quant_args_;
} PreluQuantArg;
typedef struct MatmulQuantArg {
QuantArg input;
QuantArg weight;
QuantArg output;
int32_t out_act_min;
int32_t out_act_max;
int32_t left_shift;
int32_t right_shift;
int32_t quant_multiplier;
} MatmulQuantArg;
typedef struct CropQuantArg {
QuantArg in_args_;
QuantArg out_args_;

@ -73,4 +73,15 @@ typedef struct MatmulQuantParameter {
int32_t *quant_multiplier_;
} MatmulQuantParameter;
typedef struct MatmulQuantArg {
QuantArg input;
QuantArg weight;
QuantArg output;
int32_t out_act_min;
int32_t out_act_max;
int32_t left_shift;
int32_t right_shift;
int32_t quant_multiplier;
} MatmulQuantArg;
#endif // MINDSPORE_LITE_NNACL_MATMUL_H_

@ -67,10 +67,5 @@ int FullconnectionCPUKernel::ReSize() {
return MatmulFp32BaseCPUKernel::ReSize();
}
int FullconnectionCPUKernel::Run() {
MatmulFp32BaseCPUKernel::Run();
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FullConnection, LiteKernelCreator<FullconnectionCPUKernel>)
} // namespace mindspore::kernel

@ -33,7 +33,6 @@ class FullconnectionCPUKernel : public MatmulFp32BaseCPUKernel {
~FullconnectionCPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FULLCONNECTION_H_

@ -18,52 +18,19 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_FULLCONNECTION_INT8_H_
#include <vector>
#include "src/lite_kernel.h"
#include "include/errorcode.h"
#include "mindspore/lite/nnacl/int8/quantize.h"
#include "nnacl/common_func.h"
#include "nnacl/int8/common_func_int8.h"
#include "nnacl/int8/matmul_int8.h"
#include "src/runtime/kernel/arm/int8/matmul_base_int8.h"
namespace mindspore::kernel {
class FullconnectionInt8CPUKernel : public LiteKernel {
class FullconnectionInt8CPUKernel : public MatmulBaseInt8CPUKernel {
public:
FullconnectionInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
fc_param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
}
~FullconnectionInt8CPUKernel() override {
FreeTmpBuffer();
FreeQuantParam();
}
: MatmulBaseInt8CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~FullconnectionInt8CPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
public:
int RunImpl(int task_id);
private:
void InitParam();
void FreeTmpBuffer();
void FreeQuantParam();
int MallocQuantParam();
private:
MatMulParameter *fc_param_ = nullptr;
MatmulQuantParameter quant_;
int thread_count_ = 1;
int thread_stride_ = 0;
int8_t *pack_a_ptr_ = nullptr;
int8_t *pack_b_ptr_ = nullptr;
int8_t *c_ptr_ = nullptr;
int *input_sums_ = nullptr;
int *weight_bias_sums_ = nullptr;
int *bias_ptr_ = nullptr;
bool filter_per_channel_ = true;
};
} // namespace mindspore::kernel

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_BASE_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_BASE_INT8_H_
#include <vector>
#include "include/errorcode.h"
#include "include/context.h"
#include "src/lite_kernel.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/common_func.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/int8/common_func_int8.h"
#include "nnacl/int8/matmul_int8.h"
namespace mindspore::kernel {
class MatmulBaseInt8CPUKernel : public LiteKernel {
public:
MatmulBaseInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
}
~MatmulBaseInt8CPUKernel() override;
int Init() override;
int ReSize() override;
int Run() override;
public:
int RunImpl(int task_id);
protected:
void InitParameter();
private:
void ResizeParameter();
int InitBias();
private:
int InitTmpBuffer();
void FreeTmpBuffer();
void TransferA();
void TransferB();
private:
int MallocQuantParam();
void FreeQuantParam();
void InitQuantParam();
protected:
MatMulParameter *param_ = nullptr;
MatmulQuantParameter quant_;
int thread_count_ = 1;
int thread_stride_ = 0;
int8_t *pack_a_ptr_ = nullptr;
int8_t *pack_b_ptr_ = nullptr;
int *input_sums_ = nullptr;
int *weight_bias_sums_ = nullptr;
int *bias_ptr_ = nullptr;
bool filter_per_channel_ = true;
int8_t *batch_b_ptr_ = nullptr;
int8_t *batch_c_ptr_ = nullptr;
int *batch_sums_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_BASE_INT8_H_

@ -22,46 +22,27 @@
#include "src/kernel_registry.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_MatMul;
namespace mindspore::kernel {
MatmulInt8CPUKernel::~MatmulInt8CPUKernel() { FreeTmpBuffer(); }
int MatmulInt8CPUKernel::Init() {
InitParameter();
void MatmulInt8CPUKernel::FreeTmpBuffer() {
if (a_r4x16_ptr_ != nullptr) {
context_->allocator->Free(a_r4x16_ptr_);
a_r4x16_ptr_ = nullptr;
}
if (input_sums_ != nullptr) {
context_->allocator->Free(input_sums_);
input_sums_ = nullptr;
}
if (b_c16x4_batch_ != nullptr) {
context_->allocator->Free(b_c16x4_batch_);
b_c16x4_batch_ = nullptr;
}
if (weight_bias_sums_batch_ != nullptr) {
context_->allocator->Free(weight_bias_sums_batch_);
weight_bias_sums_batch_ = nullptr;
}
if (bias_ptr_ != nullptr) {
context_->allocator->Free(bias_ptr_);
bias_ptr_ = nullptr;
auto ret = MatmulBaseInt8CPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch failed";
return ret;
}
return;
}
int MatmulInt8CPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int MatmulInt8CPUKernel::ReSize() {
FreeTmpBuffer();
int batch = 1;
auto x_shape = in_tensors_.at(0)->shape();
auto o_shape = out_tensors_.at(0)->shape();
@ -69,159 +50,19 @@ int MatmulInt8CPUKernel::ReSize() {
for (size_t i = 0; i < x_shape.size() - 2; ++i) {
batch *= x_shape[i];
}
params_->batch = batch;
param_->batch = batch;
MS_ASSERT(o_shape.size() >= 2);
params_->row_ = o_shape[o_shape.size() - 2];
params_->col_ = o_shape[o_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1];
params_->row_4_ = UP_ROUND(params_->row_, 4);
params_->col_4_ = UP_ROUND(params_->col_, 4);
params_->deep_16_ = UP_ROUND(params_->deep_, 16);
a_r4x16_ptr_ =
reinterpret_cast<int8_t *>(context_->allocator->Malloc(params_->row_4_ * params_->deep_16_ * sizeof(int8_t)));
if (!a_r4x16_ptr_) return RET_MEMORY_FAILED;
memset(a_r4x16_ptr_, 0, params_->row_4_ * params_->deep_16_ * sizeof(int8_t));
input_sums_ = reinterpret_cast<int *>(context_->allocator->Malloc(params_->row_4_ * sizeof(int)));
if (!input_sums_) return RET_MEMORY_FAILED;
memset(input_sums_, 0, params_->row_4_ * sizeof(int));
b_c16x4_batch_ = reinterpret_cast<int8_t *>(
context_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t)));
if (!b_c16x4_batch_) return RET_MEMORY_FAILED;
memset(b_c16x4_batch_, 0, params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t));
weight_bias_sums_batch_ =
reinterpret_cast<int *>(context_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int)));
if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED;
memset(weight_bias_sums_batch_, 0, params_->batch * params_->col_4_ * sizeof(int));
if (in_tensors_.size() == 3) {
auto bias_size = params_->col_4_ * sizeof(int);
bias_ptr_ = reinterpret_cast<int *>(context_->allocator->Malloc(bias_size));
if (!bias_ptr_) return RET_MEMORY_FAILED;
memcpy(bias_ptr_, in_tensors_[2]->data_c(), bias_size);
} else {
bias_ptr_ = NULL;
}
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_4_, 4));
thread_stride_ = UP_DIV(UP_DIV(params_->col_4_, 4), thread_count_);
auto input_tensor = in_tensors_.at(0);
auto params = input_tensor->quant_params();
MS_ASSERT(params.size() == 1);
quant_params_.input.zp_ = params.front().zeroPoint;
quant_params_.input.scale_ = params.front().scale;
auto weight_tensor = in_tensors_.at(1);
params = weight_tensor->quant_params();
MS_ASSERT(params.size() == 1);
quant_params_.weight.zp_ = params.front().zeroPoint;
quant_params_.weight.scale_ = params.front().scale;
auto output_tensor = out_tensors_.at(0);
params = output_tensor->quant_params();
MS_ASSERT(params.size() == 1);
quant_params_.output.zp_ = params.front().zeroPoint;
quant_params_.output.scale_ = params.front().scale;
params_->b_const_ = (in_tensors_.at(1)->data_c() != nullptr);
if (params_->b_const_) {
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
for (int i = 0; i < params_->batch; ++i) {
auto cur_b = b_ptr + i * params_->deep_ * params_->col_;
auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_;
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_;
if (params_->b_transpose_) {
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_,
bias_ptr_, cur_sums, ColMajor, false);
} else {
RowMajor2Col16x4MajorInt8(cur_b, params_->deep_, params_->col_, cur_b_pack);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_,
bias_ptr_, cur_sums, RowMajor, false);
}
}
}
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
return RET_OK;
}
int MatmulInt8CPUKernel::RunImpl(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_4_, 4) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM);
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * params_->deep_16_;
auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * 4;
auto cur_c = c_ptr_ + task_id * thread_stride_ * 4;
auto &p = quant_params_;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, params_->row_4_, cur_oc * C4NUM, params_->deep_16_, input_sums_,
cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift,
params_->row_, cur_oc_res, params_->col_ * sizeof(int8_t), false);
#else
MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, params_->deep_16_, params_->col_,
input_sums_, cur_bias, &p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN,
INT8_MAX, false);
#endif
return RET_OK;
}
param_->row_ = o_shape[o_shape.size() - 2];
param_->col_ = o_shape[o_shape.size() - 1];
param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1];
int MatmulInt8Run(void *cdata, int task_id) {
auto op = reinterpret_cast<MatmulInt8CPUKernel *>(cdata);
auto ret = op->RunImpl(task_id);
auto ret = MatmulBaseInt8CPUKernel::ReSize();
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
MS_LOG(ERROR) << "MatmulBaseInt8CPUKernel failed";
return ret;
}
return RET_OK;
}
int MatmulInt8CPUKernel::Run() {
auto a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data_c());
auto c_ptr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data_c());
auto a_stride = params_->row_ * params_->deep_;
auto b_stride = params_->deep_ * params_->col_;
auto c_stride = params_->row_ * params_->col_;
if (!params_->b_const_) {
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
for (int i = 0; i < params_->batch; ++i) {
auto cur_b = b_ptr + i * b_stride;
auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_;
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_;
if (params_->b_transpose_) {
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_,
bias_ptr_, cur_sums, ColMajor, false);
} else {
RowMajor2Col16x4MajorInt8(cur_b, params_->deep_, params_->col_, cur_b_pack);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_,
bias_ptr_, cur_sums, RowMajor, false);
}
}
}
for (int i = 0; i < params_->batch; ++i) {
auto cur_a_ptr = a_ptr + i * a_stride;
if (params_->a_transpose_) {
RowMajor2Col16x4MajorInt8(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor);
} else {
RowMajor2Row16x4MajorInt8(cur_a_ptr, a_r4x16_ptr_, params_->row_, params_->deep_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor);
}
b_c16x4_ptr_ = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_;
weight_bias_sums_ = weight_bias_sums_batch_ + i * params_->col_4_;
c_ptr_ = c_ptr + i * c_stride;
auto ret = ParallelLaunch(this->context_->thread_pool_, MatmulInt8Run, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]";
return ret;
}
}
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMul, LiteKernelCreator<MatmulInt8CPUKernel>)
} // namespace mindspore::kernel

@ -22,39 +22,18 @@
#include "nnacl/matmul_parameter.h"
#include "mindspore/lite/nnacl/int8/quantize.h"
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/int8/matmul_base_int8.h"
using mindspore::lite::InnerContext;
namespace mindspore::kernel {
class MatmulInt8CPUKernel : public LiteKernel {
class MatmulInt8CPUKernel : public MatmulBaseInt8CPUKernel {
public:
MatmulInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
params_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
}
~MatmulInt8CPUKernel() override;
: MatmulBaseInt8CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~MatmulInt8CPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
private:
void FreeTmpBuffer();
private:
MatMulParameter *params_ = nullptr;
MatmulQuantArg quant_params_;
int8_t *a_r4x16_ptr_ = nullptr;
int8_t *b_c16x4_ptr_ = nullptr;
int8_t *c_ptr_ = nullptr;
int8_t *b_c16x4_batch_ = nullptr;
int *bias_ptr_ = nullptr;
int *input_sums_ = nullptr;
int *weight_bias_sums_ = nullptr;
int *weight_bias_sums_batch_ = nullptr;
int thread_stride_ = 0;
int thread_count_ = 0;
};
} // namespace mindspore::kernel

@ -599,9 +599,9 @@ function Run_x86() {
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out --accuracyThreshold=${accuracy_limit}' >> "${run_x86_log_file}"
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out --accuracyThreshold=${accuracy_limit} >> "${run_x86_log_file}"
if [ $? = 0 ]; then
run_result='x86: '${model_name}'[weight quant] pass'; echo ${run_result} >> ${run_benchmark_result_file}
run_result='x86: '${model_name}'[weight_quant] pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='x86: '${model_name}'[weight quant] failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
run_result='x86: '${model_name}'[weight_quant] failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_mindspore_weightquant_config}

@ -79,58 +79,6 @@ void MMInt8TestInit(std::vector<lite::Tensor *> *inputs, std::vector<lite::Tenso
delete[] weight_data;
}
TEST_F(TestMatmulInt8, simple) {
#define ROW 10
#define COL 15
#define DEPTH 10
#define ROW4 UP_ROUND(ROW, 4)
#define COL4 UP_ROUND(COL, 4)
#define DEPTH16 UP_ROUND(DEPTH, 16)
int8_t a[ROW * DEPTH] = {-3, -3, 0, -2, -4, -2, 1, 0, -1, 0, 5, 1, 3, 4, 4, -3, -5, 2, -2, 4,
4, 5, 1, -1, 5, 5, 2, -1, 0, 4, -4, 2, 5, -2, 5, 3, -1, 2, -4, 5,
-5, 4, 5, 3, 5, 4, -2, 5, 5, -5, -5, -5, 2, -4, -3, 3, -3, -5, 5, 0,
2, -4, 4, 2, -5, 3, -1, 3, -3, 2, -5, -4, 0, -5, 2, 4, 0, -5, -1, 4,
3, 5, 5, 2, -5, -5, -4, -5, 3, 3, 3, 0, -2, 0, -2, -3, -2, 3, 5, -5};
int8_t b[DEPTH * COL] = {1, 2, -2, -5, -4, 2, 3, 2, -5, 4, -5, 4, 1, -2, 1, 5, 5, 5, 2, 5, -3, -3,
-1, -3, -1, 0, -4, 0, 1, -2, -2, -3, -5, 1, 1, 0, 4, 5, -3, -1, 4, 3, 5, 4,
2, 4, -3, -4, 1, 4, -4, 5, -1, -2, 3, 5, 5, 2, 1, -4, 1, 2, -3, 0, -2, 4,
-3, -3, 1, 3, 4, -1, 3, 1, -5, -1, 2, 0, 0, 5, -1, -5, 5, -5, 0, 3, -3, 4,
3, 1, -3, -3, 2, -2, -3, -3, 3, 4, 2, -1, 2, 0, -2, 4, 5, 3, -1, -3, -2, -1,
4, 3, -5, 1, 0, 0, -1, -4, -3, -2, 5, 3, 2, 1, -4, 1, 4, 5, -1, 2, -2, 2,
1, -2, 5, 2, -4, -4, 1, 1, 2, -1, -5, -4, 4, 1, -3, 4, -1, -4};
int8_t correct[ROW * COL] = {
-36, -33, 11, 4, -12, -7, 11, 0, 37, -30, -13, -2, -30, -3, 29, 46, -13, -84, -8, 6, 39, 26,
-67, -48, 57, 12, 32, 44, -24, -85, 22, 32, -8, -8, 20, 10, -45, 12, -69, 36, 22, -37, 58, 27,
-24, -11, -22, -50, 26, 50, 28, -56, -42, -23, -1, 70, -58, 54, 35, -61, 54, 40, -11, 35, 43, 3,
7, 30, -7, -13, 73, -3, 26, 26, -11, -37, 0, 19, 34, -4, 0, -22, 71, 8, -25, -6, -5, 31,
8, 63, -25, -55, -62, -17, 23, 1, 36, 12, -38, 2, 11, 27, 18, 5, 4, -59, -17, 1, 25, 9,
13, -77, 13, 9, -11, 26, -52, 42, 28, 6, 44, 4, 2, 26, 19, -31, 46, 23, -57, 15, -31, 39,
40, -9, 8, 38, 40, 27, -19, -47, 14, 50, 14, 18, 0, -59, 39, -48, -47, 35};
int8_t output[ROW * COL] = {0};
int8_t *a_r4x16 = new int8_t[ROW4 * DEPTH16];
memset(a_r4x16, 0, ROW4 * DEPTH16);
int8_t *b_c16x4 = new int8_t[COL4 * DEPTH16];
memset(b_c16x4, 0, COL4 * DEPTH16);
RowMajor2Row16x4MajorInt8(a, a_r4x16, ROW, DEPTH);
RowMajor2Col16x4MajorInt8(b, DEPTH, COL, b_c16x4);
int a_sums[ROW4] = {0};
int bias[COL4] = {0};
int multiplier, ls, rs;
QuantizeRoundParameterWithDoublePrecision(1.0f, &multiplier, &ls, &rs);
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls,
&rs, ROW, COL, COL, false);
#else
MatMulInt8_16x4_r(a_r4x16, b_c16x4, output, ROW, COL, DEPTH16, COL, a_sums, bias, &ls, &rs, &multiplier, 0, INT8_MIN,
INT8_MAX, false);
#endif
ASSERT_EQ(0, CompareOutputData(output, correct, ROW * COL, 0.1));
delete[] a_r4x16;
delete[] b_c16x4;
}
TEST_F(TestMatmulInt8, mmtest1) {
float in[] = {6.583835634764597, 11.337275140963907, -4.125256949459629, 10.994337291530833,
19.086065139532636, 3.620842999158455, 13.167624585590346, -18.326739299407755,

Loading…
Cancel
Save