|
|
@ -140,18 +140,21 @@ int MatmulInt8CPUKernel::Run() {
|
|
|
|
|
|
|
|
|
|
|
|
if (params_->a_transpose_) {
|
|
|
|
if (params_->a_transpose_) {
|
|
|
|
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_);
|
|
|
|
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_);
|
|
|
|
|
|
|
|
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_);
|
|
|
|
RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_);
|
|
|
|
|
|
|
|
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (params_->b_transpose_) {
|
|
|
|
if (params_->b_transpose_) {
|
|
|
|
RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_);
|
|
|
|
RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_);
|
|
|
|
|
|
|
|
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
|
|
|
|
|
|
|
|
NULL, weight_bias_sums_, ColMajor);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_);
|
|
|
|
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_);
|
|
|
|
|
|
|
|
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
|
|
|
|
|
|
|
|
NULL, weight_bias_sums_, RowMajor);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
c_ptr_ = c_ptr + i * c_stride;
|
|
|
|
c_ptr_ = c_ptr + i * c_stride;
|
|
|
|
auto &q = quant_params_;
|
|
|
|
|
|
|
|
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, input_sums_);
|
|
|
|
|
|
|
|
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, weight_bias_sums_);
|
|
|
|
|
|
|
|
ret = ParallelLaunch(THREAD_POOL_DEFAULT, MatmulInt8Run, this, thread_count_);
|
|
|
|
ret = ParallelLaunch(THREAD_POOL_DEFAULT, MatmulInt8Run, this, thread_count_);
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]";
|
|
|
|
MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]";
|
|
|
|