!4216 Add arm op Sub supporting int8 and testcases
Merge pull request !4216 from wangminggui/masterpull/4216/MERGE
commit
c0215d4445
@ -0,0 +1,205 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/sub_int8.h"
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Sub;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int SubInt8CPUKernel::Init() {
|
||||
lite::tensor::Tensor *input0 = in_tensors_.at(0);
|
||||
lite::tensor::Tensor *input1 = in_tensors_.at(1);
|
||||
lite::tensor::Tensor *output = out_tensors_.at(0);
|
||||
MS_ASSERT(input0);
|
||||
MS_ASSERT(input1);
|
||||
MS_ASSERT(output);
|
||||
|
||||
broadcast_ = input0->ElementsNum() != input1->ElementsNum();
|
||||
|
||||
param_.in0_args_.scale_ = input0->GetQuantParams().front().scale;
|
||||
param_.in0_args_.zp_ = -input0->GetQuantParams().front().zeroPoint;
|
||||
param_.in1_args_.scale_ = input1->GetQuantParams().front().scale;
|
||||
param_.in1_args_.zp_ = -input1->GetQuantParams().front().zeroPoint;
|
||||
param_.out_args_.scale_ = output->GetQuantParams().front().scale;
|
||||
param_.out_args_.zp_ = output->GetQuantParams().front().zeroPoint;
|
||||
|
||||
const int left_shift = 20;
|
||||
const double twice_max_input_scale = 2 * std::max(param_.in0_args_.scale_, param_.in1_args_.scale_);
|
||||
const double real_input0_multiplier = param_.in0_args_.scale_ / twice_max_input_scale;
|
||||
const double real_input1_multiplier = param_.in1_args_.scale_ / twice_max_input_scale;
|
||||
const double real_output_multiplier = twice_max_input_scale / ((1 << left_shift) * param_.out_args_.scale_);
|
||||
|
||||
QuantizeMultiplierSmallerThanOne(real_input0_multiplier, ¶m_.input0_multiplier_, ¶m_.input0_shift_);
|
||||
QuantizeMultiplierSmallerThanOne(real_input1_multiplier, ¶m_.input1_multiplier_, ¶m_.input1_shift_);
|
||||
QuantizeMultiplierSmallerThanOne(real_output_multiplier, ¶m_.output_multiplier_, ¶m_.output_shift_);
|
||||
|
||||
param_.output_activation_min_ = std::numeric_limits<int8_t>::min();
|
||||
param_.output_activation_max_ = std::numeric_limits<int8_t>::max();
|
||||
|
||||
int left_shift0 = -param_.input0_shift_ > 0 ? -param_.input0_shift_ : 0;
|
||||
param_.right_shift0_ = -param_.input0_shift_ > 0 ? 0 : param_.input0_shift_;
|
||||
|
||||
int left_shift1 = -param_.input1_shift_ > 0 ? -param_.input1_shift_ : 0;
|
||||
param_.right_shift1_ = -param_.input1_shift_ > 0 ? 0 : param_.input1_shift_;
|
||||
|
||||
param_.left_shift_out_ = -param_.output_shift_ > 0 ? -param_.output_shift_ : 0;
|
||||
param_.right_shift_out_ = -param_.output_shift_ > 0 ? 0 : param_.output_shift_;
|
||||
|
||||
param_.left_shift_result0_ = (1 << left_shift) * ((1 << left_shift0));
|
||||
param_.left_shift_result1_ = (1 << left_shift) * ((1 << left_shift1));
|
||||
|
||||
MS_ASSERT(left_shift + left_shift0 == left_shift);
|
||||
MS_ASSERT(left_shift + left_shift1 == left_shift);
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int SubInt8CPUKernel::ReSize() {
|
||||
if (broadcast_) {
|
||||
if (tile0_data_ != nullptr) {
|
||||
if (context_ != nullptr && context_->allocator != nullptr) {
|
||||
context_->allocator->Free(tile0_data_);
|
||||
} else {
|
||||
free(tile0_data_);
|
||||
}
|
||||
}
|
||||
if (tile1_data_ != nullptr) {
|
||||
if (context_ != nullptr && context_->allocator != nullptr) {
|
||||
context_->allocator->Free(tile1_data_);
|
||||
} else {
|
||||
free(tile1_data_);
|
||||
}
|
||||
}
|
||||
|
||||
if (context_ != nullptr && context_->allocator != nullptr) {
|
||||
tile0_data_ = static_cast<int8_t *>(context_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
tile1_data_ = static_cast<int8_t *>(context_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
} else {
|
||||
tile0_data_ = static_cast<int8_t *>(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size()));
|
||||
tile1_data_ = static_cast<int8_t *>(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size()));
|
||||
}
|
||||
|
||||
if (tile0_data_ == nullptr || tile1_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc memroy fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SubInt8CPUKernel::DoExecute(int task_id) {
|
||||
auto input0_data_ = static_cast<int8_t *>(in_tensors_.at(0)->Data());
|
||||
auto input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->Data());
|
||||
auto output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->Data());
|
||||
auto element_num = out_tensors_[0]->ElementsNum();
|
||||
|
||||
MS_ASSERT(op_parameter_->thread_num_ != 0);
|
||||
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
|
||||
int count = MSMIN(stride, element_num - stride * task_id);
|
||||
|
||||
auto ret = RET_OK;
|
||||
if (broadcast_) {
|
||||
ret = SubInt8(tile0_data_ + task_id * count, tile1_data_ + task_id * count, output_data_ + task_id * count, count,
|
||||
¶m_);
|
||||
} else {
|
||||
ret = SubInt8(input0_data_ + task_id * count, input1_data_ + task_id * count, output_data_ + task_id * count, count,
|
||||
¶m_);
|
||||
}
|
||||
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Subint8 function error error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SubInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto sub_kernel = reinterpret_cast<SubInt8CPUKernel *>(cdata);
|
||||
auto ret = sub_kernel->DoExecute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SubInt8 DoExecute error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SubInt8CPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (broadcast_) {
|
||||
ArithmeticParameter tile_para = {0};
|
||||
tile_para.ndim_ = out_tensors_.at(0)->shape().size();
|
||||
for (size_t i = 0; i < tile_para.ndim_; i++) {
|
||||
tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
|
||||
tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
|
||||
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
|
||||
}
|
||||
TileDimensionsUint8(static_cast<uint8_t *>(in_tensors_.at(0)->Data()),
|
||||
static_cast<uint8_t *>(in_tensors_.at(1)->Data()), reinterpret_cast<uint8_t *>(tile0_data_),
|
||||
reinterpret_cast<uint8_t *>(tile1_data_), &tile_para);
|
||||
}
|
||||
ret = LiteBackendParallelLaunch(SubInt8Run, this, op_parameter_->thread_num_);
|
||||
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SubInt8Run function error error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuSubInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::Context *ctx, const KernelKey &desc,
|
||||
const lite::Primitive *primitive) {
|
||||
if (parameter == nullptr || ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter or ctx is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(desc.type == PrimitiveType_Sub);
|
||||
auto *kernel = new (std::nothrow) SubInt8CPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sub, CpuSubInt8KernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SUB_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SUB_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/int8/sub_int8.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class SubInt8CPUKernel : public LiteKernel {
|
||||
public:
|
||||
explicit SubInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~SubInt8CPUKernel() override {}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExecute(int task_id);
|
||||
|
||||
private:
|
||||
SubQuantArg param_;
|
||||
int8_t *tile0_data_ = nullptr;
|
||||
int8_t *tile1_data_ = nullptr;
|
||||
bool broadcast_ = false;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SUB_INT8_H_
|
@ -0,0 +1,104 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/int8/sub_int8.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#include "nnacl/add_int8.h"
|
||||
#endif
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
|
||||
int16x4_t ClacSumHalfWord(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec,
|
||||
int32x4_t output_multiplier_vec, SubQuantArg *para) {
|
||||
int32x4_t raw_data = vsubq_s32(scaled_input0, scaled_input1);
|
||||
|
||||
raw_data = RoundingDivideByPOTInt32x4(vqrdmulhq_s32(vmulq_s32(raw_data, left_shift_out_vec), output_multiplier_vec),
|
||||
para->right_shift_out_);
|
||||
raw_data = vaddq_s32(raw_data, vdupq_n_s32(para->out_args_.zp_));
|
||||
raw_data = vmaxq_s32(raw_data, vdupq_n_s32(para->output_activation_min_));
|
||||
raw_data = vminq_s32(raw_data, vdupq_n_s32(para->output_activation_max_));
|
||||
return vqmovn_s32(raw_data);
|
||||
}
|
||||
|
||||
void SubInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count,
|
||||
SubQuantArg *para, int *index) {
|
||||
int32x4_t left_shift_result0_vec = vdupq_n_s32(para->left_shift_result0_);
|
||||
int32x4_t left_shift_result1_vec = vdupq_n_s32(para->left_shift_result1_);
|
||||
int32x4_t input0_multiplier_vec = vdupq_n_s32(para->input0_multiplier_);
|
||||
int32x4_t input1_multiplier_vec = vdupq_n_s32(para->input1_multiplier_);
|
||||
int32x4_t output_multiplier_vec = vdupq_n_s32(para->output_multiplier_);
|
||||
int32x4_t left_shift_out_vec = vdupq_n_s32((1 << para->left_shift_out_));
|
||||
int32x4_t right_shift0_vec = vdupq_n_s32(-para->right_shift0_);
|
||||
int32x4_t right_shift1_vec = vdupq_n_s32(-para->right_shift1_);
|
||||
|
||||
for (; (*index) <= real_dst_count - 8; (*index) += 8) {
|
||||
int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para->in0_args_.zp_);
|
||||
int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para->in1_args_.zp_);
|
||||
|
||||
int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val));
|
||||
int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val));
|
||||
int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val));
|
||||
int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val));
|
||||
|
||||
int32x4_t scaled_input0_low =
|
||||
ClacScaledInput(input0_low, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec);
|
||||
int32x4_t scaled_input0_high =
|
||||
ClacScaledInput(input0_high, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec);
|
||||
int32x4_t scaled_input1_low =
|
||||
ClacScaledInput(input1_low, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec);
|
||||
int32x4_t scaled_input1_high =
|
||||
ClacScaledInput(input1_high, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec);
|
||||
|
||||
int16x4_t sum_low =
|
||||
ClacSumHalfWord(scaled_input0_low, scaled_input1_low, left_shift_out_vec, output_multiplier_vec, para);
|
||||
int16x4_t sum_high =
|
||||
ClacSumHalfWord(scaled_input0_high, scaled_input1_high, left_shift_out_vec, output_multiplier_vec, para);
|
||||
|
||||
int16x8_t res_s16 = vcombine_s16(sum_low, sum_high);
|
||||
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
|
||||
vst1_s8(output_data + *index, res_u8_n0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
int SubInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, SubQuantArg *para) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
SubInt8NEON(input0_data, input1_data, output_data, real_dst_count, para, &index);
|
||||
#endif
|
||||
for (; index < real_dst_count; ++index) {
|
||||
const int32_t input0_val = para->in0_args_.zp_ + input0_data[index];
|
||||
const int32_t input1_val = para->in1_args_.zp_ + input1_data[index];
|
||||
const int32_t shifted_input0_val = input0_val * para->left_shift_result0_;
|
||||
const int32_t shifted_input1_val = input1_val * para->left_shift_result1_;
|
||||
const int32_t scaled_input0_val = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(shifted_input0_val, para->input0_multiplier_), para->right_shift0_);
|
||||
const int32_t scaled_input1_val = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(shifted_input1_val, para->input1_multiplier_), para->right_shift1_);
|
||||
|
||||
const int32_t raw_data = scaled_input0_val - scaled_input1_val;
|
||||
const int32_t raw_output =
|
||||
RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data * (1 << (unsigned int)para->left_shift_out_),
|
||||
para->output_multiplier_),
|
||||
para->right_shift_out_) +
|
||||
para->out_args_.zp_;
|
||||
|
||||
output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_));
|
||||
}
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SUB_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SUB_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
int SubInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, SubQuantArg *para);
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SUB_INT8_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,74 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "mindspore/lite/include/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestSubInt8 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestSubInt8() {}
|
||||
};
|
||||
|
||||
TEST_F(TestSubInt8, SubInt8) {
|
||||
lite::tensor::Tensor in_tensor0(kNumberTypeInt8, {1, 1, 2, 5});
|
||||
lite::tensor::Tensor in_tensor1(kNumberTypeInt8, {1, 1, 1, 5});
|
||||
lite::tensor::Tensor out_tensor(kNumberTypeInt8, {1, 1, 2, 5});
|
||||
|
||||
int8_t input_data0[] = {105, 35, -27, 0, -63, 99, 16, 122, 67, -49};
|
||||
int8_t input_data1[] = {24, -38, -115, 106, -98};
|
||||
int8_t output_data[10] = {0};
|
||||
in_tensor0.SetData(input_data0);
|
||||
in_tensor1.SetData(input_data1);
|
||||
out_tensor.SetData(output_data);
|
||||
|
||||
const lite::tensor::QuantArg quant_in0 = {0.00784314f, 0}; // -1.0--1.0 -> 0--255
|
||||
const lite::tensor::QuantArg quant_in1 = {0.00784314f, 0};
|
||||
const lite::tensor::QuantArg quant_out = {0.00784314f, 0};
|
||||
in_tensor0.AddQuantParam(quant_in0);
|
||||
in_tensor1.AddQuantParam(quant_in1);
|
||||
out_tensor.AddQuantParam(quant_out);
|
||||
|
||||
std::vector<lite::tensor::Tensor *> inputs = {&in_tensor0, &in_tensor1};
|
||||
std::vector<lite::tensor::Tensor *> outputs = {&out_tensor};
|
||||
|
||||
OpParameter parameter = {};
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Sub};
|
||||
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
|
||||
auto ctx = std::make_shared<lite::Context>();
|
||||
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
|
||||
auto ret = kernel->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
int8_t expect0[10] = {81, 73, 88, -106, 35, 75, 54, 127, -39, 49};
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
EXPECT_EQ(output_data[i], expect0[i]);
|
||||
}
|
||||
|
||||
in_tensor0.SetData(nullptr);
|
||||
in_tensor1.SetData(nullptr);
|
||||
out_tensor.SetData(nullptr);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in new issue