diff --git a/mindspore/lite/nnacl/x86_64_sse/ConvDwFp32IndirectRow.c b/mindspore/lite/nnacl/x86_64_sse/ConvDwFp32IndirectRow.c index 924c21b93f..5a11b441f8 100644 --- a/mindspore/lite/nnacl/x86_64_sse/ConvDwFp32IndirectRow.c +++ b/mindspore/lite/nnacl/x86_64_sse/ConvDwFp32IndirectRow.c @@ -49,7 +49,7 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const __m256 in4 = _mm256_loadu_ps(in[k + 3]); __m256 w4 = _mm256_loadu_ps(w + 24); out1 = _mm256_fmadd_ps(in3, w3, out1); - __m256 in5 = _mm256_loadu_ps(in[k + 8]); + __m256 in5 = _mm256_loadu_ps(in[k + 4]); __m256 w5 = _mm256_loadu_ps(w + 32); out1 = _mm256_fmadd_ps(in4, w4, out1); w += 40; @@ -68,7 +68,10 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const __m256 zero = _mm256_setzero_ps(); out1 = _mm256_max_ps(out1, zero); } - if (c == C8NUM) { + if (c > C8NUM || c8_mod == 0) { + _mm256_storeu_ps(output, out1); + output += C8NUM; + } else { __m128 tmp; switch (c8_mod) { case 1: @@ -105,10 +108,7 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const _mm256_storeu_ps(output, out1); break; } - output += c8_mod == 0 ? C8NUM : c8_mod; - } else { - _mm256_storeu_ps(output, out1); - output += C8NUM; + output += c8_mod; } } }