From 872bcfaffa8ebc7406819508f5f3390b82097117 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Fri, 7 Aug 2020 16:46:21 +0800 Subject: [PATCH] add int8 relu6 kernel --- .../src/runtime/kernel/arm/int8/activation.cc | 5 +- .../arm/int8/{relu_int8.cc => relux_int8.cc} | 22 ++++---- .../arm/int8/{relu_int8.h => relux_int8.h} | 45 +++++++++++++--- .../nnacl/int8/{relu_int8.h => relux_int8.h} | 9 ++-- ...relu_int8_tests.cc => relux_int8_tests.cc} | 52 +++++++++++++++++-- 5 files changed, 107 insertions(+), 26 deletions(-) rename mindspore/lite/src/runtime/kernel/arm/int8/{relu_int8.cc => relux_int8.cc} (75%) rename mindspore/lite/src/runtime/kernel/arm/int8/{relu_int8.h => relux_int8.h} (52%) rename mindspore/lite/src/runtime/kernel/arm/nnacl/int8/{relu_int8.h => relux_int8.h} (87%) rename mindspore/lite/test/ut/src/runtime/kernel/arm/int8/{relu_int8_tests.cc => relux_int8_tests.cc} (55%) diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc b/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc index 32d933d5e8..70dfce455a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc @@ -15,7 +15,7 @@ */ #include "src/runtime/kernel/arm/fp32/activation.h" -#include "src/runtime/kernel/arm/int8/relu_int8.h" +#include "src/runtime/kernel/arm/int8/relux_int8.h" #include "src/runtime/kernel/arm/int8/hswish_int8.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" @@ -44,6 +44,9 @@ kernel::LiteKernel *CpuActivationInt8KernelCreator(const std::vector(inputs_.at(0)->Data()); auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); auto length = inputs_.at(0)->ElementsNum(); @@ -54,24 +54,24 @@ int ReluInt8CPUKernel::DoActivation(int task_id) { int stride = UP_DIV(length, thread_count_); int count = MSMIN(stride, length - stride * task_id); - ReluInt8(input_addr + stride * task_id, count, output_addr + stride * task_id, &quant_arg_); + ReluXInt8(input_addr + stride * task_id, count, output_addr + stride * task_id, &quant_arg_); return RET_OK; } -int ReluInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { - auto activation_kernel = reinterpret_cast(cdata); +int ReluXInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto activation_kernel = reinterpret_cast(cdata); auto error_code = activation_kernel->DoActivation(task_id); if (error_code != RET_OK) { - MS_LOG(ERROR) << "ReluInt8Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + MS_LOG(ERROR) << "ReluXInt8Run error task_id[" << task_id << "] error_code[" << error_code << "]"; return RET_ERROR; } return RET_OK; } -int ReluInt8CPUKernel::Run() { - int error_code = LiteBackendParallelLaunch(ReluInt8Run, this, thread_count_); +int ReluXInt8CPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(ReluXInt8Run, this, thread_count_); if (error_code != RET_OK) { - MS_LOG(ERROR) << "ReluInt8Run function error error_code[" << error_code << "]"; + MS_LOG(ERROR) << "ReluXInt8Run function error error_code[" << error_code << "]"; return RET_ERROR; } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.h similarity index 52% rename from mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.h rename to mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.h index 5f8f81a3a5..44cffe4bc0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.h @@ -20,27 +20,60 @@ #include #include "src/lite_kernel.h" #include "src/runtime/kernel/arm/nnacl/fp32/activation.h" -#include "src/runtime/kernel/arm/nnacl/int8/relu_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/relux_int8.h" namespace mindspore::kernel { -class ReluInt8CPUKernel : public LiteKernel { +class ReluXInt8CPUKernel : public LiteKernel { public: - ReluInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) + ReluXInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) { type_ = (reinterpret_cast(parameter))->type_; } - ~ReluInt8CPUKernel() override = default; + ~ReluXInt8CPUKernel() override = default; int Init() override; int ReSize() override; int Run() override; int DoActivation(int task_id); + ReluXQuantArg quant_arg_; + private: int thread_count_; int type_; - ReluQuantArg quant_arg_; +}; + +class ReluInt8CPUKernel : public ReluXInt8CPUKernel { + public: + ReluInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : ReluXInt8CPUKernel(parameter, inputs, outputs, ctx) {} + + ~ReluInt8CPUKernel() override = default; + + int Init() override { + auto ret = ReluXInt8CPUKernel::Init(); + quant_arg_.quantized_output_min = quant_arg_.output_arg.zp_; + quant_arg_.quantized_output_max = CHAR_MAX; + return ret; + }; +}; + +class Relu6Int8CPUKernel : public ReluXInt8CPUKernel { + public: + Relu6Int8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : ReluXInt8CPUKernel(parameter, inputs, outputs, ctx) {} + + ~Relu6Int8CPUKernel() override = default; + + int Init() override { + auto ret = ReluXInt8CPUKernel::Init(); + quant_arg_.quantized_output_min = QuantizeToInt8(0, quant_arg_.output_arg.scale_, quant_arg_.output_arg.zp_); + quant_arg_.quantized_output_max = QuantizeToInt8(6, quant_arg_.output_arg.scale_, quant_arg_.output_arg.zp_); + return ret; + }; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/relu_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/relux_int8.h similarity index 87% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/int8/relu_int8.h rename to mindspore/lite/src/runtime/kernel/arm/nnacl/int8/relux_int8.h index 88d5adcb13..a1c3673150 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/relu_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/relux_int8.h @@ -21,15 +21,17 @@ #include "src/runtime/kernel/arm/nnacl/errorcode.h" #include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" -struct ReluQuantArg { +struct ReluXQuantArg { QuantArg input_arg; QuantArg output_arg; int input_multiplier_; int left_shift_; int right_shift_; + int quantized_output_min; + int quantized_output_max; }; -inline void ReluInt8(const int8_t *src, int length, int8_t *dst, ReluQuantArg *arg) { +inline void ReluXInt8(const int8_t *src, int length, int8_t *dst, ReluXQuantArg *arg) { for (int i = 0; i < length; ++i) { if (src[i] <= arg->input_arg.zp_) { dst[i] = arg->output_arg.zp_; @@ -39,8 +41,7 @@ inline void ReluInt8(const int8_t *src, int length, int8_t *dst, ReluQuantArg *a const int32_t scaled_input = SaturatingRoundingDoublingHighMul(input_val, arg->input_multiplier_); const int32_t shifted_input = RoundingDivideByPOT(scaled_input * (1 << arg->left_shift_), -arg->right_shift_); const int32_t output = shifted_input + arg->output_arg.zp_; - dst[i] = (int8_t)output; + dst[i] = (int8_t)MSMIN(output, arg->quantized_output_max); } } - #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_RELU_INT8_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relu_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc similarity index 55% rename from mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relu_int8_tests.cc rename to mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc index 99ca0098a8..910e746cf4 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relu_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc @@ -17,17 +17,17 @@ #include #include #include "common/common_test.h" -#include "mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.h" #include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/include/context.h" namespace mindspore { -class TestReluInt8 : public mindspore::Common { +class TestReluXInt8 : public mindspore::Common { public: - TestReluInt8() {} + TestReluXInt8() {} }; -TEST_F(TestReluInt8, Relu) { +TEST_F(TestReluXInt8, Relu) { lite::tensor::Tensor in_tensor(kNumberTypeInt8, {2, 2}); lite::tensor::Tensor out_tensor(kNumberTypeInt8, {2, 2}); @@ -68,4 +68,48 @@ TEST_F(TestReluInt8, Relu) { in_tensor.SetData(nullptr); out_tensor.SetData(nullptr); } + +TEST_F(TestReluXInt8, Relu6) { + lite::tensor::Tensor in_tensor(kNumberTypeInt8, {2, 4}); + lite::tensor::Tensor out_tensor(kNumberTypeInt8, {2, 4}); + + // -2.5f, -1.5f, 1.25f, 3.0f, 4.5f, 6.0f, 6.5f, 9.0f + int8_t input_data[] = {-118, -98, -44, -10, 19, 49, 59, 108}; + int8_t output_data[8] = {0}; + in_tensor.SetData(input_data); + out_tensor.SetData(output_data); + + const lite::tensor::QuantArg quant_in = {0.0509804f, -69}; // -3.0 -- 10.0 + const lite::tensor::QuantArg quant_out = {0.0392157f, -128}; // 0.0 -- 10.0 + in_tensor.AddQuantParam(quant_in); + out_tensor.AddQuantParam(quant_out); + + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor}; + + ActivationParameter parameter = {0}; + parameter.op_parameter_.type_ = schema::PrimitiveType_Activation; + parameter.type_ = schema::ActivationType_RELU6; + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Activation}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto ctx = std::make_shared(); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + // 0.0f, 0.0f, 1.25f, 3.0f, 4.5f, 6.0f, 6.0f, 6.0f + int8_t expect[8] = {-128, -128, -96, -52, -14, 25, 25, 25}; + for (int i = 0; i < sizeof(expect); ++i) { + EXPECT_EQ(output_data[i], expect[i]); + } + + in_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); +} } // namespace mindspore