!4748 Add lite op Power supporting int8 and testcase

Merge pull request !4748 from wangminggui/master
pull/4748/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit c170ccbf33

@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/power_base.h"
#include <vector>
#include "src/runtime/kernel/arm/int8/power_int8.h"
#include "src/runtime/kernel/arm/fp32/power.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Power;
namespace mindspore::kernel {
int PowerBaseCPUKernel::Init() { return RET_OK; }
int PowerBaseCPUKernel::ReSize() { return RET_OK; }
kernel::LiteKernel *CpuPowerInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Power);
auto *kernel = new (std::nothrow) PowerInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PowerInt8CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Power);
PowerCPUKernel *kernel = new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PowerCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Power, CpuPowerInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Power, CpuPowerFp32KernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_POWER_BASE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_POWER_BASE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/nnacl/power_parameter.h"
namespace mindspore::kernel {
class PowerBaseCPUKernel : public LiteKernel {
public:
PowerBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<PowerParameter *>(op_parameter_);
}
~PowerBaseCPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override { return 0; }
protected:
PowerParameter *param_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_POWER_BASE_H_

@ -45,6 +45,9 @@ int SoftmaxBaseCPUKernel::ReSize() {
auto in_dims = in_shape.size();
int ele_size = 1;
softmax_param_->n_dim_ = in_dims;
if (softmax_param_->axis_ == -1) {
softmax_param_->axis_ += in_dims;
}
for (size_t i = 0; i < in_dims; i++) {
softmax_param_->input_shape_[i] = in_shape[i];
ele_size *= in_shape[i];

@ -76,27 +76,4 @@ int PowerCPUKernel::RunImpl(int task_id) {
return RET_OK;
}
kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Power);
PowerCPUKernel *kernel = new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PowerCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Power, CpuPowerFp32KernelCreator)
} // namespace mindspore::kernel

