diff --git a/mindspore/lite/nnacl/fp32/exp_fp32.c b/mindspore/lite/nnacl/fp32/exp_fp32.c index a59f8925c3..20fdcaa17f 100644 --- a/mindspore/lite/nnacl/fp32/exp_fp32.c +++ b/mindspore/lite/nnacl/fp32/exp_fp32.c @@ -16,6 +16,7 @@ #include "nnacl/fp32/exp_fp32.h" #include +#include #include "nnacl/errorcode.h" int Exp(const float *input_data, float *output_data, ExpParameter *parameter, int task_id) { @@ -35,3 +36,40 @@ int Exp(const float *input_data, float *output_data, ExpParameter *parameter, in } return NNACL_OK; } + +void ExpFp32(const float *src, float *dst, int num) { + int i = 0; + const float param[] = {log(2.0f), 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; +#ifdef ENABLE_ARM64 + float32x4_t maxv = vdupq_n_f32(88.0f); + float32x4_t minv = vdupq_n_f32(-88.0f); + float32x4_t param0 = vdupq_n_f32(log(2.0f)); + float32x4_t param1 = vdupq_n_f32(1.0f / 120); + float32x4_t param2 = vdupq_n_f32(1.0f / 24); + float32x4_t param3 = vdupq_n_f32(1.0f / 6); + float32x4_t param4 = vdupq_n_f32(0.5f); + float32x4_t param5 = vdupq_n_f32(1.0f); + for (; i < num - C4NUM; i += C4NUM) { + float32x4_t input4 = vmaxq_f32(minv, vminq_f32(maxv, vld1q_f32(src + i))); + int32x4_t integer4 = vcvtq_s32_f32(vdivq_f32(input4, param0)); + float32x4_t decimal4 = vsubq_f32(input4, vmulq_f32(vcvtq_f32_s32(integer4), param0)); + int32x4_t int_exp4 = vshlq_s32(vaddq_s32(integer4, vdupq_n_s32(127)), vdupq_n_s32(23)); + vst1q_f32(dst + i, vld1q_f32((float32_t *)(&int_exp4))); + float32x4_t decimal_exp4 = vaddq_f32(param2, vmulq_f32(decimal4, param1)); + decimal_exp4 = vmulq_f32(decimal4, vaddq_f32(param3, vmulq_f32(decimal4, decimal_exp4))); + decimal_exp4 = vaddq_f32(param5, vmulq_f32(decimal4, vaddq_f32(param4, decimal_exp4))); + decimal_exp4 = vaddq_f32(param5, vmulq_f32(decimal4, decimal_exp4)); + vst1q_f32(dst + i, vmulq_f32(vld1q_f32(dst + i), decimal_exp4)); + } +#endif + for (; i < num; ++i) { + float input = MSMAX(-88.0f, MSMIN(88.0f, src[i])); + int integer = input / param[0]; + float decimal = input - integer * param[0]; + int int_exp = (integer + 127) << 23; + memcpy(dst + i, &int_exp, sizeof(float)); + float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + dst[i] *= decimal_exp; + } +} diff --git a/mindspore/lite/nnacl/fp32/exp_fp32.h b/mindspore/lite/nnacl/fp32/exp_fp32.h index 2ada5325ac..a69d3c891a 100644 --- a/mindspore/lite/nnacl/fp32/exp_fp32.h +++ b/mindspore/lite/nnacl/fp32/exp_fp32.h @@ -34,6 +34,7 @@ typedef struct ExpParameter { extern "C" { #endif int Exp(const float *input_data, float *output_data, ExpParameter *parameter, int task_id); +void ExpFp32(const float *src, float *dst, int num); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32/softmax_fp32.c b/mindspore/lite/nnacl/fp32/softmax_fp32.c index aa2d2a0248..a28f6c63ab 100644 --- a/mindspore/lite/nnacl/fp32/softmax_fp32.c +++ b/mindspore/lite/nnacl/fp32/softmax_fp32.c @@ -13,9 +13,79 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "nnacl/fp32/softmax_fp32.h" #include +#include "nnacl/fp32/exp_fp32.h" + +void SoftmaxNorm(const float *src, float *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int j = 0; +#ifdef ENABLE_ARM64 + float32x4_t max4 = vld1q_f32(src + cur_batch_offset); + j += C4NUM; + for (; j < channel - C4NUM; j += C4NUM) { + float32x4_t input4 = vld1q_f32(src + cur_batch_offset + j); + max4 = vmaxq_f32(max4, input4); + } + float max = channel >= C4NUM ? vmaxvq_f32(max4) : src[cur_batch_offset]; +#else + float max = src[cur_batch_offset]; +#endif + for (; j < channel; j++) { + float input = src[cur_batch_offset + j]; + if (input > max) { + max = input; + } + } + int k = 0; +#ifdef ENABLE_NEON + for (; k < channel - C4NUM; k += C4NUM) { + float32x4_t input4 = vld1q_f32(src + cur_batch_offset + k); + float32x4_t output4 = vsubq_f32(input4, vdupq_n_f32(max)); + vst1q_f32(dst + cur_batch_offset + k, output4); + } +#endif + for (; k < channel; k++) { + int offset = cur_batch_offset + k; + dst[offset] = src[offset] - max; + } + } +} + +void SumAndDiv(const float *src, float *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float sum = 0; + int j = 0; +#ifdef ENABLE_NEON + float32x4_t sum4 = vdupq_n_f32(0); + for (; j < channel - C4NUM; j += C4NUM) { + sum4 = vaddq_f32(sum4, vld1q_f32(src + cur_batch_offset + j)); + } + sum = sum4[0] + sum4[1] + sum4[2] + sum4[3]; +#endif + for (; j < channel; j++) { + sum += src[cur_batch_offset + j]; + } + int k = 0; +#ifdef ENABLE_NEON + float div = 1.0f / sum; + for (; k < channel - C4NUM; k += C4NUM) { + vst1q_f32(dst + cur_batch_offset + k, vmulq_n_f32(vld1q_f32(src + cur_batch_offset + k), div)); + } +#endif + for (; k < channel; k++) { + dst[cur_batch_offset + k] = src[cur_batch_offset + k] / sum; + } + } +} + +void SoftmaxLastAxis(const float *src, float *dst, int batch, int channel) { + SoftmaxNorm(src, dst, batch, channel); + ExpFp32(dst, dst, batch * channel); + SumAndDiv(dst, dst, batch, channel); +} // output = exp(input) / reduce_sum(exp(input), axis) void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter) { diff --git a/mindspore/lite/nnacl/fp32/softmax_fp32.h b/mindspore/lite/nnacl/fp32/softmax_fp32.h index 3d2f220242..6c7e0ce263 100644 --- a/mindspore/lite/nnacl/fp32/softmax_fp32.h +++ b/mindspore/lite/nnacl/fp32/softmax_fp32.h @@ -23,6 +23,7 @@ extern "C" { #endif void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter); +void SoftmaxLastAxis(const float *src, float *dst, int batch, int channel); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax_fp32.cc index addf9ad0d8..c2643ae75c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax_fp32.cc @@ -70,12 +70,41 @@ int SoftmaxCPUKernel::ReSize() { return RET_OK; } -int SoftmaxCPUKernel::Run() { - memset(sum_data_, 0, in_plane_size_ * out_plane_size_ * sizeof(float)); +int SoftmaxCPUKernel::DoSoftmaxLastAxis(int task_id) { + int unit = UP_DIV(out_plane_size_, context_->thread_num_); + int begin = task_id * unit; + int end = MSMIN(begin + unit, out_plane_size_); + int channel = softmax_param_->input_shape_[softmax_param_->axis_]; + int offset = begin * channel; auto input_ptr = reinterpret_cast(in_tensors_.at(kInputIndex)->MutableData()); auto output_ptr = reinterpret_cast(out_tensors_.at(kOutputIndex)->MutableData()); - Softmax(input_ptr, output_ptr, sum_data_, softmax_param_); + SoftmaxLastAxis(input_ptr + offset, output_ptr + offset, end - begin, channel); return RET_OK; } +int SoftmaxLastAxisRun(void *cdata, int task_id) { + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->DoSoftmaxLastAxis(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoSoftmaxLastAxis error task_id: " << task_id << ", ret: " << ret; + } + return ret; +} + +int SoftmaxCPUKernel::Run() { + auto input_ptr = reinterpret_cast(in_tensors_.at(kInputIndex)->MutableData()); + auto output_ptr = reinterpret_cast(out_tensors_.at(kOutputIndex)->MutableData()); + int ret = RET_OK; + if (in_plane_size_ == 1) { + ret = ParallelLaunch(this->context_->thread_pool_, SoftmaxLastAxisRun, this, context_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SoftmaxCPUKernel ParallelLaunch failed, ret: " << ret; + } + } else { + memset(sum_data_, 0, in_plane_size_ * out_plane_size_ * sizeof(float)); + Softmax(input_ptr, output_ptr, sum_data_, softmax_param_); + } + return ret; +} + } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax_fp32.h index cf23b56635..6cd6c791fa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax_fp32.h @@ -37,6 +37,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel { int Init() override; int ReSize() override; int Run() override; + int DoSoftmaxLastAxis(int task_id); private: float *sum_data_ = nullptr; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc new file mode 100644 index 0000000000..d6752eb6fb --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "nnacl/softmax_parameter.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { +class TestSoftmaxFp32 : public mindspore::CommonTest { + public: + TestSoftmaxFp32() {} +}; + +TEST_F(TestSoftmaxFp32, 001) { + lite::Tensor in_tensor(kNumberTypeFloat32, {2, 1, 1, 5}); + lite::Tensor out_tensor(kNumberTypeFloat32, {2, 1, 1, 5}); + float input_data[] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + float output_data[10] = {0}; + in_tensor.set_data(input_data); + out_tensor.set_data(output_data); + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor}; + + SoftmaxParameter parameter = {{}, -1, 10, 4, {2, 1, 1, 5}}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SoftMax}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto ctx = std::make_shared(); + ASSERT_EQ(lite::RET_OK, ctx->Init()); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + float expect[] = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; + for (size_t i = 0; i < sizeof(expect) / sizeof(expect[0]); ++i) { + EXPECT_EQ(output_data[i], expect[i]); + } + in_tensor.set_data(nullptr); + out_tensor.set_data(nullptr); +} +} // namespace mindspore