!13673 [ms][lite][cpu] gelu optimize

From: @lzkcode
Reviewed-by: 
Signed-off-by:
pull/13673/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 461a0625cc

@ -168,3 +168,40 @@ int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val
}
return NNACL_OK;
}
int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) {
if (src == NULL || dst == NULL) {
return NNACL_ERR;
}
int i = 0;
if (approximate) {
// dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3)))
#ifdef ENABLE_NEON
int C8 = UP_ROUND(length, C8NUM);
for (; i < C8; i += C8NUM) {
float16x8_t in = vld1q_f16(src + i);
float16x8_t res =
0.5 * in * (1.0 + MS_TANHX8_F16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * in * in) * in));
vst1q_f16(dst + i, res);
}
#endif
for (; i < length; i++) {
dst[i] =
0.5 * src[i] *
(1.0 + TanhOptFp16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * src[i] * src[i]) * src[i]));
}
} else {
#ifdef ENABLE_NEON
int C8 = UP_ROUND(length, C8NUM);
for (; i < C8; i += C8NUM) {
float16x8_t in = vld1q_f16(src + i);
float16x8_t res = 0.5 * in * (1.0 + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f));
vst1q_f16(dst + i, res);
}
#endif
for (; i < length; i++) {
dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f));
}
}
return NNACL_OK;
}

@ -34,6 +34,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num);
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num);
int SwishFp16(const float16_t *src, float16_t *dst, int ele_num);
int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val);
int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate);
#ifdef __cplusplus
}
#endif

@ -134,50 +134,21 @@ float TanhOpt(float src) {
int Tanh(const float *src, int length, float *dst) {
int i = 0;
#if defined(ENABLE_ARM) || defined(ENABLE_SSE) || defined(ENABLE_AVX)
const int cnt = 6;
float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
#endif
#if defined(ENABLE_AVX)
MS_FLOAT32X8 neg_one_8 = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
MS_FLOAT32X8 pos_one_8 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X8 param256[6];
for (int j = 0; j < cnt; ++j) {
param256[j] = MS_MOV256_F32(data[j]);
}
for (; i < length - 8; i += 8) {
MS_FLOAT32X8 input = MS_LD256_F32(src + i);
MS_FLOAT32X8 square = input * input;
MS_FLOAT32X8 a = (((square + param256[0]) * square + param256[1]) * square + param256[2]) * input;
MS_FLOAT32X8 b = ((param256[3] * square + param256[4]) * square + param256[5]) * square + param256[2];
MS_ST256_F32(dst + i, MS_MIN256_F32(MS_MAX256_F32(a / b, neg_one_8), pos_one_8));
MS_ST256_F32(dst + i, MS_TANHX8_F32(input));
}
#endif
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_FLOAT32X4 param[6];
MS_FLOAT32X4 neg_one = {-1.0f, -1.0f, -1.0f, -1.0f};
MS_FLOAT32X4 pos_one = {1.0f, 1.0f, 1.0f, 1.0f};
for (int j = 0; j < cnt; ++j) {
param[j] = MS_MOVQ_F32(data[j]);
}
for (; i < length - 4; i += 4) {
MS_FLOAT32X4 input = MS_LDQ_F32(src + i);
MS_FLOAT32X4 square = input * input;
MS_FLOAT32X4 a = (((square + param[0]) * square + param[1]) * square + param[2]) * input;
MS_FLOAT32X4 b = ((param[3] * square + param[4]) * square + param[5]) * square + param[2];
MS_STQ_F32(dst + i, MS_MINQ_F32(MS_MAXQ_F32(a / b, neg_one), pos_one));
MS_STQ_F32(dst + i, MS_TANHX4_F32(input));
}
#endif
for (; i < length; ++i) {
float input = src[i];
float square = input * input;
float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input;
float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f;
dst[i] = a / b;
dst[i] = MSMAX(dst[i], -1);
dst[i] = MSMIN(dst[i], 1);
dst[i] = TanhOpt(src[i]);
}
return NNACL_OK;
}
@ -249,10 +220,44 @@ int HardTanh(const float *src, int length, float *dst, float min_val, float max_
return NNACL_OK;
}
int Gelu(const float *src, int length, float *dst) {
for (int i = 0; i < length; ++i) {
float tanh_res = TanhOpt(sqrt(2 / M_PI) * (src[i] + 0.044715 * pow(src[i], 3)));
dst[i] = 0.5f * src[i] * (1 + tanh_res);
int Gelu(const float *src, int length, float *dst, bool approximate) {
if (src == NULL || dst == NULL) {
return NNACL_ERR;
}
int i = 0;
if (approximate) {
// dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3)))
#if defined(ENABLE_AVX)
int C8 = UP_ROUND(length, C8NUM);
for (; i < C8; i += C8NUM) {
MS_FLOAT32X8 in = MS_LD256_F32(src + i);
MS_FLOAT32X8 res = 0.5 * in * (1.0 + MS_TANHX8_F32((0.79788456080287f + 0.035677408136f * in * in) * in));
MS_ST256_F32(dst + i, res);
}
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
int C4 = UP_ROUND(length, C4NUM);
for (; i < C4; i += C4NUM) {
MS_FLOAT32X4 in = MS_LDQ_F32(src + i);
MS_FLOAT32X4 res = 0.5 * in * (1.0 + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in * in) * in));
MS_STQ_F32(dst + i, res);
}
#endif
for (; i < length; i++) {
dst[i] = 0.5 * src[i] * (1.0 + TanhOpt((0.79788456080287f + 0.035677408136f * src[i] * src[i]) * src[i]));
}
} else {
#if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
int C4 = UP_ROUND(length, C4NUM);
for (; i < C4; i += C4NUM) {
MS_FLOAT32X4 in = MS_LDQ_F32(src + i);
MS_FLOAT32X4 res = 0.5 * in * (1.0 + MS_ERFX4_F32(in / 1.4142135623730951f));
MS_STQ_F32(dst + i, res);
}
#endif
for (; i < length; i++) {
dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f));
}
}
return NNACL_OK;
}

