!4148 fix bug of fp16 conv3x3 op

Merge pull request !4148 from fuzhiye/tmp
pull/4148/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 46022afbc6

@ -215,7 +215,7 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
for (int c = 0; c < output_channel; c++) {
int oc8_block = c / C8NUM;
int oc8_res = c % C8NUM;
int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * tile_num +
int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM +
C8NUM * (h * out_w_block * output_unit + w) + oc8_res;
int dst_offset = (h * output_w + w) * output_channel + c;
(output_data + dst_offset)[0] = (tmp_out + src_offset)[0];

@ -508,25 +508,25 @@ void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
int out_h_block = UP_DIV(output_h, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
// todo outputw --> out_w_block * out_unit
for (int i = 0; i < real_cal_num; i++) {
int out_w_index = (start_index + i) % out_w_block;
int out_h_index = (start_index + i) / out_w_block;
int src_tile_offset = i * oc8 * C8NUM * 36;
int dst_tile_offset = 8 * (out_w_index * 4 + out_h_index * 4 * output_w);
int dst_tile_offset = C8NUM * (out_w_index * C4NUM + out_h_index * C4NUM * out_w_block * C4NUM);
for (int j = 0; j < oc8; j++) {
int src_oc8_offset = src_tile_offset + j * 36 * C8NUM;
int dst_oc8_offset = dst_tile_offset + j * C8NUM * output_h * output_w;
int dst_oc8_offset = dst_tile_offset + j * C8NUM * out_h_block * out_w_block * C4NUM * C4NUM;
const float16_t *src_ptr = gemm_out + src_oc8_offset;
const float16_t *bias_ptr = bias_data + j * C8NUM;
float16_t *dst_ptr = out_data + dst_oc8_offset;
// output transform
Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, output_w);
Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, out_w_block * C4NUM);
}
}
}

Loading…
Cancel
Save