|
|
@ -17,6 +17,7 @@
|
|
|
|
#ifdef ENABLE_SSE
|
|
|
|
#ifdef ENABLE_SSE
|
|
|
|
#include <x86intrin.h>
|
|
|
|
#include <x86intrin.h>
|
|
|
|
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
|
|
|
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
|
|
|
|
|
|
|
#include "nnacl/intrinsics/sse/sse_common.h"
|
|
|
|
|
|
|
|
|
|
|
|
void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
|
|
|
|
void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
|
|
|
|
size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) {
|
|
|
|
size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) {
|
|
|
@ -123,18 +124,16 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f
|
|
|
|
int c2 = DOWN_DIV(width, C2NUM) * C2NUM;
|
|
|
|
int c2 = DOWN_DIV(width, C2NUM) * C2NUM;
|
|
|
|
int c1 = 0;
|
|
|
|
int c1 = 0;
|
|
|
|
// c4 loop
|
|
|
|
// c4 loop
|
|
|
|
for (; c1 < c4; c1 += C4NUM) {
|
|
|
|
for (; c1 < c4; c1 += C4NUM, dst_w += C4NUM * block_channel, src_w += C4NUM * in_sw_step) {
|
|
|
|
const float *src_kh = src_w;
|
|
|
|
const float *src_kh = src_w, *weight_kh = weight;
|
|
|
|
const float *weight_kh = weight;
|
|
|
|
|
|
|
|
__m128 dst_w_ma1 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma1 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma2 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma2 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma3 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma3 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma4 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma4 = _mm_setzero_ps();
|
|
|
|
|
|
|
|
|
|
|
|
for (int kh = 0; kh < kernel_h; kh++) {
|
|
|
|
for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) {
|
|
|
|
const float *src_kw = src_kh;
|
|
|
|
const float *src_kw = src_kh, *weight_kw = weight_kh;
|
|
|
|
const float *weight_kw = weight_kh;
|
|
|
|
for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) {
|
|
|
|
for (int kw = 0; kw < kernel_w; kw++) {
|
|
|
|
|
|
|
|
__m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
|
|
|
|
__m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
|
|
|
|
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
|
|
|
|
__m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
|
|
|
@ -154,13 +153,9 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f
|
|
|
|
__m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4);
|
|
|
|
__m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4);
|
|
|
|
dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4);
|
|
|
|
dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4);
|
|
|
|
|
|
|
|
|
|
|
|
src_kw += in_kw_step;
|
|
|
|
|
|
|
|
weight_kw += C4NUM;
|
|
|
|
|
|
|
|
} // kernel_w loop
|
|
|
|
} // kernel_w loop
|
|
|
|
src_kh += in_kh_step;
|
|
|
|
|
|
|
|
weight_kh += kernel_w * C4NUM;
|
|
|
|
|
|
|
|
} // kernel_h loop
|
|
|
|
} // kernel_h loop
|
|
|
|
|
|
|
|
|
|
|
|
// add bias relu
|
|
|
|
// add bias relu
|
|
|
|
__m128 bias_ma = _mm_loadu_ps(bias);
|
|
|
|
__m128 bias_ma = _mm_loadu_ps(bias);
|
|
|
|
dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
|
|
|
|
dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
|
|
|
@ -168,39 +163,23 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f
|
|
|
|
dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma);
|
|
|
|
dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma);
|
|
|
|
dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma);
|
|
|
|
dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma);
|
|
|
|
|
|
|
|
|
|
|
|
__m128 zero_ma = _mm_setzero_ps();
|
|
|
|
ActBlock4(&dst_w_ma1, &dst_w_ma2, &dst_w_ma3, &dst_w_ma4, relu, relu6);
|
|
|
|
if (relu || relu6) {
|
|
|
|
|
|
|
|
dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1);
|
|
|
|
|
|
|
|
dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2);
|
|
|
|
|
|
|
|
dst_w_ma3 = _mm_max_ps(zero_ma, dst_w_ma3);
|
|
|
|
|
|
|
|
dst_w_ma4 = _mm_max_ps(zero_ma, dst_w_ma4);
|
|
|
|
|
|
|
|
if (relu6) {
|
|
|
|
|
|
|
|
__m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
|
|
|
|
|
|
|
|
dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1);
|
|
|
|
|
|
|
|
dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2);
|
|
|
|
|
|
|
|
dst_w_ma3 = _mm_min_ps(const_ma, dst_w_ma3);
|
|
|
|
|
|
|
|
dst_w_ma4 = _mm_min_ps(const_ma, dst_w_ma4);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
_mm_storeu_ps(dst_w, dst_w_ma1);
|
|
|
|
_mm_storeu_ps(dst_w, dst_w_ma1);
|
|
|
|
_mm_storeu_ps(dst_w + block_channel, dst_w_ma2);
|
|
|
|
_mm_storeu_ps(dst_w + block_channel, dst_w_ma2);
|
|
|
|
_mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3);
|
|
|
|
_mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3);
|
|
|
|
_mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4);
|
|
|
|
_mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4);
|
|
|
|
|
|
|
|
|
|
|
|
dst_w += C4NUM * block_channel;
|
|
|
|
|
|
|
|
src_w += C4NUM * in_sw_step;
|
|
|
|
|
|
|
|
} // dst_width loop
|
|
|
|
} // dst_width loop
|
|
|
|
|
|
|
|
|
|
|
|
// c2 loop
|
|
|
|
// c2 loop
|
|
|
|
for (; c1 < c2; c1 += C2NUM) {
|
|
|
|
for (; c1 < c2; c1 += C2NUM, dst_w += C2NUM * block_channel, src_w += C2NUM * in_sw_step) {
|
|
|
|
const float *src_kh = src_w;
|
|
|
|
const float *src_kh = src_w, *weight_kh = weight;
|
|
|
|
const float *weight_kh = weight;
|
|
|
|
|
|
|
|
__m128 dst_w_ma1 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma1 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma2 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma2 = _mm_setzero_ps();
|
|
|
|
|
|
|
|
|
|
|
|
for (int kh = 0; kh < kernel_h; kh++) {
|
|
|
|
for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) {
|
|
|
|
const float *src_kw = src_kh;
|
|
|
|
const float *src_kw = src_kh, *weight_kw = weight_kh;
|
|
|
|
const float *weight_kw = weight_kh;
|
|
|
|
for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) {
|
|
|
|
for (int kw = 0; kw < kernel_w; kw++) {
|
|
|
|
|
|
|
|
__m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
|
|
|
|
__m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
|
|
|
|
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
|
|
|
|
__m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
|
|
|
@ -210,68 +189,38 @@ void ConvDwFp32Center(float *dst, const float *src, const float *weight, const f
|
|
|
|
__m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2);
|
|
|
|
__m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2);
|
|
|
|
dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
|
|
|
|
dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
|
|
|
|
|
|
|
|
|
|
|
|
src_kw += in_kw_step;
|
|
|
|
|
|
|
|
weight_kw += C4NUM;
|
|
|
|
|
|
|
|
} // kernel_w loop
|
|
|
|
} // kernel_w loop
|
|
|
|
src_kh += in_kh_step;
|
|
|
|
|
|
|
|
weight_kh += kernel_w * C4NUM;
|
|
|
|
|
|
|
|
} // kernel_h loop
|
|
|
|
} // kernel_h loop
|
|
|
|
// add bias relu
|
|
|
|
// add bias relu
|
|
|
|
__m128 bias_ma = _mm_loadu_ps(bias);
|
|
|
|
__m128 bias_ma = _mm_loadu_ps(bias);
|
|
|
|
dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
|
|
|
|
dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
|
|
|
|
dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma);
|
|
|
|
dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma);
|
|
|
|
__m128 zero_ma = _mm_setzero_ps();
|
|
|
|
|
|
|
|
if (relu || relu6) {
|
|
|
|
ActBlock2(&dst_w_ma1, &dst_w_ma2, relu, relu6);
|
|
|
|
dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1);
|
|
|
|
|
|
|
|
dst_w_ma2 = _mm_max_ps(zero_ma, dst_w_ma2);
|
|
|
|
|
|
|
|
if (relu6) {
|
|
|
|
|
|
|
|
__m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
|
|
|
|
|
|
|
|
dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1);
|
|
|
|
|
|
|
|
dst_w_ma2 = _mm_min_ps(const_ma, dst_w_ma2);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
_mm_storeu_ps(dst_w, dst_w_ma1);
|
|
|
|
_mm_storeu_ps(dst_w, dst_w_ma1);
|
|
|
|
_mm_storeu_ps(dst_w + block_channel, dst_w_ma2);
|
|
|
|
_mm_storeu_ps(dst_w + block_channel, dst_w_ma2);
|
|
|
|
|
|
|
|
|
|
|
|
dst_w += C2NUM * block_channel;
|
|
|
|
|
|
|
|
src_w += C2NUM * in_sw_step;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// remaining
|
|
|
|
// remaining
|
|
|
|
for (; c1 < width; c1++) {
|
|
|
|
for (; c1 < width; c1++, dst_w += block_channel, src_w += in_sw_step) {
|
|
|
|
const float *src_kh = src_w;
|
|
|
|
const float *src_kh = src_w, *weight_kh = weight;
|
|
|
|
const float *weight_kh = weight;
|
|
|
|
|
|
|
|
__m128 dst_w_ma1 = _mm_setzero_ps();
|
|
|
|
__m128 dst_w_ma1 = _mm_setzero_ps();
|
|
|
|
for (int kh = 0; kh < kernel_h; kh++) {
|
|
|
|
for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) {
|
|
|
|
const float *src_kw = src_kh;
|
|
|
|
const float *src_kw = src_kh, *weight_kw = weight_kh;
|
|
|
|
const float *weight_kw = weight_kh;
|
|
|
|
for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) {
|
|
|
|
for (int kw = 0; kw < kernel_w; kw++) {
|
|
|
|
|
|
|
|
__m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
|
|
|
|
__m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
|
|
|
|
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
|
|
|
|
__m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
|
|
|
|
__m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
|
|
|
|
dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
|
|
|
|
dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
|
|
|
|
|
|
|
|
|
|
|
|
src_kw += in_kw_step;
|
|
|
|
|
|
|
|
weight_kw += C4NUM;
|
|
|
|
|
|
|
|
} // kernel_w loop
|
|
|
|
} // kernel_w loop
|
|
|
|
src_kh += in_kh_step;
|
|
|
|
|
|
|
|
weight_kh += kernel_w * C4NUM;
|
|
|
|
|
|
|
|
} // kernel_h loop
|
|
|
|
} // kernel_h loop
|
|
|
|
|
|
|
|
|
|
|
|
// add bias relu
|
|
|
|
// add bias relu
|
|
|
|
__m128 bias_ma = _mm_loadu_ps(bias);
|
|
|
|
__m128 bias_ma = _mm_loadu_ps(bias);
|
|
|
|
dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
|
|
|
|
dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
|
|
|
|
__m128 zero_ma = _mm_setzero_ps();
|
|
|
|
ActBlock1(&dst_w_ma1, relu, relu6);
|
|
|
|
if (relu || relu6) {
|
|
|
|
|
|
|
|
dst_w_ma1 = _mm_max_ps(zero_ma, dst_w_ma1);
|
|
|
|
|
|
|
|
if (relu6) {
|
|
|
|
|
|
|
|
__m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
|
|
|
|
|
|
|
|
dst_w_ma1 = _mm_min_ps(const_ma, dst_w_ma1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
_mm_storeu_ps(dst_w, dst_w_ma1);
|
|
|
|
_mm_storeu_ps(dst_w, dst_w_ma1);
|
|
|
|
|
|
|
|
|
|
|
|
dst_w += block_channel;
|
|
|
|
|
|
|
|
src_w += in_sw_step;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
dst_h += out_h_step;
|
|
|
|
dst_h += out_h_step;
|
|
|
|
src_h += in_sh_step;
|
|
|
|
src_h += in_sh_step;
|