|
|
|
@ -23,39 +23,46 @@
|
|
|
|
|
#include "schema/model_generated.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
class ArithmeticFP16CPUKernel : public LiteKernel {
|
|
|
|
|
typedef int (*ArithmeticRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
|
|
|
|
typedef int (*ArithmeticOptRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
|
|
|
|
ArithmeticParameter *param);
|
|
|
|
|
typedef int (*ArithmeticFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
|
|
|
|
typedef int (*ArithmeticOptFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
|
|
|
|
ArithmeticParameter *param);
|
|
|
|
|
typedef struct {
|
|
|
|
|
int primitive_type_;
|
|
|
|
|
int activation_type_;
|
|
|
|
|
ArithmeticFuncFp16 func_;
|
|
|
|
|
ArithmeticOptFuncFp16 opt_func_;
|
|
|
|
|
} ARITHMETIC_FUNC_INFO_FP16;
|
|
|
|
|
|
|
|
|
|
class ArithmeticFP16CPUKernel : public LiteKernel {
|
|
|
|
|
public:
|
|
|
|
|
ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
|
|
|
|
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
|
|
|
|
const mindspore::lite::PrimitiveC *primitive)
|
|
|
|
|
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
|
|
|
|
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
|
|
|
|
|
param_ = reinterpret_cast<ArithmeticParameter *>(parameter);
|
|
|
|
|
}
|
|
|
|
|
~ArithmeticFP16CPUKernel() override;
|
|
|
|
|
~ArithmeticFP16CPUKernel() = default;
|
|
|
|
|
|
|
|
|
|
int Init() override;
|
|
|
|
|
int ReSize() override;
|
|
|
|
|
int Run() override;
|
|
|
|
|
int DoArithmetic(int task_id);
|
|
|
|
|
int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count,
|
|
|
|
|
int out_thread_stride);
|
|
|
|
|
int out_thread_stride);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void FreeTmpBuffer();
|
|
|
|
|
int outside_;
|
|
|
|
|
int break_pos_;
|
|
|
|
|
int out_thread_stride_;
|
|
|
|
|
int out_count_;
|
|
|
|
|
bool is_input0_fp32_ = false;
|
|
|
|
|
bool is_input1_fp32_ = false;
|
|
|
|
|
bool is_output_fp32_ = false;
|
|
|
|
|
float16_t *input0_fp16_ = nullptr;
|
|
|
|
|
float16_t *input1_fp16_ = nullptr;
|
|
|
|
|
float16_t *output_fp16_ = nullptr;
|
|
|
|
|
ArithmeticParameter *arithmeticParameter_ = nullptr;
|
|
|
|
|
ArithmeticRun arithmetic_run_ = nullptr;
|
|
|
|
|
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
|
|
|
|
|
ArithmeticParameter *param_ = nullptr;
|
|
|
|
|
ArithmeticFuncFp16 arithmetic_func_ = nullptr;
|
|
|
|
|
ArithmeticOptFuncFp16 arithmetic_opt_func_ = nullptr;
|
|
|
|
|
};
|
|
|
|
|
} // namespace mindspore::kernel
|
|
|
|
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_
|
|
|
|
|