diff --git a/mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S b/mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S index 832df26aa9..dc78bdd5d9 100644 --- a/mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S +++ b/mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S @@ -1,15 +1,16 @@ #ifdef ENABLE_AVX -#ifndef WIN32 - .text .align 4 .global ConvDwFp32Avx3x3 #ifndef __APPLE__ +#ifndef WIN32 .type ConvDwFp32Avx3x3, %function #endif +#endif -// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, -// size_t input_stride, size_t relu) +// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, size_t output_width, +// size_t input_stride, size_t relum, szie_t relu6) +// in linux x64 platfrom: // rdi: output // rsi: input // rdx: weights @@ -20,6 +21,16 @@ // 16: relu // 24: relu6 +// in win x64 platfrom: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output +// rdx: input +// r8: weights +// r9: bias +// 40: channels +// 48: output_width +// 56: input_stride +// 64: relu +// 72: relu6 ConvDwFp32Avx3x3: pushq %r15 pushq %r14 @@ -27,14 +38,34 @@ ConvDwFp32Avx3x3: pushq %r12 pushq %rbx pushq %rbp - pushq %r9 - pushq %r8 - pushq %rcx - pushq %rdx - pushq %rsi - pushq %rdi + pushq %r9 // -56 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi // -88 + pushq %rdi // -96 addq $96, %rsp +#ifdef WIN32 + movq %rcx, %rdi + movq %rdx, %rsi + movq %r8, %rdx + movq %r9, %rcx + movq 40(%rsp), %r8 // channels + movq 48(%rsp), %r9 // output_width + + mov %rdx, -80(%rsp) + mov %rcx, -72(%rsp) + mov %r9, -56(%rsp) + mov %r8, -64(%rsp) + movq 56(%rsp), %rbp // input_stride + movq %rbp, 8(%rsp) + movq 64(%rsp), %rbp // relu + movq %rbp, 16(%rsp) + movq 72(%rsp), %rbp // relu6 + movq %rbp, 24(%rsp) +#endif + movq $6, %rax vcvtsi2ss %rax, %xmm15, %xmm15 vshufps $0, %xmm15, %xmm15, %xmm15 @@ -270,4 +301,3 @@ End: popq %r15 retq #endif -#endif diff --git a/mindspore/lite/nnacl/assembly/avx/MatmulAvx.S b/mindspore/lite/nnacl/assembly/avx/MatmulAvx.S index 9e39fe8fd5..98d70f67d4 100644 --- a/mindspore/lite/nnacl/assembly/avx/MatmulAvx.S +++ b/mindspore/lite/nnacl/assembly/avx/MatmulAvx.S @@ -1,14 +1,16 @@ #ifdef ENABLE_AVX -#ifndef WIN32 .text .align 4 .global MatmulFloatAvxOpt #ifndef __APPLE__ +#ifndef WIN32 .type MatmulFloatAvxOpt, %function #endif +#endif -// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth // int row, int col, size_t stride, size_t writeMode) +// parameters pass in Linux x86 platform: // rdi: a // rsi: b // rdx: c @@ -20,6 +22,18 @@ // 24: stride // 32: writeNhwc/writeWino +// parameters pass in win x64 platfrom: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: a +// rdx: b +// r8: c +// r9: bias +// 40: act_type +// 48: depth +// 56: row +// 64: col +// 72: stride +// 80: writeMode + MatmulFloatAvxOpt: // rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention pushq %r15 @@ -28,14 +42,37 @@ MatmulFloatAvxOpt: pushq %r12 pushq %rbx pushq %rbp - pushq %r9 - pushq %r8 - pushq %rcx - pushq %rdx - pushq %rsi - pushq %rdi - addq $96, %rsp + pushq %r9 // -56 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi // -88 + pushq %rdi // -96 + pushq %rsi // -104 rsi + pushq %rdi // -112 rdi + addq $112, %rsp +#ifdef WIN32 + movq %rcx, %rdi + movq %rdx, %rsi + movq %r8, %rdx + movq %r9, %rcx + movq 40(%rsp), %r8 // act_type + movq 48(%rsp), %r9 // depth + movq %r9, -56(%rsp) // r9 + movq %rcx, -72(%rsp) // rcx + movq %rdx, -80(%rsp) // rdx + movq %rsi, -88(%rsp) // rsi + movq %rdi, -96(%rsp) // rdi + movq 56(%rsp), %rbp // row + movq %rbp, 8(%rsp) + movq 64(%rsp), %rbp // col + movq %rbp, 16(%rsp) + movq 72(%rsp), %rbp // stride + movq %rbp, 24(%rsp) + movq 80(%rsp), %rbp // weiteMode + movq %rbp, 32(%rsp) +#endif movq 8(%rsp), %rbp movq 16(%rsp), %rbx movq 24(%rsp), %r10 @@ -926,10 +963,12 @@ LoopRow: jmp LoopRow LoopRowEnd: - subq $96, %rsp + subq $112, %rsp popq %rdi popq %rsi popq %rdx + popq %rdx + popq %rdx popq %rcx popq %r8 popq %r9 @@ -941,4 +980,3 @@ LoopRowEnd: popq %r15 retq #endif -#endif diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c index 77ad4b9140..a5f7847ee3 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c @@ -681,47 +681,6 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c #endif #ifdef ENABLE_AVX -#ifdef WIN32 -void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, - int output_width, int input_stride, bool relu, bool relu6, int kernel) { - do { - float *in[kernel]; - for (int k = 0; k < kernel; k++) { - in[k] = input[k]; - } - input = input + input_stride; - - size_t c = channels; - const float *w = weights; - float *out = output; - memcpy(out, bias, channels * sizeof(float)); - for (; c >= C8NUM; c -= C8NUM) { - for (int i = 0; i < C8NUM; i++) { - for (int k = 0; k < kernel; k++) { - out[i] += in[k][i] * w[i + k * C8NUM]; - } - } - w += kernel * C8NUM; - out += C8NUM; - for (int k = 0; k < kernel; k++) { - in[k] += C8NUM; - } - } - for (int i = 0; i < c; i++) { - for (int k = 0; k < kernel; k++) { - out[i] += in[k][i] * w[i + k * C8NUM]; - } - } - if (relu) { - ReluFp32C8(output, output, channels); - } - if (relu6) { - Relu6Fp32C8(output, output, channels); - } - output += channels; - } while (--output_width != 0); -} -#else void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, int input_stride, bool relu, bool relu6, int kernel) { if (kernel == 9) { @@ -729,7 +688,6 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c } } #endif -#endif void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, float *zero_ptr, const ConvParameter *conv_param, int task_id) { diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h index 3ea0e52dfb..6efe5e7b6e 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h @@ -67,10 +67,8 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c #endif #ifdef ENABLE_AVX -#ifndef WIN32 -void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, - int output_width, size_t input_stride, size_t relu, size_t relu6); -#endif +void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6); #endif void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, diff --git a/mindspore/lite/nnacl/fp32/matmul_fp32.c b/mindspore/lite/nnacl/fp32/matmul_fp32.c index 2e9efc96a4..4a4c85ba4e 100644 --- a/mindspore/lite/nnacl/fp32/matmul_fp32.c +++ b/mindspore/lite/nnacl/fp32/matmul_fp32.c @@ -883,11 +883,7 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT if (out_type == OutType_C8) { MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); } else { -#ifdef WIN32 - MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type); -#else - MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); -#endif + MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type)); } #elif ENABLE_SSE if (out_type == OutType_C8) { diff --git a/mindspore/lite/nnacl/fp32/matmul_fp32.h b/mindspore/lite/nnacl/fp32/matmul_fp32.h index 25651c8e66..b2864b7921 100644 --- a/mindspore/lite/nnacl/fp32/matmul_fp32.h +++ b/mindspore/lite/nnacl/fp32/matmul_fp32.h @@ -62,8 +62,8 @@ void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bia void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, int stride, int write_mode); #ifdef ENABLE_AVX -void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col, int stride, int write_mode); +void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth, + size_t row, size_t col, size_t stride, size_t write_mode); #endif #endif