|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
*/
|
|
|
|
|
#include "nnacl/fp32/softmax_fp32.h"
|
|
|
|
|
#include <math.h>
|
|
|
|
|
#include <float.h>
|
|
|
|
|
#include "nnacl/fp32/exp_fp32.h"
|
|
|
|
|
|
|
|
|
|
void SoftmaxNorm(const float *src, float *dst, int batch, int channel) {
|
|
|
|
@ -22,15 +23,15 @@ void SoftmaxNorm(const float *src, float *dst, int batch, int channel) {
|
|
|
|
|
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
|
|
|
|
|
int j = 0;
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
|
float32x4_t max4 = vld1q_f32(src + cur_batch_offset);
|
|
|
|
|
j += C4NUM;
|
|
|
|
|
for (; j < channel - C4NUM; j += C4NUM) {
|
|
|
|
|
float32x4_t max4 = vdupq_n_f32(-FLT_MAX);
|
|
|
|
|
int count = (channel / C4NUM) * C4NUM;
|
|
|
|
|
for (; j < count; j += C4NUM) {
|
|
|
|
|
float32x4_t input4 = vld1q_f32(src + cur_batch_offset + j);
|
|
|
|
|
max4 = vmaxq_f32(max4, input4);
|
|
|
|
|
}
|
|
|
|
|
float max = channel >= C4NUM ? vmaxvq_f32(max4) : src[cur_batch_offset];
|
|
|
|
|
float max = vmaxvq_f32(max4);
|
|
|
|
|
#else
|
|
|
|
|
float max = src[cur_batch_offset];
|
|
|
|
|
float max = -FLT_MAX;
|
|
|
|
|
#endif
|
|
|
|
|
for (; j < channel; j++) {
|
|
|
|
|
float input = src[cur_batch_offset + j];
|
|
|
|
@ -40,7 +41,8 @@ void SoftmaxNorm(const float *src, float *dst, int batch, int channel) {
|
|
|
|
|
}
|
|
|
|
|
int k = 0;
|
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
|
for (; k < channel - C4NUM; k += C4NUM) {
|
|
|
|
|
int count2 = (channel / C4NUM) * C4NUM;
|
|
|
|
|
for (; k < count2; k += C4NUM) {
|
|
|
|
|
float32x4_t input4 = vld1q_f32(src + cur_batch_offset + k);
|
|
|
|
|
float32x4_t output4 = vsubq_f32(input4, vdupq_n_f32(max));
|
|
|
|
|
vst1q_f32(dst + cur_batch_offset + k, output4);
|
|
|
|
@ -60,7 +62,8 @@ void SumAndDiv(const float *src, float *dst, int batch, int channel) {
|
|
|
|
|
int j = 0;
|
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
|
float32x4_t sum4 = vdupq_n_f32(0);
|
|
|
|
|
for (; j < channel - C4NUM; j += C4NUM) {
|
|
|
|
|
int count = (channel / C4NUM) * C4NUM;
|
|
|
|
|
for (; j < count; j += C4NUM) {
|
|
|
|
|
sum4 = vaddq_f32(sum4, vld1q_f32(src + cur_batch_offset + j));
|
|
|
|
|
}
|
|
|
|
|
sum = sum4[0] + sum4[1] + sum4[2] + sum4[3];
|
|
|
|
@ -71,7 +74,7 @@ void SumAndDiv(const float *src, float *dst, int batch, int channel) {
|
|
|
|
|
int k = 0;
|
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
|
const float div = 1.0f / sum;
|
|
|
|
|
for (; k < channel - C4NUM; k += C4NUM) {
|
|
|
|
|
for (; k < count; k += C4NUM) {
|
|
|
|
|
vst1q_f32(dst + cur_batch_offset + k, vmulq_n_f32(vld1q_f32(src + cur_batch_offset + k), div));
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|