@ -21,14 +21,15 @@
#include "include/context.h"
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/nnacl/power.h"
#include "src/runtime/kernel/arm/base/power_base.h"
namespace mindspore::kernel {
class PowerCPUKernel : public LiteKernel {
class PowerCPUKernel : public PowerBaseCPUKernel {
public:
PowerCPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(param, inputs, outputs, ctx, primitive),
: PowerBaseCPUKernel(param, inputs, outputs, ctx, primitive),
ctx_(ctx),
thread_count_(ctx->thread_num_),
power_(reinterpret_cast<PowerParameter *>(op_parameter_)->power_),

@ -48,10 +48,6 @@ int SoftmaxCPUKernel::ReSize() {
}
auto n_dim = softmax_param_->n_dim_;
auto axis = softmax_param_->axis_;
if (axis == -1) {
softmax_param_->axis_ += n_dim;
axis = softmax_param_->axis_;
}
auto in_shape = in_tensors_.front()->shape();
int out_plane_size = 1;
for (int i = 0; i < axis; ++i) {

@ -0,0 +1,112 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/power_int8.h"
#include <limits>
#include "src/runtime/kernel/arm/nnacl/int8/power_int8.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int PowerInt8CPUKernel::Init() {
auto ret = PowerBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);
MS_ASSERT(input);
MS_ASSERT(output);
auto in_quant_args = input->GetQuantParams();
param_->quant_arg_.in_args_.scale_ = in_quant_args.front().scale;
param_->quant_arg_.in_args_.zp_ = in_quant_args.front().zeroPoint;
auto out_quant_args = output->GetQuantParams();
param_->quant_arg_.out_args_.scale_ = out_quant_args.front().scale;
param_->quant_arg_.out_args_.zp_ = out_quant_args.front().zeroPoint;
param_->quant_arg_.output_activation_max_ = std::numeric_limits<int8_t>::max();
param_->quant_arg_.output_activation_min_ = std::numeric_limits<int8_t>::min();
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int PowerInt8CPUKernel::ReSize() { return PowerBaseCPUKernel::ReSize(); }
int PowerInt8CPUKernel::DoPower(int task_id) {
const int8_t *input_data = reinterpret_cast<const int8_t *>(in_tensors_[0]->Data());
int8_t *output_data = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
auto size = in_tensors_[0]->ElementsNum();
int stride = UP_DIV(size, op_parameter_->thread_num_);
int count = MSMIN(stride, size - stride * task_id);
int8_t *exp_ptr = nullptr;
param_->broadcast_ = true;
if (in_tensors_.size() == 2) {
auto exp_tensor = in_tensors_.at(1);
auto exp_quant_args = exp_tensor->GetQuantParams();
param_->quant_arg_.exp_args_.scale_ = exp_quant_args.front().scale;
param_->quant_arg_.exp_args_.zp_ = exp_quant_args.front().zeroPoint;
exp_ptr = reinterpret_cast<int8_t *>(exp_tensor->Data());
param_->broadcast_ = false;
if (in_tensors_[0]->Size() != in_tensors_[1]->Size()) {
MS_LOG(ERROR) << "Power input size " << in_tensors_[0]->Size() << " is not equal to exponent size "
<< in_tensors_[1]->Size();
return RET_ERROR;
}
}
if (!param_->broadcast_) {
exp_ptr = exp_ptr + stride * task_id;
}
auto ret = PowerInt8(input_data + stride * task_id, exp_ptr, output_data + stride * task_id, count, param_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "PowerInt8 error ,task_id[" << task_id << "] error_code[" << ret << "]";
}
return ret;
}
int PowerInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto power_kernel = reinterpret_cast<PowerInt8CPUKernel *>(cdata);
auto ret = power_kernel->DoPower(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoPower error, task_id[" << task_id << "] error_code[" << ret << "]";
}
return ret;
}
int PowerInt8CPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return ret;
}
ret = LiteBackendParallelLaunch(PowerInt8Run, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "PowerInt8Run error, error_code[" << ret << "]";
}
return ret;
}
} // namespace mindspore::kernel

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_POWER_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_POWER_INT8_H_
#include <vector>
#include "src/runtime/kernel/arm/base/power_base.h"
#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h"
namespace mindspore::kernel {
class PowerInt8CPUKernel : public PowerBaseCPUKernel {
public:
PowerInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: PowerBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~PowerInt8CPUKernel() {
}
int Init() override;
int ReSize() override;
int Run() override;
int DoPower(int task_id);
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_POWER_INT8_H_

@ -22,7 +22,7 @@ int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output
int i, j, k;
int left, right;
float depth_radius = param->depth_radius_;
int depth_radius = param->depth_radius_;
float bias = param->bias_;
float alpha = param->alpha_;
float beta = param->beta_;

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/int8/power_int8.h"
int PowerInt8(const int8_t *input, int8_t *exp_ptr, int8_t *output, int count, PowerParameter *param) {
double input_scale = param->quant_arg_.in_args_.scale_;
int input_zp = param->quant_arg_.in_args_.zp_;
double output_scale = param->quant_arg_.out_args_.scale_;
int output_zp = param->quant_arg_.out_args_.zp_;
int act_min = param->quant_arg_.output_activation_min_;
int act_max = param->quant_arg_.output_activation_max_;
if (param->broadcast_) {
for (int i = 0; i < count; ++i) {
float input_val = input_scale * (input[i] - input_zp);
float output_val = pow(param->scale_ * input_val + param->shift_, param->power_);
int32_t output_scaled = round(output_val / output_scale) + output_zp;
output[i] = (int8_t)MSMAX(act_min, MSMIN(output_scaled, act_max));
}
} else {
double exp_scale = param->quant_arg_.exp_args_.scale_;
int exp_zp = param->quant_arg_.exp_args_.zp_;
for (int i = 0; i < count; ++i) {
float input_val = input_scale * (input[i] - input_zp);
float exp_val = exp_scale * (exp_ptr[i] - exp_zp);
float output_val = pow(param->scale_ * input_val + param->shift_, exp_val);
int32_t output_scaled = round(output_val / output_scale) + output_zp;
output[i] = (int8_t)MSMAX(act_min, MSMIN(output_scaled, act_max));
}
}
return 0;
}

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_POWER_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_POWER_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/power_parameter.h"
#include "nnacl/quantization/quantize.h"
#ifdef __cplusplus
extern "C" {
#endif
int PowerInt8(const int8_t *input_ptr, int8_t *exp_ptr, int8_t *output_ptr, int count, PowerParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_POWER_INT8_H_

@ -20,19 +20,17 @@
#include "nnacl/errorcode.h"
int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *param) {
int input_scale = param->quant_arg_.in_args_.scale_;
int input_zp = -param->quant_arg_.in_args_.zp_;
int output_scale = param->quant_arg_.in_args_.scale_;
int output_zp = -param->quant_arg_.out_args_.zp_;
double input_scale = param->quant_arg_.in_args_.scale_;
int input_zp = param->quant_arg_.in_args_.zp_;
double output_scale = param->quant_arg_.out_args_.scale_;
int output_zp = param->quant_arg_.out_args_.zp_;
int act_min = param->quant_arg_.output_activation_min_;
int act_max = param->quant_arg_.output_activation_max_;
int equal_quant = 0;
double multiplier = 0;
double multiplier = input_scale / output_scale;
if (input_scale == output_scale && input_zp == output_zp) {
equal_quant = 1;
} else {
multiplier = input_scale / output_scale;
}
int32_t end_n = param->begin_[0] + param->size_[0];
@ -57,7 +55,7 @@ int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *par
memcpy(output + out_offset, input + in_offset, unit_size);
} else {
for (c = 0; c < unit_count; ++c) {
int32_t output_val = round(multiplier * (input[in_offset + c] + input_zp)) + output_zp;
int32_t output_val = round(multiplier * (input[in_offset + c] - input_zp)) + output_zp;
output[c + out_offset] = (int8_t)MSMAX(act_min, MSMIN(output_val, act_max));
}
}
@ -69,10 +67,10 @@ int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *par
}
int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param) {
int input_scale = param->quant_arg_.in_args_.scale_;
int input_zp = -param->quant_arg_.in_args_.zp_;
int output_scale = param->quant_arg_.in_args_.scale_;
int output_zp = -param->quant_arg_.out_args_.zp_;
double input_scale = param->quant_arg_.in_args_.scale_;
int input_zp = param->quant_arg_.in_args_.zp_;
double output_scale = param->quant_arg_.out_args_.scale_;
int output_zp = param->quant_arg_.out_args_.zp_;
int act_min = param->quant_arg_.output_activation_min_;
int act_max = param->quant_arg_.output_activation_max_;
@ -92,11 +90,9 @@ int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param) {
int n, h, w, c;
int equal_quant = 0;
double multiplier = 0;
double multiplier = input_scale / output_scale;
if (input_scale == output_scale && input_zp == output_zp) {
equal_quant = 1;
} else {
multiplier = input_scale / output_scale;
}
for (n = 0; n < param->size_[0]; ++n) {
@ -116,7 +112,7 @@ int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param) {
memcpy(output + out_offset, input + in_offset, unit_size);
} else {
for (c = 0; c < out_dim3; ++c) {
int32_t output_val = round(multiplier * (input[in_offset + c] + input_zp)) + output_zp;
int32_t output_val = round(multiplier * (input[in_offset + c] - input_zp)) + output_zp;
output[c + out_offset] = (int8_t)MSMAX(act_min, MSMIN(output_val, act_max));
}
}

@ -18,13 +18,7 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POWER_H_
#include <math.h>
#include "nnacl/op_base.h"
typedef struct PowerParameter {
OpParameter op_parameter_;
float power_;
float scale_;
float shift_;
} PowerParameter;
#include "nnacl/power_parameter.h"
#ifdef __cplusplus
extern "C" {

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POWER_PARAMETER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POWER_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef struct PowerParameter {
OpParameter op_parameter_;
PowerQuantArg quant_arg_;
float power_;
float scale_;
float shift_;
bool broadcast_;
} PowerParameter;
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POWER_PARAMETER_H_

@ -251,6 +251,14 @@ typedef struct SliceQuantArg {
int output_activation_max_;
} SliceQuantArg;
typedef struct PowerQuantArg {
QuantArg in_args_;
QuantArg exp_args_;
QuantArg out_args_;
int output_activation_min_;
int output_activation_max_;
} PowerQuantArg;
#ifdef __cplusplus
extern "C" {
#endif

@ -0,0 +1,153 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/int8/power_int8.h"
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/power_parameter.h"
#include "mindspore/lite/src/kernel_registry.h"
namespace mindspore {
class TestPowerInt8 : public mindspore::CommonTest {
public:
TestPowerInt8() {}
};
TEST_F(TestPowerInt8, PowerInt8) {
std::vector<lite::tensor::Tensor *> inputs_tensor;
std::vector<lite::tensor::Tensor *> outputs_tensor;
PowerParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_Power;
op_param.power_ = 2;
op_param.scale_ = 1;
op_param.shift_ = 0;
lite::tensor::QuantArg input_quant_arg;
input_quant_arg.scale = 0.0156863;
input_quant_arg.zeroPoint = -128;
lite::tensor::QuantArg output_quant_arg;
output_quant_arg.scale = 0.0627451;
output_quant_arg.zeroPoint = -128;
std::vector<int8_t> input = {-64, -1, 63, 127};
std::vector<int> in_shape = {1, 1, 1, 4};
lite::tensor::Tensor input0_tensor;
TypeId tid_int8 = kNumberTypeInt8;
inputs_tensor.push_back(&input0_tensor);
input0_tensor.SetData(input.data());
input0_tensor.set_shape(in_shape);
input0_tensor.AddQuantParam(input_quant_arg);
input0_tensor.set_data_type(tid_int8);
std::vector<int8_t> output(4);
std::vector<int> output_shape = {1, 1, 1, 4};
lite::tensor::Tensor output0_tensor;
outputs_tensor.push_back(&output0_tensor);
output0_tensor.SetData(output.data());
output0_tensor.AddQuantParam(output_quant_arg);
output0_tensor.set_data_type(tid_int8);
auto ctx = std::make_shared<lite::Context>();
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Power};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto output_tensor_shape = output0_tensor.shape();
kernel->Run();
std::vector<int8_t> except_result = {-112, -65, 15, 127};
CompareOutputData(output.data(), except_result.data(), input.size(), 0.000001);
input0_tensor.SetData(nullptr);
output0_tensor.SetData(nullptr);
}
TEST_F(TestPowerInt8, normal) {
std::vector<lite::tensor::Tensor *> inputs_tensor;
std::vector<lite::tensor::Tensor *> outputs_tensor;
PowerParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_Power;
op_param.scale_ = 1;
op_param.shift_ = 0;
lite::tensor::QuantArg input_quant_arg;
input_quant_arg.scale = 0.0156863;
input_quant_arg.zeroPoint = -128;
lite::tensor::QuantArg exp_quant_arg;
exp_quant_arg.scale = 0.0156863;
exp_quant_arg.zeroPoint = -128;
lite::tensor::QuantArg output_quant_arg;
output_quant_arg.scale = 0.0352941;
output_quant_arg.zeroPoint = -128;
std::vector<int8_t> input = {-64, -1, 63, 127};
std::vector<int> in_shape = {1, 1, 1, 4};
std::vector<int8_t> input1 = {127, 63, -1, -64};
std::vector<int> in_shape1 = {1, 1, 1, 4};
lite::tensor::Tensor input0_tensor, input1_tensor;
TypeId tid_int8 = kNumberTypeInt8;
inputs_tensor.push_back(&input0_tensor);
inputs_tensor.push_back(&input1_tensor);
input0_tensor.SetData(input.data());
input0_tensor.set_shape(in_shape);
input0_tensor.AddQuantParam(input_quant_arg);
input0_tensor.set_data_type(tid_int8);
input1_tensor.SetData(input1.data());
input1_tensor.set_shape(in_shape1);
input1_tensor.AddQuantParam(exp_quant_arg);
input1_tensor.set_data_type(tid_int8);
std::vector<int8_t> output(4);
std::vector<int> output_shape = {1, 1, 1, 4};
lite::tensor::Tensor output0_tensor;
outputs_tensor.push_back(&output0_tensor);
output0_tensor.SetData(output.data());
output0_tensor.AddQuantParam(output_quant_arg);
output0_tensor.set_data_type(tid_int8);
auto ctx = std::make_shared<lite::Context>();
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Power};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto output_tensor_shape = output0_tensor.shape();
kernel->Run();
std::vector<int8_t> except_result = {-99, 95, 124, -14};
CompareOutputData(output.data(), except_result.data(), input.size(), 0.000001);
input0_tensor.SetData(nullptr);
output0_tensor.SetData(nullptr);
}
} // namespace mindspore
Loading…
Cancel
Save