commit
87ed2da6df
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,109 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "nnacl/fp32/arithmetic_compare.h"
|
||||
|
||||
// equal:
|
||||
int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] == input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] == input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// not equal
|
||||
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] != input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] != input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// less
|
||||
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] < input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] < input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// less equal
|
||||
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] <= input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] <= input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// greater
|
||||
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] > input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] > input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// greater equal
|
||||
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] >= input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] >= input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_
|
||||
#define MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
|
||||
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
|
||||
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
|
||||
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
|
||||
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
|
||||
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_ARITHMETIC_COMPARE_H_
|
@ -0,0 +1,216 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
|
||||
#include "nnacl/fp16/arithmetic_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Equal;
|
||||
using mindspore::schema::PrimitiveType_Greater;
|
||||
using mindspore::schema::PrimitiveType_GreaterEqual;
|
||||
using mindspore::schema::PrimitiveType_Less;
|
||||
using mindspore::schema::PrimitiveType_LessEqual;
|
||||
using mindspore::schema::PrimitiveType_NotEqual;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
ARITHMETIC_COMPARE_FUNC_INFO_FP16 arithmetic_cp_fun_table_fp16[] = {
|
||||
{PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16},
|
||||
{PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16},
|
||||
{PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16},
|
||||
{PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16},
|
||||
{PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16},
|
||||
{PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16,
|
||||
ElementOptGreaterEqualFp16}};
|
||||
|
||||
ArithmeticCompareFuncFp16 GetArithmeticCompareFun(int primitive_type, int activation_type) {
|
||||
for (size_t i = 0; i < sizeof(arithmetic_cp_fun_table_fp16); i++) {
|
||||
if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == primitive_type &&
|
||||
arithmetic_cp_fun_table_fp16[i].activation_type_ == activation_type) {
|
||||
return arithmetic_cp_fun_table_fp16[i].func_;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ArithmeticCompareOptFuncFp16 GetOptimizedArithmeticCompareFun(int primitive_type, int activation_type) {
|
||||
for (size_t i = 0; i < sizeof(arithmetic_cp_fun_table_fp16); i++) {
|
||||
if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == primitive_type &&
|
||||
arithmetic_cp_fun_table_fp16[i].activation_type_ == activation_type) {
|
||||
return arithmetic_cp_fun_table_fp16[i].opt_func_;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int ArithmeticCompareFP16CPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int ArithmeticCompareFP16CPUKernel::ReSize() {
|
||||
param_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
|
||||
param_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
||||
param_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
||||
|
||||
if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
|
||||
param_->broadcasting_ = false;
|
||||
arithmetic_opt_func_ = GetOptimizedArithmeticCompareFun(param_->op_parameter_.type_, param_->activation_type_);
|
||||
} else {
|
||||
arithmetic_func_ = GetArithmeticCompareFun(param_->op_parameter_.type_, param_->activation_type_);
|
||||
}
|
||||
if (arithmetic_opt_func_ == nullptr && arithmetic_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param_->broadcasting_) {
|
||||
outside_ = 1;
|
||||
for (int i = param_->ndim_ - 1; i >= 0; --i) {
|
||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||
break_pos_ = i;
|
||||
break;
|
||||
}
|
||||
outside_ *= param_->out_shape_[i];
|
||||
}
|
||||
ComputeStrides(param_->in_shape0_, param_->in_strides0_, param_->ndim_);
|
||||
ComputeStrides(param_->in_shape1_, param_->in_strides1_, param_->ndim_);
|
||||
ComputeStrides(param_->out_shape_, param_->out_strides_, param_->ndim_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCompareFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, uint8_t *output, int dim,
|
||||
int out_count, int cur_offset) {
|
||||
if (dim > break_pos_) {
|
||||
return arithmetic_func_(input0 + cur_offset, input1 + cur_offset, output + cur_offset, out_count);
|
||||
}
|
||||
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
|
||||
int pos0 = param_->in_shape0_[dim] == 1 ? 0 : i;
|
||||
int pos1 = param_->in_shape1_[dim] == 1 ? 0 : i;
|
||||
int ret = BroadcastRun(input0 + pos0 * param_->in_strides0_[dim], input1 + pos1 * param_->in_strides1_[dim],
|
||||
output + i * param_->out_strides_[dim], dim + 1, out_count, cur_offset);
|
||||
if (ret != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCompareFP16CPUKernel::DoArithmetic(int task_id) {
|
||||
int stride_per_thread = UP_DIV(param_->broadcasting_ ? outside_ : param_->out_elements_num_, context_->thread_num_);
|
||||
int cur_offset = stride_per_thread * task_id;
|
||||
int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset)
|
||||
: MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset);
|
||||
|
||||
int ret = RET_OK;
|
||||
if (param_->broadcasting_) {
|
||||
ret = BroadcastRun(input0_fp16_, input1_fp16_, output_fp16_, 0, cur_count, cur_offset);
|
||||
} else if (param_->in_elements_num0_ == 1) {
|
||||
ret = arithmetic_opt_func_(input0_fp16_, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count, param_);
|
||||
} else if (param_->in_elements_num1_ == 1) {
|
||||
ret = arithmetic_opt_func_(input0_fp16_ + cur_offset, input1_fp16_, output_fp16_ + cur_offset, cur_count, param_);
|
||||
} else {
|
||||
ret = arithmetic_func_(input0_fp16_ + cur_offset, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoArithmetic failed, ret = " << ret;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static int ArithmeticsRunFp16(void *cdata, int task_id) {
|
||||
auto arithmetic_kernel = reinterpret_cast<ArithmeticCompareFP16CPUKernel *>(cdata);
|
||||
auto ret = arithmetic_kernel->DoArithmetic(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticsRunFp16 error task_id[" << task_id << "] ret[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCompareFP16CPUKernel::Run() {
|
||||
auto output_tensor = out_tensors_.at(0);
|
||||
is_input0_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
|
||||
is_input1_fp32_ = in_tensors_.at(1)->data_type() == kNumberTypeFloat32;
|
||||
|
||||
input0_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_);
|
||||
input1_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_);
|
||||
output_fp16_ = reinterpret_cast<uint8_t *>(output_tensor->MutableData());
|
||||
if (input0_fp16_ == nullptr || input1_fp16_ == nullptr || output_fp16_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRunFp16, this, context_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticsRunFp16 run error error_code[" << ret << "]";
|
||||
}
|
||||
FreeTmpBuffer();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void ArithmeticCompareFP16CPUKernel::FreeTmpBuffer() {
|
||||
if (is_input0_fp32_) {
|
||||
context_->allocator->Free(input0_fp16_);
|
||||
input0_fp16_ = nullptr;
|
||||
}
|
||||
if (is_input1_fp32_) {
|
||||
context_->allocator->Free(input1_fp16_);
|
||||
input1_fp16_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuArithmeticCompareFp16KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *parameter, const lite::InnerContext *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "input parameter is null!";
|
||||
return nullptr;
|
||||
}
|
||||
auto kernel = new (std::nothrow) ArithmeticCompareFP16CPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticCompareFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticCompareFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticCompareFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticCompareFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticCompareFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp16KernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,67 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/fp16/arithmetic_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
typedef int (*ArithmeticCompareFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
|
||||
typedef int (*ArithmeticCompareOptFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
typedef struct {
|
||||
int primitive_type_;
|
||||
int activation_type_;
|
||||
ArithmeticCompareFuncFp16 func_;
|
||||
ArithmeticCompareOptFuncFp16 opt_func_;
|
||||
} ARITHMETIC_COMPARE_FUNC_INFO_FP16;
|
||||
|
||||
class ArithmeticCompareFP16CPUKernel : public LiteKernel {
|
||||
public:
|
||||
ArithmeticCompareFP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
param_ = reinterpret_cast<ArithmeticParameter *>(parameter);
|
||||
}
|
||||
~ArithmeticCompareFP16CPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoArithmetic(int task_id);
|
||||
int BroadcastRun(float16_t *input0, float16_t *input1, uint8_t *output, int dim, int out_count,
|
||||
int out_thread_stride);
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
int outside_;
|
||||
int break_pos_;
|
||||
bool is_input0_fp32_ = false;
|
||||
bool is_input1_fp32_ = false;
|
||||
float16_t *input0_fp16_ = nullptr;
|
||||
float16_t *input1_fp16_ = nullptr;
|
||||
uint8_t *output_fp16_ = nullptr;
|
||||
ArithmeticParameter *param_ = nullptr;
|
||||
ArithmeticCompareFuncFp16 arithmetic_func_ = nullptr;
|
||||
ArithmeticCompareOptFuncFp16 arithmetic_opt_func_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_COMPARE_FP16_H_
|
@ -0,0 +1,127 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/fp32/arithmetic_compare.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "nnacl/fp32/arithmetic_compare.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Equal;
|
||||
using mindspore::schema::PrimitiveType_Greater;
|
||||
using mindspore::schema::PrimitiveType_GreaterEqual;
|
||||
using mindspore::schema::PrimitiveType_Less;
|
||||
using mindspore::schema::PrimitiveType_LessEqual;
|
||||
using mindspore::schema::PrimitiveType_NotEqual;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
typedef struct {
|
||||
int primitive_type_;
|
||||
ArithmeticCompareFp32Func func_;
|
||||
} TYPE_FUNC_INFO;
|
||||
} // namespace
|
||||
|
||||
ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) {
|
||||
TYPE_FUNC_INFO type_func_table[] = {
|
||||
{PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32},
|
||||
{PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32},
|
||||
{PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}};
|
||||
for (size_t i = 0; i < sizeof(type_func_table); i++) {
|
||||
if (type_func_table[i].primitive_type_ == primitive_type) {
|
||||
return type_func_table[i].func_;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int ArithmeticCompareCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int ArithmeticCompareCPUKernel::DoExecute(int task_id) {
|
||||
int elements_num = in_tensors_.at(0)->ElementsNum();
|
||||
int stride = UP_DIV(elements_num, op_parameter_->thread_num_);
|
||||
int offset = task_id * stride;
|
||||
int count = MSMIN(stride, elements_num - offset);
|
||||
if (count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Run function is null! ";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// two inputs have the same shape, support broadcast later
|
||||
auto *input0_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
auto *input1_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
|
||||
auto *output_ptr = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->MutableData());
|
||||
auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run failed, illegal input! ";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCompareRun(void *cdata, int task_id) {
|
||||
auto kernel = reinterpret_cast<ArithmeticCompareCPUKernel *>(cdata);
|
||||
auto ret = kernel->DoExecute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCompareCPUKernel::Run() {
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *parameter, const lite::InnerContext *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *kernel = new (std::nothrow) ArithmeticCompareCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new ArithmeticSelfCPUKernel fail!";
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/arm/fp32/arithmetic.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
|
||||
public:
|
||||
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
func_ = GetArithmeticCompareFun(parameter->type_);
|
||||
}
|
||||
~ArithmeticCompareCPUKernel() override = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
virtual int DoExecute(int task_id);
|
||||
|
||||
private:
|
||||
ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type);
|
||||
ArithmeticCompareFp32Func func_;
|
||||
};
|
||||
int ArithmeticCompareRun(void *cdata, int task_id);
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_COMPARE_H_
|
Loading…
Reference in new issue