parent
7ea0a14795
commit
1cd66fb729
@ -0,0 +1,116 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
||||
|
||||
void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels,
|
||||
size_t output_width, size_t input_stride, size_t relu, size_t relu6) {
|
||||
input_stride /= sizeof(float *);
|
||||
size_t c8 = UP_DIV(channels, C8NUM) * C8NUM;
|
||||
size_t c8_mod = channels % C8NUM;
|
||||
int kernel = 25;
|
||||
for (int i = 0; i < output_width; ++i) {
|
||||
float *in[kernel];
|
||||
for (int k = 0; k < kernel; k++) {
|
||||
in[k] = input[k];
|
||||
}
|
||||
input += input_stride;
|
||||
size_t c = c8;
|
||||
const float *w = weights;
|
||||
const float *bias1 = bias;
|
||||
for (; c >= C8NUM; c -= C8NUM) {
|
||||
__m256 out1 = _mm256_loadu_ps(bias1);
|
||||
bias1 += 8;
|
||||
for (int k = 0; k < kernel; k += 5) {
|
||||
__m256 in1 = _mm256_loadu_ps(in[k]);
|
||||
__m256 w1 = _mm256_loadu_ps(w);
|
||||
__m256 in2 = _mm256_loadu_ps(in[k + 1]);
|
||||
__m256 w2 = _mm256_loadu_ps(w + 8);
|
||||
out1 = _mm256_fmadd_ps(in1, w1, out1);
|
||||
__m256 in3 = _mm256_loadu_ps(in[k + 2]);
|
||||
__m256 w3 = _mm256_loadu_ps(w + 16);
|
||||
out1 = _mm256_fmadd_ps(in2, w2, out1);
|
||||
__m256 in4 = _mm256_loadu_ps(in[k + 3]);
|
||||
__m256 w4 = _mm256_loadu_ps(w + 24);
|
||||
out1 = _mm256_fmadd_ps(in3, w3, out1);
|
||||
__m256 in5 = _mm256_loadu_ps(in[k + 8]);
|
||||
__m256 w5 = _mm256_loadu_ps(w + 32);
|
||||
out1 = _mm256_fmadd_ps(in4, w4, out1);
|
||||
w += 40;
|
||||
in[k] += C8NUM;
|
||||
in[k + 1] += C8NUM;
|
||||
in[k + 2] += C8NUM;
|
||||
in[k + 3] += C8NUM;
|
||||
in[k + 4] += C8NUM;
|
||||
out1 = _mm256_fmadd_ps(in5, w5, out1);
|
||||
}
|
||||
if (relu6 != 0) {
|
||||
__m256 relu6_data = _mm256_set1_ps(6.0);
|
||||
out1 = _mm256_min_ps(out1, relu6_data);
|
||||
}
|
||||
if (relu != 0 || relu6 != 0) {
|
||||
__m256 zero = _mm256_setzero_ps();
|
||||
out1 = _mm256_max_ps(out1, zero);
|
||||
}
|
||||
if (c == C8NUM) {
|
||||
__m128 tmp;
|
||||
switch (c8_mod) {
|
||||
case 1:
|
||||
_mm_store_ss(output, _mm256_castps256_ps128(out1));
|
||||
break;
|
||||
case 2:
|
||||
_mm_storel_pi((__m64 *)output, _mm256_castps256_ps128(out1));
|
||||
break;
|
||||
case 3:
|
||||
tmp = _mm256_castps256_ps128(out1);
|
||||
_mm_storel_pi((__m64 *)output, tmp);
|
||||
tmp = _mm_unpackhi_ps(tmp, tmp);
|
||||
_mm_store_ss(output + 2, tmp);
|
||||
break;
|
||||
case 4:
|
||||
_mm_storeu_ps(output, _mm256_castps256_ps128(out1));
|
||||
break;
|
||||
case 5:
|
||||
_mm_storeu_ps(output, _mm256_castps256_ps128(out1));
|
||||
_mm_store_ss(output + 4, _mm256_extractf128_ps(out1, 1));
|
||||
break;
|
||||
case 6:
|
||||
_mm_storeu_ps(output, _mm256_castps256_ps128(out1));
|
||||
_mm_storel_pi((__m64 *)(output + 4), _mm256_extractf128_ps(out1, 1));
|
||||
break;
|
||||
case 7:
|
||||
_mm_storeu_ps(output, _mm256_castps256_ps128(out1));
|
||||
tmp = _mm256_extractf128_ps(out1, 1);
|
||||
_mm_storel_pi((__m64 *)(output + 4), tmp);
|
||||
tmp = _mm_unpackhi_ps(tmp, tmp);
|
||||
_mm_store_ss(output + 6, tmp);
|
||||
break;
|
||||
default:
|
||||
_mm256_storeu_ps(output, out1);
|
||||
break;
|
||||
}
|
||||
output += c8_mod == 0 ? C8NUM : c8_mod;
|
||||
} else {
|
||||
_mm256_storeu_ps(output, out1);
|
||||
output += C8NUM;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
Loading…
Reference in new issue