!7777 fix bug of arithmetic

Merge pull request !7777 from fuzhiye/tmp
pull/7777/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 87ed2da6df

File diff suppressed because it is too large Load Diff

@ -64,17 +64,17 @@ int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param);
int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
@ -104,12 +104,12 @@ int ElementSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t
int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementLessEqual(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementLessEqual(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
ArithmeticParameter *param);

@ -14,6 +14,19 @@
* limitations under the License.
*/
#include "nnacl/fp16/cast_fp16.h"
void BoolToFloat16(const bool *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float16_t)input[i];
}
}
void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float16_t)input[i];
}
}
#ifndef ENABLE_ARM64
void Float32ToFloat16(const float *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) {

@ -22,6 +22,8 @@
#ifdef __cplusplus
extern "C" {
#endif
void BoolToFloat16(const bool *input, float16_t *output, int number);
void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number);
void Float32ToFloat16(const float *input, float16_t *output, int number);
void Float16ToFloat32(const float16_t *input, float *output, int number);
#ifdef __cplusplus

@ -0,0 +1,109 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string.h>
#include <math.h>
#include "nnacl/fp32/arithmetic_compare.h"
// equal:
int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] == input1[i];
}
return NNACL_OK;
}
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] == input1[i];
}
return NNACL_OK;
}
// not equal
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] != input1[i];
}
return NNACL_OK;
}
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] != input1[i];
}
return NNACL_OK;
}
// less
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] < input1[i];
}
return NNACL_OK;
}
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] < input1[i];
}
return NNACL_OK;
}
// less equal
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] <= input1[i];
}
return NNACL_OK;
}
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] <= input1[i];
}
return NNACL_OK;
}
// greater
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] > input1[i];
}
return NNACL_OK;
}
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] > input1[i];
}
return NNACL_OK;
}
// greater equal
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] >= input1[i];
}
return NNACL_OK;
}
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] >= input1[i];
}
return NNACL_OK;
}

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_
#define MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#ifdef __cplusplus
extern "C" {
#endif
int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_

@ -22,101 +22,87 @@
#define ACCURACY_DATA 0.00000001
int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
ArithmeticQuantArg *quant_arg) {
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
float out_zp = quant_arg->out_args_.zp_;
for (int index = 0; index < element_size; ++index) {
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
float minus_inputs = in0_real - in1_real;
float out_real = (float)true;
bool out_real = true;
if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) {
out_real = (float)false;
out_real = false;
}
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
output[index] = (uint8_t)out_real;
}
return NNACL_OK;
}
int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
float out_zp = quant_arg->out_args_.zp_;
for (int index = 0; index < element_size; ++index) {
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
float minus_inputs = in0_real - in1_real;
float out_real = (float)false;
bool out_real = false;
if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) {
out_real = (float)true;
out_real = true;
}
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
output[index] = (uint8_t)out_real;
}
return NNACL_OK;
}
int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
float out_zp = quant_arg->out_args_.zp_;
for (int index = 0; index < element_size; ++index) {
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
float out_real = (float)(in0_real < in1_real);
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
bool out_real = in0_real < in1_real;
output[index] = (uint8_t)out_real;
}
return NNACL_OK;
}
int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
ArithmeticQuantArg *quant_arg) {
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
float out_zp = quant_arg->out_args_.zp_;
for (int index = 0; index < element_size; ++index) {
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
float out_real = (float)(in0_real <= in1_real);
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
bool out_real = in0_real <= in1_real;
output[index] = (uint8_t)out_real;
}
return NNACL_OK;
}
int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
ArithmeticQuantArg *quant_arg) {
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
float out_zp = quant_arg->out_args_.zp_;
for (int index = 0; index < element_size; ++index) {
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
float out_real = (float)(in0_real > in1_real);
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
bool out_real = in0_real > in1_real;
output[index] = (uint8_t)out_real;
}
return NNACL_OK;
}
int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
ArithmeticQuantArg *quant_arg) {
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
const float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
float out_zp = quant_arg->out_args_.zp_;
for (int index = 0; index < element_size; ++index) {
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
float out_real = (float)(in0_real >= in1_real);
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
bool out_real = in0_real >= in1_real;
output[index] = (uint8_t)out_real;
}
return NNACL_OK;
}