@ -40,7 +40,7 @@ int HSigmoid(const float *src, int length, float *dst);
int Swish(const float *src, int length, float *dst);
int HSwish(const float *src, int length, float *dst);
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
int Gelu(const float *src, int length, float *dst);
int Gelu(const float *src, int length, float *dst, bool approximate);
float TanhOpt(float src);
#ifdef __cplusplus

@ -17,11 +17,6 @@
#include "nnacl/fp32/conv_depthwise_fp32.h"
#include "nnacl/common_func.h"
#include "nnacl/fp32/common_func_fp32.h"
#include "nnacl/fp32/winograd_transform.h"
#include "nnacl/intrinsics/ms_simd_instructions.h"
#ifdef ENABLE_ARM64
#include <arm_neon.h>
#endif
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,

@ -1,39 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/gelu_fp32.h"
#include "nnacl/gelu_parameter.h"
#include <string.h>
#include <math.h>
#include "nnacl/errorcode.h"
int DoGeLU(const float *src, float *out, int64_t real_dst_count, const GeLUParameter *param) {
if (src == NULL || out == NULL) {
return NNACL_ERR;
}
if (param->approximate_) {
for (int i = 0; i < real_dst_count; i++) {
out[i] = 0.5 * src[i] * (1.0 + tanh(0.7978845608028654 * (src[i] + 0.044715 * pow(src[i], 3))));
}
} else {
for (int i = 0; i < real_dst_count; i++) {
out[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951));
}
}
return NNACL_OK;
}

@ -1,31 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_GELU_H_
#define MINDSPORE_LITE_NNACL_FP32_GELU_H_
#include "nnacl/op_base.h"
#include "nnacl/gelu_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
int DoGeLU(const float *src, float *out, int64_t real_dst_count, const GeLUParameter *param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_GELU_H_

@ -16,6 +16,7 @@
#ifndef MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#define MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#include <math.h>
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
@ -170,4 +171,56 @@ inline static float32x4_t vrecp(float32x4_t v) {
MS_STQ_F32(output_ptr + 6 * num, dst##7); \
MS_STQ_F32(output_ptr + 7 * num, dst##8);
static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) {
static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f};
static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X4 square = src * src;
MS_FLOAT32X4 a = (((square + data[0]) * square + data[1]) * square + data[2]) * src;
MS_FLOAT32X4 b = ((data[3] * square + data[4]) * square + data[5]) * square + data[2];
return MS_MINQ_F32(MS_MAXQ_F32(a / b, neg), pos);
}
#ifdef ENABLE_AVX
static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X8 square = src * src;
MS_FLOAT32X8 a = (((square + data[0]) * square + data[1]) * square + data[2]) * src;
MS_FLOAT32X8 b = ((data[3] * square + data[4]) * square + data[5]) * square + data[2];
return MS_MIN256_F32(MS_MAX256_F32(a / b, neg), pos);
}
#endif
static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
MS_FLOAT32X4 dst;
dst[0] = erff(src[0]);
dst[1] = erff(src[1]);
dst[2] = erff(src[2]);
dst[3] = erff(src[3]);
return dst;
}
#ifdef ENABLE_ARM64
static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {
float32x4_t src_low = vcvt_f32_f16(vget_low_f16(src));
float32x4_t src_high = vcvt_f32_f16(vget_high_f16(src));
return vcombine_f16(vcvt_f16_f32(MS_TANHX4_F32(src_low)), vcvt_f16_f32(MS_TANHX4_F32(src_high)));
}
static inline float16x8_t MS_ERFX8_F16(float16x8_t src) {
float16x8_t dst;
dst[0] = erff(src[0]);
dst[1] = erff(src[1]);
dst[2] = erff(src[2]);
dst[3] = erff(src[3]);
dst[4] = erff(src[4]);
dst[5] = erff(src[5]);
dst[6] = erff(src[6]);
dst[7] = erff(src[7]);
return dst;
}
#endif
#endif // MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_

@ -21,7 +21,7 @@
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#if defined(ENBALE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
#if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
#include "nnacl/intrinsics/ms_simd_instructions.h"
#endif

@ -25,6 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::ActivationType_GELU;
using mindspore::schema::ActivationType_HSWISH;
using mindspore::schema::ActivationType_LEAKY_RELU;
using mindspore::schema::ActivationType_RELU;
@ -73,6 +74,8 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) {
} else if (type_ == schema::ActivationType_HARD_TANH) {
error_code =
HardTanhFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, min_val_, max_val_);
} else if (type_ == schema::ActivationType_GELU) {
error_code = GeluFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, true);
} else {
MS_LOG(ERROR) << "Activation fp16 not support type: " << type_;
return RET_ERROR;

@ -79,7 +79,7 @@ int ActivationCPUKernel::DoActivation(int task_id) {
} else if (type_ == schema::ActivationType_HARD_TANH) {
ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
} else if (type_ == schema::ActivationType_GELU) {
ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id, true);
} else {
MS_LOG(ERROR) << "Activation type error";
return RET_ERROR;

Loading…
Cancel
Save