[MSLITE] int8 conv1x1 parallel by oc

pull/8811/head
ling 4 years ago
parent 6f6925805e
commit b7b5623506

@ -375,7 +375,7 @@ int Convolution1x1Int8CPUKernel::InitParam() {
thread_stride_hw_ = UP_DIV(hw_thread_count, thread_count_hw_);
thread_count_oc_ = MSMIN(op_parameter_->thread_num_, oc_thread_count);
thread_stride_oc_ = UP_DIV(oc_thread_count, thread_count_oc_);
parallel_by_oc_ = hw_thread_count < op_parameter_->thread_num_;
parallel_by_oc_ = oc_thread_count > op_parameter_->thread_num_;
return RET_OK;
}
@ -521,11 +521,12 @@ int Convolution1x1Int8CPUKernel::RunArm64OptOc(int task_id) {
int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_;
int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_;
int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_;
int32_t *cur_zp = filter_peroc_ ? filter_zp_ptr_ + cur_stride : filter_zp_ptr_;
Conv1x1Int8Opt(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_4_, output_ptr_ + cur_stride,
input_sum_, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
matmul_param_->deep_4_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_, matmul_func_,
filter_zp_ptr_);
cur_zp);
return RET_OK;
}

Loading…
Cancel
Save