add avx support for fp32 winograd conv

pull/9918/head
lixian 4 years ago
parent d0bd77c1ef
commit 25f3866be7

@ -53,7 +53,7 @@ NoC8Steps:
movq $4, %r12
imul %r10, %r12
imul %rbx, %r12
movq $48, %r13
movq $32, %r13
imul %r10, %r13
NoWinoSteps:
movq $4, %rax
@ -827,32 +827,35 @@ LoopRow:
jmp WriteEnd
WriteWino:
movq %rdx, %rax
addq %r13, %rdx
movq %rdx, %r15
addq %r13, %rdx
movq %rdx, -80(%rsp)
vmovups %ymm4, (%rax)
vmovups %ymm5, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm6, (%rax)
vmovups %ymm7, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm8, (%rax)
vmovups %ymm9, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm10, (%rax)
vmovups %ymm11, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm12, (%rax)
vmovups %ymm13, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm14, (%rax)
vmovups %ymm15, (%r15)
addq %r13, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
addq %r12, %rdx
vmovups %ymm6, (%rdx)
addq %r12, %rdx
vmovups %ymm8, (%rdx)
addq %r12, %rdx
vmovups %ymm10, (%rdx)
addq %r12, %rdx
vmovups %ymm12, (%rdx)
addq %r12, %rdx
vmovups %ymm14, (%rdx)
cmpq $-8, %rbx
je WriteEnd
movq %rax, %rdx
addq %r13, %rax
movq %rax, -80(%rsp)
vmovups %ymm5, (%rdx)
addq %r12, %rdx
vmovups %ymm7, (%rdx)
addq %r12, %rdx
vmovups %ymm9, (%rdx)
addq %r12, %rdx
vmovups %ymm11, (%rdx)
addq %r12, %rdx
vmovups %ymm13, (%rdx)
addq %r12, %rdx
vmovups %ymm15, (%rdx)
jmp WriteEnd
Write16:
movq %rdx, %rax

@ -73,6 +73,12 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
int output_count = out_w_block * out_h_block;
const int tile_num = C12NUM;
int output_tile_count = UP_DIV(output_count, tile_num);
#ifdef ENABLE_AVX
const int col_tile = C16NUM;
#else
const int col_tile = C8NUM;
#endif
int oc_tile = UP_DIV(conv_param->output_channel_, col_tile);
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
@ -101,13 +107,15 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
#ifdef ENABLE_AVX
RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
#else
RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
#endif
MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
cal_num, oc8 * C8NUM, input_unit_square, 2);
MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0,
in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2);
}
// step 4 : output transform

@ -793,8 +793,6 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
return;
}
#ifdef ENABLE_AVX
#ifdef WIN32
void MatMul6x16(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type) {
if (out_type == OutType_Nhwc) {
@ -815,11 +813,29 @@ void MatMul6x16(const float *a, const float *b, float *dst, const float *bias, A
dst[ci] = value;
}
}
} else {
for (int i = 0; i < row; ++i) {
int dst_r_offset = i * col * stride;
int r6div = i / C6NUM, r6mod = i % C6NUM;
for (int j = 0; j < col; ++j) {
int b16div = j / C16NUM, b16mod = j % C16NUM;
int c8div = j / C8NUM, c8mod = j % C8NUM;
size_t ci = dst_r_offset + c8div * C8NUM * stride + c8mod;
float value = 0;
for (int d = 0; d < deep; ++d) {
size_t ai = r6div * deep * C6NUM + d * C6NUM + r6mod;
size_t bi = b16div * deep * C16NUM + d * C16NUM + b16mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[j];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
}
return;
}
#endif
#endif
void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type) {
@ -862,16 +878,14 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
}
#elif ENABLE_AVX
if (out_type == OutType_Nhwc) {
if (out_type == OutType_C8) {
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
} else {
#ifdef WIN32
MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#else
MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
#endif
} else if (out_type == OutType_C8) {
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
} else {
MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
}
#elif ENABLE_SSE
if (out_type == OutType_C8) {

@ -49,8 +49,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
conv_param_->output_channel_ = out_channel;
int oc4 = UP_DIV(out_channel, C4NUM);
#ifdef ENABLE_AVX
const int oc_block = C16NUM;
#else
const int oc_block = C8NUM;
int oc_block_num = UP_DIV(out_channel, C8NUM);
#endif
int oc_block_num = UP_DIV(out_channel, oc_block);
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float);

Loading…
Cancel
Save