@ -22,19 +22,20 @@
#ifdef __cplusplus
extern "C" {
#endif
int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
ArithmeticQuantArg *quant_arg);
int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg);
int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg);
int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg);
int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg);
int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
ArithmeticQuantArg *quant_arg);
int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg);
int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
ArithmeticQuantArg *quant_arg);
int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
ArithmeticQuantArg *quant_arg);
#ifdef __cplusplus

@ -0,0 +1,216 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "nnacl/fp16/arithmetic_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Equal;
using mindspore::schema::PrimitiveType_Greater;
using mindspore::schema::PrimitiveType_GreaterEqual;
using mindspore::schema::PrimitiveType_Less;
using mindspore::schema::PrimitiveType_LessEqual;
using mindspore::schema::PrimitiveType_NotEqual;
namespace mindspore::kernel {
ARITHMETIC_COMPARE_FUNC_INFO_FP16 arithmetic_cp_fun_table_fp16[] = {
{PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16},
{PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16},
{PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16},
{PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16},
{PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16},
{PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16,
ElementOptGreaterEqualFp16}};
ArithmeticCompareFuncFp16 GetArithmeticCompareFun(int primitive_type, int activation_type) {
for (size_t i = 0; i < sizeof(arithmetic_cp_fun_table_fp16); i++) {
if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == primitive_type &&
arithmetic_cp_fun_table_fp16[i].activation_type_ == activation_type) {
return arithmetic_cp_fun_table_fp16[i].func_;
}
}
return nullptr;
}
ArithmeticCompareOptFuncFp16 GetOptimizedArithmeticCompareFun(int primitive_type, int activation_type) {
for (size_t i = 0; i < sizeof(arithmetic_cp_fun_table_fp16); i++) {
if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == primitive_type &&
arithmetic_cp_fun_table_fp16[i].activation_type_ == activation_type) {
return arithmetic_cp_fun_table_fp16[i].opt_func_;
}
}
return nullptr;
}
int ArithmeticCompareFP16CPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int ArithmeticCompareFP16CPUKernel::ReSize() {
param_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
param_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
param_->out_elements_num_ = out_tensors_[0]->ElementsNum();
if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
param_->broadcasting_ = false;
arithmetic_opt_func_ = GetOptimizedArithmeticCompareFun(param_->op_parameter_.type_, param_->activation_type_);
} else {
arithmetic_func_ = GetArithmeticCompareFun(param_->op_parameter_.type_, param_->activation_type_);
}
if (arithmetic_opt_func_ == nullptr && arithmetic_func_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!";
return RET_ERROR;
}
if (param_->broadcasting_) {
outside_ = 1;
for (int i = param_->ndim_ - 1; i >= 0; --i) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_pos_ = i;
break;
}
outside_ *= param_->out_shape_[i];
}
ComputeStrides(param_->in_shape0_, param_->in_strides0_, param_->ndim_);
ComputeStrides(param_->in_shape1_, param_->in_strides1_, param_->ndim_);
ComputeStrides(param_->out_shape_, param_->out_strides_, param_->ndim_);
}
return RET_OK;
}
int ArithmeticCompareFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, uint8_t *output, int dim,
int out_count, int cur_offset) {
if (dim > break_pos_) {
return arithmetic_func_(input0 + cur_offset, input1 + cur_offset, output + cur_offset, out_count);
}
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
int pos0 = param_->in_shape0_[dim] == 1 ? 0 : i;
int pos1 = param_->in_shape1_[dim] == 1 ? 0 : i;
int ret = BroadcastRun(input0 + pos0 * param_->in_strides0_[dim], input1 + pos1 * param_->in_strides1_[dim],
output + i * param_->out_strides_[dim], dim + 1, out_count, cur_offset);
if (ret != RET_OK) {
return RET_ERROR;
}
}
return RET_OK;
}
int ArithmeticCompareFP16CPUKernel::DoArithmetic(int task_id) {
int stride_per_thread = UP_DIV(param_->broadcasting_ ? outside_ : param_->out_elements_num_, context_->thread_num_);
int cur_offset = stride_per_thread * task_id;
int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset)
: MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset);
int ret = RET_OK;
if (param_->broadcasting_) {
ret = BroadcastRun(input0_fp16_, input1_fp16_, output_fp16_, 0, cur_count, cur_offset);
} else if (param_->in_elements_num0_ == 1) {
ret = arithmetic_opt_func_(input0_fp16_, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count, param_);
} else if (param_->in_elements_num1_ == 1) {
ret = arithmetic_opt_func_(input0_fp16_ + cur_offset, input1_fp16_, output_fp16_ + cur_offset, cur_count, param_);
} else {
ret = arithmetic_func_(input0_fp16_ + cur_offset, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoArithmetic failed, ret = " << ret;
}
return ret;
}
static int ArithmeticsRunFp16(void *cdata, int task_id) {
auto arithmetic_kernel = reinterpret_cast<ArithmeticCompareFP16CPUKernel *>(cdata);
auto ret = arithmetic_kernel->DoArithmetic(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRunFp16 error task_id[" << task_id << "] ret[" << ret << "]";
}
return ret;
}
int ArithmeticCompareFP16CPUKernel::Run() {
auto output_tensor = out_tensors_.at(0);
is_input0_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
is_input1_fp32_ = in_tensors_.at(1)->data_type() == kNumberTypeFloat32;
input0_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_);
input1_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_);
output_fp16_ = reinterpret_cast<uint8_t *>(output_tensor->MutableData());
if (input0_fp16_ == nullptr || input1_fp16_ == nullptr || output_fp16_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
FreeTmpBuffer();
return RET_ERROR;
}
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRunFp16, this, context_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRunFp16 run error error_code[" << ret << "]";
}
FreeTmpBuffer();
return ret;
}
void ArithmeticCompareFP16CPUKernel::FreeTmpBuffer() {
if (is_input0_fp32_) {
context_->allocator->Free(input0_fp16_);
input0_fp16_ = nullptr;
}
if (is_input1_fp32_) {
context_->allocator->Free(input1_fp16_);
input1_fp16_ = nullptr;
}
}
kernel::LiteKernel *CpuArithmeticCompareFp16KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *parameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "input parameter is null!";
return nullptr;
}
auto kernel = new (std::nothrow) ArithmeticCompareFP16CPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
free(parameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticCompareFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticCompareFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticCompareFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticCompareFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticCompareFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp16KernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/fp16/arithmetic_fp16.h"
#include "schema/model_generated.h"
namespace mindspore::kernel {
typedef int (*ArithmeticCompareFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
typedef int (*ArithmeticCompareOptFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
typedef struct {
int primitive_type_;
int activation_type_;
ArithmeticCompareFuncFp16 func_;
ArithmeticCompareOptFuncFp16 opt_func_;
} ARITHMETIC_COMPARE_FUNC_INFO_FP16;
class ArithmeticCompareFP16CPUKernel : public LiteKernel {
public:
ArithmeticCompareFP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
~ArithmeticCompareFP16CPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
int DoArithmetic(int task_id);
int BroadcastRun(float16_t *input0, float16_t *input1, uint8_t *output, int dim, int out_count,
int out_thread_stride);
private:
void FreeTmpBuffer();
int outside_;
int break_pos_;
bool is_input0_fp32_ = false;
bool is_input1_fp32_ = false;
float16_t *input0_fp16_ = nullptr;
float16_t *input1_fp16_ = nullptr;
uint8_t *output_fp16_ = nullptr;
ArithmeticParameter *param_ = nullptr;
ArithmeticCompareFuncFp16 arithmetic_func_ = nullptr;
ArithmeticCompareOptFuncFp16 arithmetic_opt_func_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_

@ -68,15 +68,7 @@ ARITHMETIC_FUNC_INFO_FP16 arithmetic_fun_table_fp16[] = {
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16,
ElementOptSquaredDifferenceFp16},
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16},
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16},
{PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16},
{PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16},
{PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16},
{PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16},
{PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16},
{PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16,
ElementOptGreaterEqualFp16},
};
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}};
ArithmeticFuncFp16 GetArithmeticFun(int primitive_type, int activation_type) {
for (size_t i = 0; i < sizeof(arithmetic_fun_table_fp16); i++) {

@ -67,6 +67,12 @@ int CastFp16CPUKernel::DoCast(int thread_id) {
auto offset = thread_id * stride_;
auto output_data = out_tensors_.at(0)->MutableData();
switch (input->data_type()) {
case kNumberTypeBool:
BoolToFloat16(reinterpret_cast<bool *>(input->MutableData()) + offset,
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
case kNumberTypeUInt8:
Uint8ToFloat16(reinterpret_cast<uint8_t *>(input->MutableData()) + offset,
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
case kNumberTypeFloat32:
Float32ToFloat16(reinterpret_cast<float *>(input->MutableData()) + offset,
reinterpret_cast<float16_t *>(output_data) + offset, data_num);

@ -339,12 +339,6 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorMod, CpuArithmeticFp32Ke
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,127 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/arithmetic_compare.h"
#include "src/kernel_registry.h"
#include "nnacl/fp32/arithmetic_compare.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Equal;
using mindspore::schema::PrimitiveType_Greater;
using mindspore::schema::PrimitiveType_GreaterEqual;
using mindspore::schema::PrimitiveType_Less;
using mindspore::schema::PrimitiveType_LessEqual;
using mindspore::schema::PrimitiveType_NotEqual;
namespace mindspore::kernel {
namespace {
typedef struct {
int primitive_type_;
ArithmeticCompareFp32Func func_;
} TYPE_FUNC_INFO;
} // namespace
ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) {
TYPE_FUNC_INFO type_func_table[] = {
{PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32},
{PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32},
{PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}};
for (size_t i = 0; i < sizeof(type_func_table); i++) {
if (type_func_table[i].primitive_type_ == primitive_type) {
return type_func_table[i].func_;
}
}
return nullptr;
}
int ArithmeticCompareCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; }
int ArithmeticCompareCPUKernel::DoExecute(int task_id) {
int elements_num = in_tensors_.at(0)->ElementsNum();
int stride = UP_DIV(elements_num, op_parameter_->thread_num_);
int offset = task_id * stride;
int count = MSMIN(stride, elements_num - offset);
if (count <= 0) {
return RET_OK;
}
if (func_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
return RET_ERROR;
}
// two inputs have the same shape, support broadcast later
auto *input0_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto *input1_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto *output_ptr = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->MutableData());
auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run failed, illegal input! ";
}
return ret;
}
int ArithmeticCompareRun(void *cdata, int task_id) {
auto kernel = reinterpret_cast<ArithmeticCompareCPUKernel *>(cdata);
auto ret = kernel->DoExecute(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]";
}
return ret;
}
int ArithmeticCompareCPUKernel::Run() {
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]";
}
return ret;
}
kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *parameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) ArithmeticCompareCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ArithmeticSelfCPUKernel fail!";
free(parameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp32KernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_
#include <vector>
#include "src/runtime/kernel/arm/fp32/arithmetic.h"
namespace mindspore::kernel {
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
public:
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) {
func_ = GetArithmeticCompareFun(parameter->type_);
}
~ArithmeticCompareCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
virtual int DoExecute(int task_id);
private:
ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type);
ArithmeticCompareFp32Func func_;
};
int ArithmeticCompareRun(void *cdata, int task_id);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_

