!10227 [ms][lite][cpu] win x86 avx optimize

From: @lzkcode
Reviewed-by: 
Signed-off-by:
pull/10227/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3f0aeaa8fc

@ -1,15 +1,16 @@
#ifdef ENABLE_AVX
#ifndef WIN32
.text
.align 4
.global ConvDwFp32Avx3x3
#ifndef __APPLE__
#ifndef WIN32
.type ConvDwFp32Avx3x3, %function
#endif
#endif
// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width,
// size_t input_stride, size_t relu)
// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, size_t output_width,
// size_t input_stride, size_t relum, szie_t relu6)
// in linux x64 platfrom:
// rdi: output
// rsi: input
// rdx: weights
@ -20,6 +21,16 @@
// 16: relu
// 24: relu6
// in win x64 platfrom: "shadow space" needs to be opened up for first four parameters ==> 32 bites
// rcx: output
// rdx: input
// r8: weights
// r9: bias
// 40: channels
// 48: output_width
// 56: input_stride
// 64: relu
// 72: relu6
ConvDwFp32Avx3x3:
pushq %r15
pushq %r14
@ -27,14 +38,34 @@ ConvDwFp32Avx3x3:
pushq %r12
pushq %rbx
pushq %rbp
pushq %r9
pushq %r8
pushq %rcx
pushq %rdx
pushq %rsi
pushq %rdi
pushq %r9 // -56
pushq %r8 // -64
pushq %rcx // -72
pushq %rdx // -80
pushq %rsi // -88
pushq %rdi // -96
addq $96, %rsp
#ifdef WIN32
movq %rcx, %rdi
movq %rdx, %rsi
movq %r8, %rdx
movq %r9, %rcx
movq 40(%rsp), %r8 // channels
movq 48(%rsp), %r9 // output_width
mov %rdx, -80(%rsp)
mov %rcx, -72(%rsp)
mov %r9, -56(%rsp)
mov %r8, -64(%rsp)
movq 56(%rsp), %rbp // input_stride
movq %rbp, 8(%rsp)
movq 64(%rsp), %rbp // relu
movq %rbp, 16(%rsp)
movq 72(%rsp), %rbp // relu6
movq %rbp, 24(%rsp)
#endif
movq $6, %rax
vcvtsi2ss %rax, %xmm15, %xmm15
vshufps $0, %xmm15, %xmm15, %xmm15
@ -270,4 +301,3 @@ End:
popq %r15
retq
#endif
#endif

@ -1,14 +1,16 @@
#ifdef ENABLE_AVX
#ifndef WIN32
.text
.align 4
.global MatmulFloatAvxOpt
#ifndef __APPLE__
#ifndef WIN32
.type MatmulFloatAvxOpt, %function
#endif
#endif
// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// int row, int col, size_t stride, size_t writeMode)
// parameters pass in Linux x86 platform:
// rdi: a
// rsi: b
// rdx: c
@ -20,6 +22,18 @@
// 24: stride
// 32: writeNhwc/writeWino
// parameters pass in win x64 platfrom: "shadow space" needs to be opened up for first four parameters ==> 32 bites
// rcx: a
// rdx: b
// r8: c
// r9: bias
// 40: act_type
// 48: depth
// 56: row
// 64: col
// 72: stride
// 80: writeMode
MatmulFloatAvxOpt:
// rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention
pushq %r15
@ -28,14 +42,37 @@ MatmulFloatAvxOpt:
pushq %r12
pushq %rbx
pushq %rbp
pushq %r9
pushq %r8
pushq %rcx
pushq %rdx
pushq %rsi
pushq %rdi
addq $96, %rsp
pushq %r9 // -56
pushq %r8 // -64
pushq %rcx // -72
pushq %rdx // -80
pushq %rsi // -88
pushq %rdi // -96
pushq %rsi // -104 rsi
pushq %rdi // -112 rdi
addq $112, %rsp
#ifdef WIN32
movq %rcx, %rdi
movq %rdx, %rsi
movq %r8, %rdx
movq %r9, %rcx
movq 40(%rsp), %r8 // act_type
movq 48(%rsp), %r9 // depth
movq %r9, -56(%rsp) // r9
movq %rcx, -72(%rsp) // rcx
movq %rdx, -80(%rsp) // rdx
movq %rsi, -88(%rsp) // rsi
movq %rdi, -96(%rsp) // rdi
movq 56(%rsp), %rbp // row
movq %rbp, 8(%rsp)
movq 64(%rsp), %rbp // col
movq %rbp, 16(%rsp)
movq 72(%rsp), %rbp // stride
movq %rbp, 24(%rsp)
movq 80(%rsp), %rbp // weiteMode
movq %rbp, 32(%rsp)
#endif
movq 8(%rsp), %rbp
movq 16(%rsp), %rbx
movq 24(%rsp), %r10
@ -926,10 +963,12 @@ LoopRow:
jmp LoopRow
LoopRowEnd:
subq $96, %rsp
subq $112, %rsp
popq %rdi
popq %rsi
popq %rdx
popq %rdx
popq %rdx
popq %rcx
popq %r8
popq %r9
@ -941,4 +980,3 @@ LoopRowEnd:
popq %r15
retq
#endif
#endif

@ -681,47 +681,6 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
#endif
#ifdef ENABLE_AVX
#ifdef WIN32
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
do {
float *in[kernel];
for (int k = 0; k < kernel; k++) {
in[k] = input[k];
}
input = input + input_stride;
size_t c = channels;
const float *w = weights;
float *out = output;
memcpy(out, bias, channels * sizeof(float));
for (; c >= C8NUM; c -= C8NUM) {
for (int i = 0; i < C8NUM; i++) {
for (int k = 0; k < kernel; k++) {
out[i] += in[k][i] * w[i + k * C8NUM];
}
}
w += kernel * C8NUM;
out += C8NUM;
for (int k = 0; k < kernel; k++) {
in[k] += C8NUM;
}
}
for (int i = 0; i < c; i++) {
for (int k = 0; k < kernel; k++) {
out[i] += in[k][i] * w[i + k * C8NUM];
}
}
if (relu) {
ReluFp32C8(output, output, channels);
}
if (relu6) {
Relu6Fp32C8(output, output, channels);
}
output += channels;
} while (--output_width != 0);
}
#else
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
if (kernel == 9) {
@ -729,7 +688,6 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
}
}
#endif
#endif
void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
float *zero_ptr, const ConvParameter *conv_param, int task_id) {

@ -67,10 +67,8 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c
#endif
#ifdef ENABLE_AVX
#ifndef WIN32
void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, size_t input_stride, size_t relu, size_t relu6);
#endif
void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels,
size_t output_width, size_t input_stride, size_t relu, size_t relu6);
#endif
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,

@ -883,11 +883,7 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
if (out_type == OutType_C8) {
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
} else {
#ifdef WIN32
MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#else
MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
#endif
MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
}
#elif ENABLE_SSE
if (out_type == OutType_C8) {

@ -62,8 +62,8 @@ void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bia
void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, int write_mode);
#ifdef ENABLE_AVX
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, int write_mode);
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth,
size_t row, size_t col, size_t stride, size_t write_mode);
#endif
#endif

Loading…
Cancel
Save