From 62e93f158d38eb1165222f4157fc65c5d809de36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E8=B4=B5?= Date: Mon, 10 Aug 2020 17:03:05 +0800 Subject: [PATCH] Add arm op sub for int8 and testcases --- .../src/runtime/kernel/arm/int8/sub_int8.cc | 205 ++++ .../src/runtime/kernel/arm/int8/sub_int8.h | 46 + .../src/runtime/kernel/arm/nnacl/add_int8.h | 2 + .../runtime/kernel/arm/nnacl/int8/sub_int8.cc | 104 ++ .../runtime/kernel/arm/nnacl/int8/sub_int8.h | 25 + .../kernel/arm/nnacl/quantization/quantize.h | 20 + .../kernel/arm/fp32/arithmetic_fp32_tests.cc | 1016 +++++++++++++++++ .../runtime/kernel/arm/int8/sub_int_tests.cc | 74 ++ 8 files changed, 1492 insertions(+) create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/nnacl/int8/sub_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/nnacl/int8/sub_int8.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sub_int_tests.cc diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.cc new file mode 100644 index 0000000000..adbd58b53d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.cc @@ -0,0 +1,205 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/sub_int8.h" +#include +#include +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Sub; + +namespace mindspore::kernel { + +int SubInt8CPUKernel::Init() { + lite::tensor::Tensor *input0 = in_tensors_.at(0); + lite::tensor::Tensor *input1 = in_tensors_.at(1); + lite::tensor::Tensor *output = out_tensors_.at(0); + MS_ASSERT(input0); + MS_ASSERT(input1); + MS_ASSERT(output); + + broadcast_ = input0->ElementsNum() != input1->ElementsNum(); + + param_.in0_args_.scale_ = input0->GetQuantParams().front().scale; + param_.in0_args_.zp_ = -input0->GetQuantParams().front().zeroPoint; + param_.in1_args_.scale_ = input1->GetQuantParams().front().scale; + param_.in1_args_.zp_ = -input1->GetQuantParams().front().zeroPoint; + param_.out_args_.scale_ = output->GetQuantParams().front().scale; + param_.out_args_.zp_ = output->GetQuantParams().front().zeroPoint; + + const int left_shift = 20; + const double twice_max_input_scale = 2 * std::max(param_.in0_args_.scale_, param_.in1_args_.scale_); + const double real_input0_multiplier = param_.in0_args_.scale_ / twice_max_input_scale; + const double real_input1_multiplier = param_.in1_args_.scale_ / twice_max_input_scale; + const double real_output_multiplier = twice_max_input_scale / ((1 << left_shift) * param_.out_args_.scale_); + + QuantizeMultiplierSmallerThanOne(real_input0_multiplier, ¶m_.input0_multiplier_, ¶m_.input0_shift_); + QuantizeMultiplierSmallerThanOne(real_input1_multiplier, ¶m_.input1_multiplier_, ¶m_.input1_shift_); + QuantizeMultiplierSmallerThanOne(real_output_multiplier, ¶m_.output_multiplier_, ¶m_.output_shift_); + + param_.output_activation_min_ = std::numeric_limits::min(); + param_.output_activation_max_ = std::numeric_limits::max(); + + int left_shift0 = -param_.input0_shift_ > 0 ? -param_.input0_shift_ : 0; + param_.right_shift0_ = -param_.input0_shift_ > 0 ? 0 : param_.input0_shift_; + + int left_shift1 = -param_.input1_shift_ > 0 ? -param_.input1_shift_ : 0; + param_.right_shift1_ = -param_.input1_shift_ > 0 ? 0 : param_.input1_shift_; + + param_.left_shift_out_ = -param_.output_shift_ > 0 ? -param_.output_shift_ : 0; + param_.right_shift_out_ = -param_.output_shift_ > 0 ? 0 : param_.output_shift_; + + param_.left_shift_result0_ = (1 << left_shift) * ((1 << left_shift0)); + param_.left_shift_result1_ = (1 << left_shift) * ((1 << left_shift1)); + + MS_ASSERT(left_shift + left_shift0 == left_shift); + MS_ASSERT(left_shift + left_shift1 == left_shift); + + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int SubInt8CPUKernel::ReSize() { + if (broadcast_) { + if (tile0_data_ != nullptr) { + if (context_ != nullptr && context_->allocator != nullptr) { + context_->allocator->Free(tile0_data_); + } else { + free(tile0_data_); + } + } + if (tile1_data_ != nullptr) { + if (context_ != nullptr && context_->allocator != nullptr) { + context_->allocator->Free(tile1_data_); + } else { + free(tile1_data_); + } + } + + if (context_ != nullptr && context_->allocator != nullptr) { + tile0_data_ = static_cast(context_->allocator->Malloc(out_tensors_.at(0)->Size())); + tile1_data_ = static_cast(context_->allocator->Malloc(out_tensors_.at(0)->Size())); + } else { + tile0_data_ = static_cast(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size())); + tile1_data_ = static_cast(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size())); + } + + if (tile0_data_ == nullptr || tile1_data_ == nullptr) { + MS_LOG(ERROR) << "malloc memroy fail!"; + return RET_ERROR; + } + } + return RET_OK; +} + +int SubInt8CPUKernel::DoExecute(int task_id) { + auto input0_data_ = static_cast(in_tensors_.at(0)->Data()); + auto input1_data_ = static_cast(in_tensors_.at(1)->Data()); + auto output_data_ = static_cast(out_tensors_.at(0)->Data()); + auto element_num = out_tensors_[0]->ElementsNum(); + + MS_ASSERT(op_parameter_->thread_num_ != 0); + int stride = UP_DIV(element_num, op_parameter_->thread_num_); + int count = MSMIN(stride, element_num - stride * task_id); + + auto ret = RET_OK; + if (broadcast_) { + ret = SubInt8(tile0_data_ + task_id * count, tile1_data_ + task_id * count, output_data_ + task_id * count, count, + ¶m_); + } else { + ret = SubInt8(input0_data_ + task_id * count, input1_data_ + task_id * count, output_data_ + task_id * count, count, + ¶m_); + } + + if (ret != RET_OK) { + MS_LOG(ERROR) << "Subint8 function error error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SubInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto sub_kernel = reinterpret_cast(cdata); + auto ret = sub_kernel->DoExecute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SubInt8 DoExecute error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SubInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + + if (broadcast_) { + ArithmeticParameter tile_para = {0}; + tile_para.ndim_ = out_tensors_.at(0)->shape().size(); + for (size_t i = 0; i < tile_para.ndim_; i++) { + tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); + tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i); + tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i); + } + TileDimensionsUint8(static_cast(in_tensors_.at(0)->Data()), + static_cast(in_tensors_.at(1)->Data()), reinterpret_cast(tile0_data_), + reinterpret_cast(tile1_data_), &tile_para); + } + ret = LiteBackendParallelLaunch(SubInt8Run, this, op_parameter_->thread_num_); + + if (ret != RET_OK) { + MS_LOG(ERROR) << "SubInt8Run function error error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuSubInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or ctx is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_Sub); + auto *kernel = new (std::nothrow) SubInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sub, CpuSubInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h new file mode 100644 index 0000000000..49ee856e35 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SUB_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SUB_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/int8/sub_int8.h" +#include "src/runtime/runtime_api.h" + +namespace mindspore::kernel { +class SubInt8CPUKernel : public LiteKernel { + public: + explicit SubInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~SubInt8CPUKernel() override {} + + int Init() override; + int ReSize() override; + int Run() override; + int DoExecute(int task_id); + + private: + SubQuantArg param_; + int8_t *tile0_data_ = nullptr; + int8_t *tile1_data_ = nullptr; + bool broadcast_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SUB_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.h index 28ba83720a..42549a118a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.h @@ -48,6 +48,8 @@ void AddInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int6 #ifdef ENABLE_NEON #include int16x8_t LoadAndAddOffset(int8_t *data, int index, int offset); +int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int32x4_t input_multiplier_vec, + int32x4_t right_shift_vec); #endif #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ADD_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/sub_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/sub_int8.cc new file mode 100644 index 0000000000..4bfbb236bd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/sub_int8.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/sub_int8.h" +#ifdef ENABLE_NEON +#include +#include "nnacl/add_int8.h" +#endif +#include "nnacl/quantization/fixed_point.h" + +#ifdef ENABLE_NEON + +int16x4_t ClacSumHalfWord(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, SubQuantArg *para) { + int32x4_t raw_data = vsubq_s32(scaled_input0, scaled_input1); + + raw_data = RoundingDivideByPOTInt32x4(vqrdmulhq_s32(vmulq_s32(raw_data, left_shift_out_vec), output_multiplier_vec), + para->right_shift_out_); + raw_data = vaddq_s32(raw_data, vdupq_n_s32(para->out_args_.zp_)); + raw_data = vmaxq_s32(raw_data, vdupq_n_s32(para->output_activation_min_)); + raw_data = vminq_s32(raw_data, vdupq_n_s32(para->output_activation_max_)); + return vqmovn_s32(raw_data); +} + +void SubInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + SubQuantArg *para, int *index) { + int32x4_t left_shift_result0_vec = vdupq_n_s32(para->left_shift_result0_); + int32x4_t left_shift_result1_vec = vdupq_n_s32(para->left_shift_result1_); + int32x4_t input0_multiplier_vec = vdupq_n_s32(para->input0_multiplier_); + int32x4_t input1_multiplier_vec = vdupq_n_s32(para->input1_multiplier_); + int32x4_t output_multiplier_vec = vdupq_n_s32(para->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32((1 << para->left_shift_out_)); + int32x4_t right_shift0_vec = vdupq_n_s32(-para->right_shift0_); + int32x4_t right_shift1_vec = vdupq_n_s32(-para->right_shift1_); + + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para->in0_args_.zp_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para->in1_args_.zp_); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int32x4_t scaled_input0_low = + ClacScaledInput(input0_low, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input0_high = + ClacScaledInput(input0_high, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input1_low = + ClacScaledInput(input1_low, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + int32x4_t scaled_input1_high = + ClacScaledInput(input1_high, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + + int16x4_t sum_low = + ClacSumHalfWord(scaled_input0_low, scaled_input1_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = + ClacSumHalfWord(scaled_input0_high, scaled_input1_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data + *index, res_u8_n0); + } +} +#endif + +int SubInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, SubQuantArg *para) { + int index = 0; +#ifdef ENABLE_NEON + SubInt8NEON(input0_data, input1_data, output_data, real_dst_count, para, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + const int32_t input1_val = para->in1_args_.zp_ + input1_data[index]; + const int32_t shifted_input0_val = input0_val * para->left_shift_result0_; + const int32_t shifted_input1_val = input1_val * para->left_shift_result1_; + const int32_t scaled_input0_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input0_val, para->input0_multiplier_), para->right_shift0_); + const int32_t scaled_input1_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input1_val, para->input1_multiplier_), para->right_shift1_); + + const int32_t raw_data = scaled_input0_val - scaled_input1_val; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data * (1 << (unsigned int)para->left_shift_out_), + para->output_multiplier_), + para->right_shift_out_) + + para->out_args_.zp_; + + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return 0; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/sub_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/sub_int8.h new file mode 100644 index 0000000000..60d8544196 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/sub_int8.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SUB_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SUB_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/quantization/quantize.h" + +int SubInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, SubQuantArg *para); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SUB_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h index 1473896a71..ada5292fc4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h @@ -173,6 +173,26 @@ typedef struct QuantMulArg { int right_shift_; } QuantMulArg; +typedef struct SubQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int input0_multiplier_; + int input1_multiplier_; + int output_multiplier_; + int input0_shift_; + int input1_shift_; + int output_shift_; + int left_shift_result0_; + int left_shift_result1_; + int right_shift0_; + int right_shift1_; + int left_shift_out_; + int right_shift_out_; +} SubQuantArg; + void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift); inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc new file mode 100644 index 0000000000..514e64c1f6 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc @@ -0,0 +1,1016 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" + +namespace mindspore { + +class TestArithmeticTestFp32 : public mindspore::CommonTest { + public: + TestArithmeticTestFp32() {} +}; + +TEST_F(TestArithmeticTestFp32, AddTest) { + auto add_param = new ArithmeticParameter(); + add_param->ndim_ = 4; + add_param->in_shape0_[0] = 1; + add_param->in_shape0_[1] = 2; + add_param->in_shape0_[2] = 3; + add_param->in_shape0_[3] = 4; + add_param->in_shape1_[0] = 1; + add_param->in_shape1_[1] = 1; + add_param->in_shape1_[2] = 1; + add_param->in_shape1_[3] = 4; + add_param->out_shape_[0] = 1; + add_param->out_shape_[1] = 2; + add_param->out_shape_[2] = 3; + add_param->out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector in = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, 14.727955, -8.080715, + 13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352, + 6.7263746, -3.9301445, 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + auto in_ptr = in.data(); + std::vector add = {0.9035316, 0.022212252, 0.3038014, 0.3478275}; + auto add_ptr = add.data(); + std::vector correct_out = {13.119816, 3.368904, 15.631221, 5.5827856, 1.7079077, 9.9744, + 15.031756, -7.7328877, 14.617362, 8.078041, 6.888335, -8.904492, + -3.3416586, 11.572254, 9.565813, 1.6258626, 7.629906, -3.9079323, + 4.0682936, -8.254251, -2.4522753, 13.641247, -2.365638, 3.548678}; + auto correct_out_ptr = correct_out.data(); + + int size = 1 * 2 * 3 * 4; + auto out = new float[size]; + + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + BroadcastAdd(in_ptr, add_ptr, tile_data0, tile_data1, out, size, add_param); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete add_param; +} + +TEST_F(TestArithmeticTestFp32, MulTest) { + auto mul_param = new ArithmeticParameter(); + mul_param->ndim_ = 4; + mul_param->in_shape0_[0] = 1; + mul_param->in_shape0_[1] = 2; + mul_param->in_shape0_[2] = 3; + mul_param->in_shape0_[3] = 4; + mul_param->in_shape1_[0] = 1; + mul_param->in_shape1_[1] = 1; + mul_param->in_shape1_[2] = 1; + mul_param->in_shape1_[3] = 4; + mul_param->out_shape_[0] = 1; + mul_param->out_shape_[1] = 2; + mul_param->out_shape_[2] = 3; + mul_param->out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector in = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, 14.727955, -8.080715, + 13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352, + 6.7263746, -3.9301445, 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + auto in_ptr = in.data(); + std::vector add = {0.16771512, 0.7336843, 0.6768286, 0.4453379}; + auto add_ptr = add.data(); + std::vector correct_out = {2.0488555, 2.4554152, 10.374036, 2.3313253, 0.13490601, 7.3017635, + 9.968302, -3.5986485, 2.3000166, 5.910435, 4.4566007, -4.120409, + -0.71198255, 8.474085, 6.2687945, 0.5691575, 1.1281147, -2.8834853, + 2.547916, -3.8308315, -0.56281954, 9.992072, -1.8067529, 1.42546}; + auto correct_out_ptr = correct_out.data(); + + int size = 1 * 2 * 3 * 4; + auto out = new float[size]; + + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + BroadcastMul(in_ptr, add_ptr, tile_data0, tile_data1, out, size, mul_param); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete mul_param; +} + +TEST_F(TestArithmeticTestFp32, DivTest) { + auto div_param = new ArithmeticParameter(); + div_param->ndim_ = 4; + div_param->in_shape0_[0] = 1; + div_param->in_shape0_[1] = 2; + div_param->in_shape0_[2] = 3; + div_param->in_shape0_[3] = 4; + div_param->in_shape1_[0] = 1; + div_param->in_shape1_[1] = 1; + div_param->in_shape1_[2] = 1; + div_param->in_shape1_[3] = 4; + div_param->out_shape_[0] = 1; + div_param->out_shape_[1] = 2; + div_param->out_shape_[2] = 3; + div_param->out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector in = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, 14.727955, -8.080715, + 13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352, + 6.7263746, -3.9301445, 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + auto in_ptr = in.data(); + std::vector add = {1.6771512, -7.336843, 0.6768286, 4.453379}; + auto add_ptr = add.data(); + std::vector correct_out = {7.28394912, -0.45614875, 22.64593872, 1.17550247, 0.47960852, -1.35646735, + 21.76024329, -1.8145132, 8.17685967, -1.09799665, 9.72850985, -2.07759546, + -2.53119099, -1.5742523, 13.68442764, 0.28698101, 4.01059523, 0.53567243, + 5.56195764, -1.93158453, -2.000897, -1.85625275, -3.94404034, 0.71874648}; + auto correct_out_ptr = correct_out.data(); + + int size = 1 * 1 * 3 * 4; + auto out = new float[size]; + + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + BroadcastDiv(in_ptr, add_ptr, tile_data0, tile_data1, out, size, div_param); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete div_param; +} + +TEST_F(TestArithmeticTestFp32, FloorDivTest) { + auto fdiv_param = new ArithmeticParameter(); + fdiv_param->ndim_ = 4; + fdiv_param->in_shape0_[0] = 1; + fdiv_param->in_shape0_[1] = 1; + fdiv_param->in_shape0_[2] = 3; + fdiv_param->in_shape0_[3] = 4; + fdiv_param->in_shape1_[0] = 1; + fdiv_param->in_shape1_[1] = 1; + fdiv_param->in_shape1_[2] = 1; + fdiv_param->in_shape1_[3] = 4; + fdiv_param->out_shape_[0] = 1; + fdiv_param->out_shape_[1] = 1; + fdiv_param->out_shape_[2] = 3; + fdiv_param->out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector in = {1.1, -1.1, 3.123, -5.432, 0.1234, -0.0312, 12.1, 21.1, 9.1, 9.0, -100, 0.1}; + auto in_ptr = in.data(); + std::vector add = {1, 3, 2, 0.3}; + auto add_ptr = add.data(); + std::vector correct_out = {1, -1, 1, -19, 0, -1, 6, 70, 9, 3, -50, 0}; + auto correct_out_ptr = correct_out.data(); + + int size = 1 * 1 * 3 * 4; + auto out = new float[size]; + + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + int ret = BroadcastFloorDiv(in_ptr, add_ptr, tile_data0, tile_data1, out, size, fdiv_param); + EXPECT_EQ(ret, 0); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete fdiv_param; +} + +TEST_F(TestArithmeticTestFp32, FloorModTest) { + auto fmod_param = new ArithmeticParameter(); + fmod_param->ndim_ = 4; + fmod_param->in_shape0_[0] = 1; + fmod_param->in_shape0_[1] = 1; + fmod_param->in_shape0_[2] = 3; + fmod_param->in_shape0_[3] = 4; + fmod_param->in_shape1_[0] = 1; + fmod_param->in_shape1_[1] = 1; + fmod_param->in_shape1_[2] = 1; + fmod_param->in_shape1_[3] = 4; + fmod_param->out_shape_[0] = 1; + fmod_param->out_shape_[1] = 1; + fmod_param->out_shape_[2] = 3; + fmod_param->out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector in = {1.1, -1.1, 3.123, -5.432, 0.1234, -0.0312, 12.1, 21.1, 9.1, 9.0, -100, 0.1}; + auto in_ptr = in.data(); + std::vector add = {1, 3, 2, 0.3}; + auto add_ptr = add.data(); + std::vector correct_out = {0.100000, 1.900000, 1.123000, 0.268000, 0.123400, 2.968800, + 0.100000, 0.100000, 0.100000, 0.000000, 0.000000, 0.100000}; + auto correct_out_ptr = correct_out.data(); + + int size = 1 * 1 * 3 * 4; + auto out = new float[size]; + + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + int ret = BroadcastFloorMod(in_ptr, add_ptr, tile_data0, tile_data1, out, size, fmod_param); + EXPECT_EQ(ret, 0); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete fmod_param; +} + +TEST_F(TestArithmeticTestFp32, LogicalAndTest) { + auto logical_and_param = new ArithmeticParameter(); + logical_and_param->ndim_ = 4; + logical_and_param->in_shape0_[0] = 1; + logical_and_param->in_shape0_[1] = 2; + logical_and_param->in_shape0_[2] = 3; + logical_and_param->in_shape0_[3] = 4; + logical_and_param->in_shape1_[0] = 1; + logical_and_param->in_shape1_[1] = 1; + logical_and_param->in_shape1_[2] = 1; + logical_and_param->in_shape1_[3] = 4; + logical_and_param->out_shape_[0] = 1; + logical_and_param->out_shape_[1] = 2; + logical_and_param->out_shape_[2] = 3; + logical_and_param->out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector in = {12.216284, 3.3466918, 15.327419, 5.234958, 0, 9.952188, 14.727955, -8.080715, + 13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352, + 6.7263746, -3.9301445, 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + auto in_ptr = in.data(); + std::vector add = {1.6771512, -7.336843, 0, 4.453379}; + auto add_ptr = add.data(); + std::vector correct_out = {1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1}; + auto correct_out_ptr = correct_out.data(); + int size = 1 * 2 * 3 * 4; + + auto out = new float[size]; + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + BroadcastLogicalAnd(in_ptr, add_ptr, tile_data0, tile_data1, out, size, logical_and_param); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete logical_and_param; +} + +TEST_F(TestArithmeticTestFp32, LogicalOrTest) { + auto logical_or_param = new ArithmeticParameter(); + logical_or_param->ndim_ = 4; + logical_or_param->in_shape0_[0] = 1; + logical_or_param->in_shape0_[1] = 2; + logical_or_param->in_shape0_[2] = 3; + logical_or_param->in_shape0_[3] = 4; + logical_or_param->in_shape1_[0] = 1; + logical_or_param->in_shape1_[1] = 1; + logical_or_param->in_shape1_[2] = 1; + logical_or_param->in_shape1_[3] = 4; + logical_or_param->out_shape_[0] = 1; + logical_or_param->out_shape_[1] = 2; + logical_or_param->out_shape_[2] = 3; + logical_or_param->out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector in = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 0, 14.727955, -8.080715, + 13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352, + 6.7263746, -3.9301445, 3.764492, 0, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + + auto in_ptr = in.data(); + std::vector add = {1.6771512, 0, 0.6768286, 0}; + auto add_ptr = add.data(); + std::vector correct_out = {1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1}; + auto correct_out_ptr = correct_out.data(); + + int size = 1 * 2 * 3 * 4; + + auto out = new float[size]; + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + BroadcastLogicalOr(in_ptr, add_ptr, tile_data0, tile_data1, out, size, logical_or_param); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete logical_or_param; +} + +TEST_F(TestArithmeticTestFp32, MaximumTest) { + auto maximum_param = new ArithmeticParameter(); + maximum_param->ndim_ = 4; + maximum_param->in_shape0_[0] = 1; + maximum_param->in_shape0_[1] = 2; + maximum_param->in_shape0_[2] = 3; + maximum_param->in_shape0_[3] = 4; + maximum_param->in_shape1_[0] = 1; + maximum_param->in_shape1_[1] = 1; + maximum_param->in_shape1_[2] = 1; + maximum_param->in_shape1_[3] = 4; + maximum_param->out_shape_[0] = 1; + maximum_param->out_shape_[1] = 2; + maximum_param->out_shape_[2] = 3; + maximum_param->out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector in = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 0, 14.727955, -8.080715, + 13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352, + 6.7263746, -3.9301445, 3.764492, 0, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + + auto in_ptr = in.data(); + std::vector add = {1.6771512, 6.34876, 3.6768286, 2.936284}; + auto add_ptr = add.data(); + std::vector correct_out = {12.216284, 6.34876, 15.327419, 5.234958, 1.6771512, 6.34876, + 14.727955, 2.936284, 13.71383, 8.055829, 6.5845337, 2.936284, + 1.6771512, 11.550042, 9.262012, 2.936284, 6.7263746, 6.34876, + 3.764492, 2.93628, 1.6771512, 13.619035, 3.6768286, 3.2008505}; + auto correct_out_ptr = correct_out.data(); + + int size = 1 * 2 * 3 * 4; + + auto out = new float[size]; + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + BroadcastMaximum(in_ptr, add_ptr, tile_data0, tile_data1, out, size, maximum_param); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete maximum_param; +} + +TEST_F(TestArithmeticTestFp32, MinimumTest) { + auto minimum_param = new ArithmeticParameter(); + minimum_param->ndim_ = 4; + minimum_param->in_shape0_[0] = 1; + minimum_param->in_shape0_[1] = 2; + minimum_param->in_shape0_[2] = 3; + minimum_param->in_shape0_[3] = 4; + minimum_param->in_shape1_[0] = 1; + minimum_param->in_shape1_[1] = 1; + minimum_param->in_shape1_[2] = 1; + minimum_param->in_shape1_[3] = 4; + minimum_param->out_shape_[0] = 1; + minimum_param->out_shape_[1] = 2; + minimum_param->out_shape_[2] = 3; + minimum_param->out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector in = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 0, 14.727955, -8.080715, + 13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352, + 6.7263746, -3.9301445, 3.764492, 0, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + + auto in_ptr = in.data(); + std::vector add = {1.6771512, 6.34876, 3.6768286, 2.936284}; + auto add_ptr = add.data(); + std::vector correct_out = {1.6771512, 3.3466918, 3.6768286, 2.936284, 0.804376, 0, + 3.6768286, -8.080715, 1.6771512, 6.34876, 3.6768286, -9.25232, + -4.24519, 6.34876, 3.6768286, 1.2780352, 1.6771512, -3.9301445, + 3.6768286, 0, -3.3558068, 6.34876, -2.6694393, 2.936284}; + auto correct_out_ptr = correct_out.data(); + + int size = 1 * 2 * 3 * 4; + + auto out = new float[size]; + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + BroadcastMinimum(in_ptr, add_ptr, tile_data0, tile_data1, out, size, minimum_param); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete minimum_param; +} + +TEST_F(TestArithmeticTestFp32, SquaredDifferenceTest) { + auto add_param = new ArithmeticParameter(); + add_param->ndim_ = 3; + add_param->in_shape0_[0] = 2; + add_param->in_shape0_[1] = 3; + add_param->in_shape0_[2] = 2; + add_param->in_shape1_[0] = 2; + add_param->in_shape1_[1] = 1; + add_param->in_shape1_[2] = 2; + add_param->out_shape_[0] = 2; + add_param->out_shape_[1] = 3; + add_param->out_shape_[2] = 2; + + /* 1x2x3x4 NHWC */ + std::vector in = {10, 11, 12, 13, 14, 15, 20, 21, 22, 23, 24, 25}; + auto in_ptr = in.data(); + std::vector add = {30, 31, 32, 33}; + auto add_ptr = add.data(); + std::vector correct_out = {400, 400, 324, 324, 256, 256, 144, 144, 100, 100, 64, 64}; + auto correct_out_ptr = correct_out.data(); + + int size = 2 * 3 * 2; + auto out = new float[size]; + + auto tile_data0 = new float[size]; + auto tile_data1 = new float[size]; + BroadcastSub(in_ptr, add_ptr, tile_data0, tile_data1, out, size, add_param); + ElementMul(out, out, out, size); + CompareOutputData(out, correct_out_ptr, size, 0.00001); + + delete[] out; + delete[] tile_data0; + delete[] tile_data1; + delete add_param; +} + +TEST_F(TestArithmeticTestFp32, MulFp32) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + ArithmeticParameter mul_param; + mul_param.broadcasting_ = true; + mul_param.op_parameter_.type_ = schema::PrimitiveType_Mul; + mul_param.ndim_ = 4; + mul_param.in_shape0_[0] = 1; + mul_param.in_shape0_[1] = 2; + mul_param.in_shape0_[2] = 3; + mul_param.in_shape0_[3] = 4; + mul_param.in_shape1_[0] = 1; + mul_param.in_shape1_[1] = 1; + mul_param.in_shape1_[2] = 1; + mul_param.in_shape1_[3] = 4; + mul_param.out_shape_[0] = 1; + mul_param.out_shape_[1] = 2; + mul_param.out_shape_[2] = 3; + mul_param.out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector input0 = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, + 14.727955, -8.080715, 13.71383, 8.055829, 6.5845337, -9.25232, + -4.24519, 11.550042, 9.262012, 1.2780352, 6.7263746, -3.9301445, + 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + std::vector input0_shape = {1, 2, 3, 4}; + std::vector input1 = {0.16771512, 0.7336843, 0.6768286, 0.4453379}; + std::vector input1_shape = {1, 1, 1, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + input0_tensor.set_data_type(kNumberTypeFloat32); + input0_tensor.SetData(input0.data()); + input1_tensor.SetData(input1.data()); + input0_tensor.set_shape(input0_shape); + input1_tensor.set_shape(input1_shape); + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.set_shape(output_shape); + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Eltwise}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&mul_param), &ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector correct_out = {2.0488555, 2.4554152, 10.374036, 2.3313253, 0.13490601, 7.3017635, + 9.968302, -3.5986485, 2.3000166, 5.910435, 4.4566007, -4.120409, + -0.71198255, 8.474085, 6.2687945, 0.5691575, 1.1281147, -2.8834853, + 2.547916, -3.8308315, -0.56281954, 9.992072, -1.8067529, 1.42546}; + auto correct_out_ptr = correct_out.data(); + + CompareOutputData(output.data(), correct_out_ptr, 24, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} + +TEST_F(TestArithmeticTestFp32, MulReluFp32) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + ArithmeticParameter mul_param; + mul_param.broadcasting_ = true; + mul_param.op_parameter_.type_ = schema::PrimitiveType_Mul; + mul_param.ndim_ = 4; + mul_param.activation_type_ = schema::ActivationType_RELU; + mul_param.in_shape0_[0] = 1; + mul_param.in_shape0_[1] = 2; + mul_param.in_shape0_[2] = 3; + mul_param.in_shape0_[3] = 4; + mul_param.in_shape1_[0] = 1; + mul_param.in_shape1_[1] = 1; + mul_param.in_shape1_[2] = 1; + mul_param.in_shape1_[3] = 4; + mul_param.out_shape_[0] = 1; + mul_param.out_shape_[1] = 2; + mul_param.out_shape_[2] = 3; + mul_param.out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector input0 = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, + 14.727955, -8.080715, 13.71383, 8.055829, 6.5845337, -9.25232, + -4.24519, 11.550042, 9.262012, 1.2780352, 6.7263746, -3.9301445, + 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + std::vector input0_shape = {1, 2, 3, 4}; + std::vector input1 = {0.16771512, 0.7336843, 0.6768286, 0.4453379}; + std::vector input1_shape = {1, 1, 1, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + input0_tensor.set_data_type(kNumberTypeFloat32); + input0_tensor.SetData(input0.data()); + input1_tensor.SetData(input1.data()); + input0_tensor.set_shape(input0_shape); + input1_tensor.set_shape(input1_shape); + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.set_shape(output_shape); + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Eltwise}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&mul_param), &ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector correct_out = {2.0488555, 2.4554152, 10.374036, 2.3313253, 0.13490601, 7.3017635, + 9.968302, 0, 2.3000166, 5.910435, 4.4566007, 0, + 0, 8.474085, 6.2687945, 0.5691575, 1.1281147, 0, + 2.547916, 0, 0, 9.992072, 0, 1.42546}; + auto correct_out_ptr = correct_out.data(); + + CompareOutputData(output.data(), correct_out_ptr, 24, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} + +TEST_F(TestArithmeticTestFp32, MulRelu6Fp32) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + ArithmeticParameter mul_param; + mul_param.broadcasting_ = true; + mul_param.op_parameter_.type_ = schema::PrimitiveType_Mul; + mul_param.ndim_ = 4; + mul_param.activation_type_ = schema::ActivationType_RELU6; + mul_param.in_shape0_[0] = 1; + mul_param.in_shape0_[1] = 2; + mul_param.in_shape0_[2] = 3; + mul_param.in_shape0_[3] = 4; + mul_param.in_shape1_[0] = 1; + mul_param.in_shape1_[1] = 1; + mul_param.in_shape1_[2] = 1; + mul_param.in_shape1_[3] = 4; + mul_param.out_shape_[0] = 1; + mul_param.out_shape_[1] = 2; + mul_param.out_shape_[2] = 3; + mul_param.out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector input0 = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, + 14.727955, -8.080715, 13.71383, 8.055829, 6.5845337, -9.25232, + -4.24519, 11.550042, 9.262012, 1.2780352, 6.7263746, -3.9301445, + 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + std::vector input0_shape = {1, 2, 3, 4}; + std::vector input1 = {0.16771512, 0.7336843, 0.6768286, 0.4453379}; + std::vector input1_shape = {1, 1, 1, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + input0_tensor.set_data_type(kNumberTypeFloat32); + input0_tensor.SetData(input0.data()); + input1_tensor.SetData(input1.data()); + input0_tensor.set_shape(input0_shape); + input1_tensor.set_shape(input1_shape); + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.set_shape(output_shape); + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Eltwise}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&mul_param), &ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector correct_out = {2.0488555, 2.4554152, 6, 2.3313253, 0.13490601, 6, 6, 0, + 2.3000166, 5.910435, 4.4566007, 0, 0, 6, 6, 0.5691575, + 1.1281147, 0, 2.547916, 0, 0, 6, 0, 1.42546}; + auto correct_out_ptr = correct_out.data(); + + CompareOutputData(output.data(), correct_out_ptr, 24, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} + +TEST_F(TestArithmeticTestFp32, AddReluFp32) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + ArithmeticParameter add_param; + add_param.broadcasting_ = true; + add_param.op_parameter_.type_ = schema::PrimitiveType_Add; + add_param.ndim_ = 4; + add_param.activation_type_ = schema::ActivationType_RELU; + add_param.in_shape0_[0] = 1; + add_param.in_shape0_[1] = 2; + add_param.in_shape0_[2] = 3; + add_param.in_shape0_[3] = 4; + add_param.in_shape1_[0] = 1; + add_param.in_shape1_[1] = 1; + add_param.in_shape1_[2] = 1; + add_param.in_shape1_[3] = 4; + add_param.out_shape_[0] = 1; + add_param.out_shape_[1] = 2; + add_param.out_shape_[2] = 3; + add_param.out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector input0 = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, + 14.727955, -8.080715, 13.71383, 8.055829, 6.5845337, -9.25232, + -4.24519, 11.550042, 9.262012, 1.2780352, 6.7263746, -3.9301445, + 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + std::vector input0_shape = {1, 2, 3, 4}; + std::vector input1 = {0.9035316, 0.022212252, 0.3038014, 0.3478275}; + std::vector input1_shape = {1, 1, 1, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + input0_tensor.set_data_type(kNumberTypeFloat32); + input0_tensor.SetData(input0.data()); + input1_tensor.SetData(input1.data()); + input0_tensor.set_shape(input0_shape); + input1_tensor.set_shape(input1_shape); + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.set_shape(output_shape); + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Eltwise}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&add_param), &ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector correct_out = { + 13.119816, 3.368904, 15.631221, 5.5827856, 1.7079077, 9.9744, 15.031756, 0, 14.617362, 8.078041, 6.888335, 0, 0, + 11.572254, 9.565813, 1.6258626, 7.629906, 0, 4.0682936, 0, 0, 13.641247, 0, 3.548678}; + auto correct_out_ptr = correct_out.data(); + + CompareOutputData(output.data(), correct_out_ptr, 24, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} + +TEST_F(TestArithmeticTestFp32, AddRelu6Fp32) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + ArithmeticParameter add_param; + add_param.broadcasting_ = true; + add_param.op_parameter_.type_ = schema::PrimitiveType_Add; + add_param.ndim_ = 4; + add_param.activation_type_ = schema::ActivationType_RELU6; + add_param.in_shape0_[0] = 1; + add_param.in_shape0_[1] = 2; + add_param.in_shape0_[2] = 3; + add_param.in_shape0_[3] = 4; + add_param.in_shape1_[0] = 1; + add_param.in_shape1_[1] = 1; + add_param.in_shape1_[2] = 1; + add_param.in_shape1_[3] = 4; + add_param.out_shape_[0] = 1; + add_param.out_shape_[1] = 2; + add_param.out_shape_[2] = 3; + add_param.out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector input0 = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, + 14.727955, -8.080715, 13.71383, 8.055829, 6.5845337, -9.25232, + -4.24519, 11.550042, 9.262012, 1.2780352, 6.7263746, -3.9301445, + 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + std::vector input0_shape = {1, 2, 3, 4}; + std::vector input1 = {0.9035316, 0.022212252, 0.3038014, 0.3478275}; + std::vector input1_shape = {1, 1, 1, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + input0_tensor.set_data_type(kNumberTypeFloat32); + input0_tensor.SetData(input0.data()); + input1_tensor.SetData(input1.data()); + input0_tensor.set_shape(input0_shape); + input1_tensor.set_shape(input1_shape); + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.set_shape(output_shape); + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Eltwise}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&add_param), &ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector correct_out = {6, 3.368904, 6, 5.5827856, 1.7079077, 6, 6, 0, 6, 6, 6, 0, + 0, 6, 6, 1.6258626, 6, 0, 4.0682936, 0, 0, 6, 0, 3.548678}; + auto correct_out_ptr = correct_out.data(); + + CompareOutputData(output.data(), correct_out_ptr, 24, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} + +TEST_F(TestArithmeticTestFp32, DivReluFp32) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + ArithmeticParameter div_param; + div_param.broadcasting_ = true; + div_param.op_parameter_.type_ = schema::PrimitiveType_Div; + div_param.ndim_ = 4; + div_param.activation_type_ = schema::ActivationType_RELU; + div_param.in_shape0_[0] = 1; + div_param.in_shape0_[1] = 2; + div_param.in_shape0_[2] = 3; + div_param.in_shape0_[3] = 4; + div_param.in_shape1_[0] = 1; + div_param.in_shape1_[1] = 1; + div_param.in_shape1_[2] = 1; + div_param.in_shape1_[3] = 4; + div_param.out_shape_[0] = 1; + div_param.out_shape_[1] = 2; + div_param.out_shape_[2] = 3; + div_param.out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector input0 = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, + 14.727955, -8.080715, 13.71383, 8.055829, 6.5845337, -9.25232, + -4.24519, 11.550042, 9.262012, 1.2780352, 6.7263746, -3.9301445, + 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + std::vector input0_shape = {1, 2, 3, 4}; + std::vector input1 = {1.6771512, -7.336843, 0.6768286, 4.453379}; + std::vector input1_shape = {1, 1, 1, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + input0_tensor.set_data_type(kNumberTypeFloat32); + input0_tensor.SetData(input0.data()); + input1_tensor.SetData(input1.data()); + input0_tensor.set_shape(input0_shape); + input1_tensor.set_shape(input1_shape); + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.set_shape(output_shape); + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Eltwise}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&div_param), &ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector correct_out = {7.28394912, 0, 22.64593872, 1.17550247, 0.47960852, 0, + 21.76024329, 0, 8.17685967, 0, 9.72850985, 0, + 0, 0, 13.68442764, 0.28698101, 4.01059523, 0.53567243, + 5.56195764, 0, 0, 0, 0, 0.71874648}; + auto correct_out_ptr = correct_out.data(); + + CompareOutputData(output.data(), correct_out_ptr, 24, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} + +TEST_F(TestArithmeticTestFp32, DivRelu6Fp32) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + ArithmeticParameter div_param; + div_param.broadcasting_ = true; + div_param.op_parameter_.type_ = schema::PrimitiveType_Div; + div_param.ndim_ = 4; + div_param.activation_type_ = schema::ActivationType_RELU6; + div_param.in_shape0_[0] = 1; + div_param.in_shape0_[1] = 2; + div_param.in_shape0_[2] = 3; + div_param.in_shape0_[3] = 4; + div_param.in_shape1_[0] = 1; + div_param.in_shape1_[1] = 1; + div_param.in_shape1_[2] = 1; + div_param.in_shape1_[3] = 4; + div_param.out_shape_[0] = 1; + div_param.out_shape_[1] = 2; + div_param.out_shape_[2] = 3; + div_param.out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector input0 = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, + 14.727955, -8.080715, 13.71383, 8.055829, 6.5845337, -9.25232, + -4.24519, 11.550042, 9.262012, 1.2780352, 6.7263746, -3.9301445, + 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + std::vector input0_shape = {1, 2, 3, 4}; + std::vector input1 = {1.6771512, -7.336843, 0.6768286, 4.453379}; + std::vector input1_shape = {1, 1, 1, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + input0_tensor.set_data_type(kNumberTypeFloat32); + input0_tensor.SetData(input0.data()); + input1_tensor.SetData(input1.data()); + input0_tensor.set_shape(input0_shape); + input1_tensor.set_shape(input1_shape); + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.set_shape(output_shape); + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Eltwise}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&div_param), &ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector correct_out = {6, 0, 6, 1.17550247, 0.47960852, 0, 6, 0, 6, 0, 6, 0, + 0, 0, 6, 0.28698101, 4.01059523, 0.53567243, 5.56195764, 0, 0, 0, 0, 0.71874648}; + auto correct_out_ptr = correct_out.data(); + + CompareOutputData(output.data(), correct_out_ptr, 24, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} + +TEST_F(TestArithmeticTestFp32, EqualFp32) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + ArithmeticParameter equal_param; + equal_param.broadcasting_ = true; + equal_param.op_parameter_.type_ = schema::PrimitiveType_Equal; + equal_param.ndim_ = 4; + equal_param.in_shape0_[0] = 1; + equal_param.in_shape0_[1] = 2; + equal_param.in_shape0_[2] = 3; + equal_param.in_shape0_[3] = 4; + equal_param.in_shape1_[0] = 1; + equal_param.in_shape1_[1] = 1; + equal_param.in_shape1_[2] = 1; + equal_param.in_shape1_[3] = 4; + equal_param.out_shape_[0] = 1; + equal_param.out_shape_[1] = 2; + equal_param.out_shape_[2] = 3; + equal_param.out_shape_[3] = 4; + + /* 1x2x3x4 NHWC */ + std::vector input0 = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, + 14.727955, -8.080715, 13.71383, 8.055829, 6.5845337, -9.25232, + -4.24519, 11.550042, 9.262012, 1.2780352, 6.7263746, -3.9301445, + 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + std::vector input0_shape = {1, 2, 3, 4}; + std::vector input1 = {0.16771512, 3.3466918, 0.6768286, 3.2008505}; + std::vector input1_shape = {1, 1, 1, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + input0_tensor.set_data_type(kNumberTypeFloat32); + input0_tensor.SetData(input0.data()); + input1_tensor.SetData(input1.data()); + input0_tensor.set_shape(input0_shape); + input1_tensor.set_shape(input1_shape); + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.set_shape(output_shape); + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Eltwise}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&equal_param), &ctx, desc, nullptr); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector correct_out = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; + auto correct_out_ptr = correct_out.data(); + + CompareOutputData(output.data(), correct_out_ptr, 24, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sub_int_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sub_int_tests.cc new file mode 100644 index 0000000000..28966e03b9 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sub_int_tests.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/include/context.h" + +namespace mindspore { +class TestSubInt8 : public mindspore::CommonTest { + public: + TestSubInt8() {} +}; + +TEST_F(TestSubInt8, SubInt8) { + lite::tensor::Tensor in_tensor0(kNumberTypeInt8, {1, 1, 2, 5}); + lite::tensor::Tensor in_tensor1(kNumberTypeInt8, {1, 1, 1, 5}); + lite::tensor::Tensor out_tensor(kNumberTypeInt8, {1, 1, 2, 5}); + + int8_t input_data0[] = {105, 35, -27, 0, -63, 99, 16, 122, 67, -49}; + int8_t input_data1[] = {24, -38, -115, 106, -98}; + int8_t output_data[10] = {0}; + in_tensor0.SetData(input_data0); + in_tensor1.SetData(input_data1); + out_tensor.SetData(output_data); + + const lite::tensor::QuantArg quant_in0 = {0.00784314f, 0}; // -1.0--1.0 -> 0--255 + const lite::tensor::QuantArg quant_in1 = {0.00784314f, 0}; + const lite::tensor::QuantArg quant_out = {0.00784314f, 0}; + in_tensor0.AddQuantParam(quant_in0); + in_tensor1.AddQuantParam(quant_in1); + out_tensor.AddQuantParam(quant_out); + + std::vector inputs = {&in_tensor0, &in_tensor1}; + std::vector outputs = {&out_tensor}; + + OpParameter parameter = {}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Sub}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto ctx = std::make_shared(); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + int8_t expect0[10] = {81, 73, 88, -106, 35, 75, 54, 127, -39, 49}; + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(output_data[i], expect0[i]); + } + + in_tensor0.SetData(nullptr); + in_tensor1.SetData(nullptr); + out_tensor.SetData(nullptr); +} +} // namespace mindspore