|
|
|
@ -508,25 +508,25 @@ void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data
|
|
|
|
|
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
|
|
|
|
|
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
|
|
|
|
|
int output_channel = conv_param->output_channel_;
|
|
|
|
|
int output_w = conv_param->output_w_;
|
|
|
|
|
int output_h = conv_param->output_h_;
|
|
|
|
|
int out_h_block = UP_DIV(output_h, C4NUM);
|
|
|
|
|
int oc8 = UP_DIV(output_channel, C8NUM);
|
|
|
|
|
// todo outputw --> out_w_block * out_unit
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < real_cal_num; i++) {
|
|
|
|
|
int out_w_index = (start_index + i) % out_w_block;
|
|
|
|
|
int out_h_index = (start_index + i) / out_w_block;
|
|
|
|
|
int src_tile_offset = i * oc8 * C8NUM * 36;
|
|
|
|
|
int dst_tile_offset = 8 * (out_w_index * 4 + out_h_index * 4 * output_w);
|
|
|
|
|
int dst_tile_offset = C8NUM * (out_w_index * C4NUM + out_h_index * C4NUM * out_w_block * C4NUM);
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < oc8; j++) {
|
|
|
|
|
int src_oc8_offset = src_tile_offset + j * 36 * C8NUM;
|
|
|
|
|
int dst_oc8_offset = dst_tile_offset + j * C8NUM * output_h * output_w;
|
|
|
|
|
int dst_oc8_offset = dst_tile_offset + j * C8NUM * out_h_block * out_w_block * C4NUM * C4NUM;
|
|
|
|
|
const float16_t *src_ptr = gemm_out + src_oc8_offset;
|
|
|
|
|
const float16_t *bias_ptr = bias_data + j * C8NUM;
|
|
|
|
|
float16_t *dst_ptr = out_data + dst_oc8_offset;
|
|
|
|
|
|
|
|
|
|
// output transform
|
|
|
|
|
Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, output_w);
|
|
|
|
|
Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, out_w_block * C4NUM);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|