|
|
|
@ -18,7 +18,6 @@ limitations under the License. */
|
|
|
|
|
#include "neon_util.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
namespace neon {
|
|
|
|
|
|
|
|
|
|
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
|
|
|
|
@ -26,17 +25,20 @@ namespace neon {
|
|
|
|
|
template <int filterSize, int stride>
|
|
|
|
|
struct DepthwiseConvKernel {};
|
|
|
|
|
|
|
|
|
|
inline float32_t conv3x3(float32x4_t r0,
|
|
|
|
|
float32x4_t r1,
|
|
|
|
|
float32x4_t r2,
|
|
|
|
|
inline float32_t conv3x3(const float* r0,
|
|
|
|
|
const float* r1,
|
|
|
|
|
const float* r2,
|
|
|
|
|
float32x4_t k0,
|
|
|
|
|
float32x4_t k1,
|
|
|
|
|
float32x4_t k2) {
|
|
|
|
|
float32x4_t tmp;
|
|
|
|
|
tmp = vmulq_f32(r0, k0);
|
|
|
|
|
tmp = vmlaq_f32(tmp, r1, k1);
|
|
|
|
|
tmp = vmlaq_f32(tmp, r2, k2);
|
|
|
|
|
return vaddvq_f32(tmp);
|
|
|
|
|
float32_t tmp[12];
|
|
|
|
|
vst1q_f32(&(tmp[0]), k0);
|
|
|
|
|
vst1q_f32(&(tmp[4]), k1);
|
|
|
|
|
vst1q_f32(&(tmp[8]), k2);
|
|
|
|
|
float32_t sum0 = r0[0] * tmp[0] + r0[1] * tmp[1] + r0[2] * tmp[2];
|
|
|
|
|
float32_t sum1 = r1[0] * tmp[4] + r1[1] * tmp[5] + r1[2] * tmp[6];
|
|
|
|
|
float32_t sum2 = r2[0] * tmp[8] + r2[1] * tmp[9] + r2[2] * tmp[10];
|
|
|
|
|
return sum0 + sum1 + sum2;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline float32_t conv4x4(float32x4_t r0,
|
|
|
|
@ -136,10 +138,7 @@ struct DepthwiseConvKernel<3, 1> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int r = 0; r < remain; r++) {
|
|
|
|
|
float32x4_t i0 = vld1q_f32(r0);
|
|
|
|
|
float32x4_t i1 = vld1q_f32(r1);
|
|
|
|
|
float32x4_t i2 = vld1q_f32(r2);
|
|
|
|
|
*outputData = conv3x3(i0, i1, i2, k[0], k[1], k[2]);
|
|
|
|
|
*outputData = conv3x3(r0, r1, r2, k[0], k[1], k[2]);
|
|
|
|
|
r0++;
|
|
|
|
|
r1++;
|
|
|
|
|
r2++;
|
|
|
|
@ -243,10 +242,7 @@ struct DepthwiseConvKernel<3, 2> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int r = 0; r < remain; r++) {
|
|
|
|
|
float32x4_t i0 = vld1q_f32(r0);
|
|
|
|
|
float32x4_t i1 = vld1q_f32(r1);
|
|
|
|
|
float32x4_t i2 = vld1q_f32(r2);
|
|
|
|
|
*outputData = conv3x3(i0, i1, i2, k[0], k[1], k[2]);
|
|
|
|
|
*outputData = conv3x3(r0, r1, r2, k[0], k[1], k[2]);
|
|
|
|
|
r0 += 2;
|
|
|
|
|
r1 += 2;
|
|
|
|
|
r2 += 2;
|
|
|
|
|