@ -211,7 +211,6 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor
filter_shape = {new_out_channel, kernel_h, kernel_w, new_in_channel};
bias_shape = {new_out_channel};
auto *origin_weight = reinterpret_cast<float *>(inputs.at(kWeightIndex)->data_c());
auto *origin_bias = reinterpret_cast<float *>(inputs.at(kBiasIndex)->data_c());
for (int i = 0; i < group; ++i) {
std::vector<lite::Tensor *> new_inputs;
@ -234,6 +233,7 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor
// if has bias, set new bias
if (has_bias) {
auto *origin_bias = reinterpret_cast<float *>(inputs.at(kBiasIndex)->data_c());
auto bias_tensor = new (std::nothrow)
lite::Tensor(inputs.at(kBiasIndex)->data_type(), bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR);
bias_tensor->MallocData();

@ -104,7 +104,7 @@ int ArithmeticInt8CPUKernel::ReSize() { return RET_OK; }
int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) {
auto input0_data = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData());
auto input1_data1 = reinterpret_cast<int8_t *>(in_tensors_[1]->MutableData());
auto output_data = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData());
auto output_data = reinterpret_cast<uint8_t *>(out_tensors_[0]->MutableData());
auto element_num = out_tensors_[0]->ElementsNum();
auto param = reinterpret_cast<ArithmeticParameter *>(op_parameter_);
if (param->broadcasting_ && arithmetic_run_ != nullptr) {

@ -24,7 +24,7 @@
namespace mindspore::kernel {
class ArithmeticInt8CPUKernel : public LiteKernel {
typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
ArithmeticQuantArg *quant_arg);
public:

Loading…
Cancel
Save