|
|
|
@ -33,7 +33,7 @@ int MatmulBaseInt8Run(void *cdata, int task_id) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MatmulBaseInt8CPUKernel::RunImpl(int task_id) {
|
|
|
|
|
int stride = thread_stride_ * C4NUM;
|
|
|
|
|
int stride = thread_stride_ * col_tile_;
|
|
|
|
|
int cur_stride = task_id * stride;
|
|
|
|
|
int res_stride = param_->col_ - cur_stride;
|
|
|
|
|
int cur_oc = MSMIN(stride, res_stride);
|
|
|
|
@ -155,16 +155,23 @@ void MatmulBaseInt8CPUKernel::InitQuantParam() {
|
|
|
|
|
void MatmulBaseInt8CPUKernel::InitParameter() {
|
|
|
|
|
param_->a_const_ = (in_tensors_[0]->data_c() != nullptr);
|
|
|
|
|
param_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
|
|
|
|
|
#ifdef ENABLE_ARM32
|
|
|
|
|
row_tile_ = C4NUM;
|
|
|
|
|
col_tile_ = C2NUM;
|
|
|
|
|
#else
|
|
|
|
|
row_tile_ = C4NUM;
|
|
|
|
|
col_tile_ = C4NUM;
|
|
|
|
|
#endif
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MatmulBaseInt8CPUKernel::ResizeParameter() {
|
|
|
|
|
param_->row_align_ = UP_ROUND(param_->row_, C4NUM);
|
|
|
|
|
param_->col_align_ = UP_ROUND(param_->col_, C4NUM);
|
|
|
|
|
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
|
|
|
|
|
param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
|
|
|
|
|
param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM);
|
|
|
|
|
|
|
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, C4NUM));
|
|
|
|
|
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, C4NUM), thread_count_);
|
|
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
|
|
|
|
|
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -195,11 +202,19 @@ void MatmulBaseInt8CPUKernel::TransferB() {
|
|
|
|
|
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_16_;
|
|
|
|
|
auto current_sums = weight_bias_sums_ + i * param_->col_align_;
|
|
|
|
|
if (param_->b_transpose_) {
|
|
|
|
|
#ifdef ENABLE_ARM32
|
|
|
|
|
RowMajor2Row2x16MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_);
|
|
|
|
|
#else
|
|
|
|
|
RowMajor2Row16x4MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_);
|
|
|
|
|
#endif
|
|
|
|
|
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_,
|
|
|
|
|
current_sums, ColMajor, filter_per_channel_);
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef ENABLE_ARM32
|
|
|
|
|
RowMajor2Col16x2MajorInt8(current_weight, current_b_pack, param_->deep_, param_->col_);
|
|
|
|
|
#else
|
|
|
|
|
RowMajor2Col16x4MajorInt8(current_weight, param_->deep_, param_->col_, current_b_pack);
|
|
|
|
|
#endif
|
|
|
|
|
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_,
|
|
|
|
|
current_sums, RowMajor, false);
|
|
|
|
|
}
|
|
|
|
|