!5441 [MS][LITE][Develop]Refactor arithmetic fp16 kernel

Merge pull request !5441 from sunsuodong/fix_arithmetic
pull/5441/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 85b1dae578

@ -23,19 +23,25 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
namespace mindspore::kernel { namespace mindspore::kernel {
class ArithmeticFP16CPUKernel : public LiteKernel { typedef int (*ArithmeticFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
typedef int (*ArithmeticRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size); typedef int (*ArithmeticOptFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
typedef int (*ArithmeticOptRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); ArithmeticParameter *param);
typedef struct {
int primitive_type_;
int activation_type_;
ArithmeticFuncFp16 func_;
ArithmeticOptFuncFp16 opt_func_;
} ARITHMETIC_FUNC_INFO_FP16;
class ArithmeticFP16CPUKernel : public LiteKernel {
public: public:
ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) { : LiteKernel(parameter, inputs, outputs, ctx, primitive) {
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter); param_ = reinterpret_cast<ArithmeticParameter *>(parameter);
} }
~ArithmeticFP16CPUKernel() override; ~ArithmeticFP16CPUKernel() = default;
int Init() override; int Init() override;
int ReSize() override; int ReSize() override;
@ -48,14 +54,15 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
void FreeTmpBuffer(); void FreeTmpBuffer();
int outside_; int outside_;
int break_pos_; int break_pos_;
int out_thread_stride_; bool is_input0_fp32_ = false;
int out_count_; bool is_input1_fp32_ = false;
bool is_output_fp32_ = false;
float16_t *input0_fp16_ = nullptr; float16_t *input0_fp16_ = nullptr;
float16_t *input1_fp16_ = nullptr; float16_t *input1_fp16_ = nullptr;
float16_t *output_fp16_ = nullptr; float16_t *output_fp16_ = nullptr;
ArithmeticParameter *arithmeticParameter_ = nullptr; ArithmeticParameter *param_ = nullptr;
ArithmeticRun arithmetic_run_ = nullptr; ArithmeticFuncFp16 arithmetic_func_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr; ArithmeticOptFuncFp16 arithmetic_opt_func_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_

Loading…
Cancel
Save