From 44fb47ed8470fb94d8e00dfa16b5e8321cd120cc Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Fri, 22 Jan 2021 15:21:17 +0800 Subject: [PATCH] [MSLITE][Develop] add cpu fp16 op: lstm --- mindspore/lite/nnacl/fp16/lstm_fp16.c | 235 ++++++++++++++++++ mindspore/lite/nnacl/fp16/lstm_fp16.h | 38 +++ mindspore/lite/nnacl/fp32/lstm_fp32.h | 21 +- mindspore/lite/nnacl/lstm_parameter.h | 39 +++ .../src/runtime/kernel/arm/fp16/lstm_fp16.cc | 203 +++++++++++++++ .../src/runtime/kernel/arm/fp16/lstm_fp16.h | 55 ++++ .../src/runtime/kernel/arm/fp32/lstm_fp32.cc | 42 ++-- .../src/runtime/kernel/arm/fp32/lstm_fp32.h | 4 +- 8 files changed, 594 insertions(+), 43 deletions(-) create mode 100644 mindspore/lite/nnacl/fp16/lstm_fp16.c create mode 100644 mindspore/lite/nnacl/fp16/lstm_fp16.h create mode 100644 mindspore/lite/nnacl/lstm_parameter.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h diff --git a/mindspore/lite/nnacl/fp16/lstm_fp16.c b/mindspore/lite/nnacl/fp16/lstm_fp16.c new file mode 100644 index 0000000000..dc7effd265 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/lstm_fp16.c @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/lstm_fp16.h" +#include +#include "nnacl/fp16/activation_fp16.h" +#include "nnacl/fp16/arithmetic_fp16.h" + +void InitGateFp16(float16_t *gate_buffer, const float16_t *bias, const LstmParameter *lstm_parm) { + int gate_offest = 0; + for (int l = 0; l < 4; l++) { + int batch_offest = gate_offest; + int bias_offest = l * lstm_parm->hidden_size_; + for (int b = 0; b < lstm_parm->batch_; b++) { + memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float16_t)); + batch_offest += lstm_parm->hidden_size_; + } + gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_; + } +} + +// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] +void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, + int inner_size) { + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + float16_t res = 0; + const float16_t *input_col = input + r * inner_size; + const float16_t *weight_col = weight + c * inner_size; + int index = 0; + float16x8_t out = vdupq_n_f16(0.0f); + for (; index <= inner_size - 8; index += 8) { + float16x8_t in_0 = vld1q_f16(input_col + index); + float16x8_t in_1 = vld1q_f16(weight_col + index); + out = vfmaq_f16(out, in_1, in_0); + } + float16x4_t add2 = vadd_f16(vget_low_f16(out), vget_high_f16(out)); + float16x4_t add4 = vpadd_f16(add2, add2); + float16x4_t add8 = vpadd_f16(add4, add4); + res += vget_lane_f16(add8, 0); + for (; index < inner_size; index++) { + res += input_col[index] * weight_col[index]; + } + output[r * cols + c] += res; + } + } +} + +void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; + for (; index <= element_size - 8; index += 8) { + float16x8_t in_0 = vld1q_f16(input0 + index); + float16x8_t in_1 = vld1q_f16(input1 + index); + float16x8_t out = vld1q_f16(output + index); + out = vfmaq_f16(out, in_1, in_0); + vst1q_f16(output + index, out); + } + for (; index < element_size; index++) { + output[index] += input0[index] * input1[index]; + } +} + +int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size) { + int index = 0; + for (; index <= element_size - 8; index += 8) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vld1q_f16(output + index); + vout = vfmaq_n_f16(vout, vin0, input1); + vst1q_f16(output + index, vout); + } + for (; index < element_size; index++) { + output[index] += input0[index] * input1; + } + return NNACL_OK; +} + +void UpdataStateFp16(float16_t *cell_state, float16_t *forget_gate, const float16_t *input_gate, + const float16_t *cell_gate, float16_t *state_buffer, int batch, int hidden_size, + float16_t smooth) { + if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // smooth * old_cell_state + memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float16_t)); + ArithmeticParameter parameter; + parameter.in_elements_num0_ = batch * hidden_size; + parameter.in_elements_num1_ = 1; + ElementOptMulFp16(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); + } + + ElementMulFp16(forget_gate, cell_state, cell_state, batch * hidden_size); + ElementMulAccFp16(input_gate, cell_gate, cell_state, batch * hidden_size); + + if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // (1 - smooth) * new_cell_state + ElementOptMulAccFp16(cell_state, 1 - smooth, state_buffer, batch * hidden_size); + } +} + +void UpdataOutputFp16(const float16_t *cell_state, float16_t *output_gate, float16_t *hidden_state, + float16_t *state_buffer_in, int batch, int hidden_size, float16_t smooth) { + float16_t *state_buffer = state_buffer_in + batch * hidden_size; + if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { + memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float16_t)); + ArithmeticParameter parameter; + parameter.in_elements_num0_ = batch * hidden_size; + parameter.in_elements_num1_ = 1; + ElementOptMulFp16(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); + } + + TanhFp16(cell_state, hidden_state, batch * hidden_size); + ElementMulFp16(hidden_state, output_gate, hidden_state, batch * hidden_size); + + if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { + ElementOptMulAccFp16(hidden_state, 1 - smooth, state_buffer, batch * hidden_size); + } +} + +void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_input_weight, + const float16_t *input_forget_weight, const float16_t *input_cell_weight, + const float16_t *input_output_weight, const float16_t *state_input_weight, + const float16_t *state_forget_weight, const float16_t *state_cell_weight, + const float16_t *state_output_weight, const float16_t *bias, float16_t *hidden_state, + float16_t *cell_state, float16_t *gate_buffer, float16_t *state_buffer, + const LstmParameter *lstm_parm) { + InitGateFp16(gate_buffer, bias, lstm_parm); + + float16_t *input_gate = gate_buffer; + float16_t *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2; + float16_t *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3; + float16_t *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1; + + // input * weight + MatMulAccFp16(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->input_size_); + MatMulAccFp16(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->input_size_); + MatMulAccFp16(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->input_size_); + MatMulAccFp16(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->input_size_); + + // state * weight + MatMulAccFp16(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + MatMulAccFp16(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + MatMulAccFp16(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + MatMulAccFp16(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + + // update input_gate + SigmoidFp16(input_gate, input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_); + + // update forget_gate + SigmoidFp16(forget_gate, forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_); + + // update cell_gate + TanhFp16(cell_gate, cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_); + // update cell state + UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, + lstm_parm->hidden_size_, lstm_parm->smooth_); + + // update output_gate + SigmoidFp16(output_gate, output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_); + // update output + UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->smooth_); + memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t)); + + if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) { + memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t)); + memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_, + lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t)); + } +} + +void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, + const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer, + float16_t *state_buffer, const LstmParameter *lstm_parm) { + // forward + const float16_t *input_input_weight = weight_i; + const float16_t *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; + const float16_t *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3; + const float16_t *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1; + + const float16_t *state_input_weight = weight_h; + const float16_t *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2; + const float16_t *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3; + const float16_t *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1; + + for (int t = 0; t < lstm_parm->seq_len_; t++) { + const float16_t *input_ptr = input + t * lstm_parm->input_step_; + float16_t *output_ptr = output + t * lstm_parm->output_step_; + LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, + input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, + state_output_weight, bias, hidden_state, cell_state, gate_buffer, state_buffer, lstm_parm); + } + + // backward + if (lstm_parm->bidirectional_) { + input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4; + input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6; + input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7; + input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5; + + state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4; + state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6; + state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7; + state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5; + + float16_t *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_; + const float16_t *backward_bias = bias + 4 * lstm_parm->hidden_size_; + float16_t *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_; + float16_t *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_; + for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) { + const float16_t *input_ptr = input + t * lstm_parm->input_step_; + float16_t *output_ptr = backward_output + t * lstm_parm->output_step_; + LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, + input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, + state_output_weight, backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, + state_buffer, lstm_parm); + } + } +} diff --git a/mindspore/lite/nnacl/fp16/lstm_fp16.h b/mindspore/lite/nnacl/fp16/lstm_fp16.h new file mode 100644 index 0000000000..d047402e42 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/lstm_fp16.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_FP16_LSTM_H_ +#define MINDSPORE_LITE_NNACL_FP16_LSTM_H_ + +#include "nnacl/lstm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, + int inner_size); + +void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size); + +void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, + const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer, + float16_t *state_buffer, const LstmParameter *lstm_parm); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP16_LSTM_H_ diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.h b/mindspore/lite/nnacl/fp32/lstm_fp32.h index 265e56058e..207fd2bd24 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.h +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.h @@ -17,26 +17,7 @@ #ifndef MINDSPORE_LITE_NNACL_FP32_LSTM_H_ #define MINDSPORE_LITE_NNACL_FP32_LSTM_H_ -#include "nnacl/op_base.h" - -typedef struct LstmParameter { - // Primitive parameter - OpParameter op_parameter_; - // shape correlative - int input_size_; - int hidden_size_; // output_size - int seq_len_; - int batch_; - // other parameter - int input_step_; - int output_step_; - bool bidirectional_; - // smooth factor for hidden/cell state calculation: - // output_hidden = old_hidden * smooth + new_hidden * (1 - smooth) - // output_cell = old_cell * smooth + new_cell * (1 - smooth) - float smooth_; -} LstmParameter; - +#include "nnacl/lstm_parameter.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/lite/nnacl/lstm_parameter.h b/mindspore/lite/nnacl/lstm_parameter.h new file mode 100644 index 0000000000..e9bc1621d8 --- /dev/null +++ b/mindspore/lite/nnacl/lstm_parameter.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_ +#define MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct LstmParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + // other parameter + int input_step_; + int output_step_; + bool bidirectional_; + // smooth factor for hidden/cell state calculation: + // output_hidden = old_hidden * smooth + new_hidden * (1 - smooth) + // output_cell = old_cell * smooth + new_cell * (1 - smooth) + float smooth_; +} LstmParameter; + +#endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc new file mode 100644 index 0000000000..81a0880781 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc @@ -0,0 +1,203 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/lstm_fp16.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "nnacl/fp16/lstm_fp16.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Lstm; + +namespace mindspore::kernel { +void LstmFp16CPUKernel::FreeTmpBuffer() { + if (gate_buffer_ != nullptr) { + free(gate_buffer_); + gate_buffer_ = nullptr; + } + if (state_buffer_ != nullptr) { + free(state_buffer_); + state_buffer_ = nullptr; + } + if (weight_i_ptr_ != nullptr) { + free(weight_i_ptr_); + weight_i_ptr_ = nullptr; + } + if (weight_h_ptr_ != nullptr) { + free(weight_h_ptr_); + weight_h_ptr_ = nullptr; + } + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } +} + +int LstmFp16CPUKernel::InitParam() { + auto input = in_tensors_.front(); + MS_ASSERT(input != nullptr); + std::vector in_shape = input->shape(); + lstm_param_->seq_len_ = in_shape.at(0); + lstm_param_->batch_ = in_shape.at(1); + lstm_param_->input_size_ = in_shape.at(2); + + auto weight_i = in_tensors_.at(1); + MS_ASSERT(weight_i != nullptr); + std::vector w_shape = weight_i->shape(); + lstm_param_->hidden_size_ = w_shape.at(1) / 4; + + lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_; + lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ + : lstm_param_->batch_ * lstm_param_->hidden_size_; + return RET_OK; +} + +int LstmFp16CPUKernel::InitBuffer() { + gate_buffer_ = + reinterpret_cast(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); + if (gate_buffer_ == nullptr) { + MS_LOG(ERROR) << "Lstm fp16 malloc gate_buffer error."; + return RET_ERROR; + } + if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) { + int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t); + state_buffer_ = reinterpret_cast(malloc(buffer_size)); + if (state_buffer_ == nullptr) { + MS_LOG(ERROR) << "Lstm fp16 malloc state_buffer error."; + return RET_ERROR; + } + } + return RET_OK; +} + +int LstmFp16CPUKernel::InitWeightBias() { + // copy weight_i and weight_h + auto weight_i = in_tensors_.at(1); + MS_ASSERT(weight_i != nullptr); + weight_i_ptr_ = reinterpret_cast(malloc(weight_i->ElementsNum() * sizeof(float16_t))); + if (weight_i_ptr_ == nullptr) { + MS_LOG(ERROR) << "Lstm fp16 malloc weight_i_ptr_ error."; + return RET_ERROR; + } + auto weight_i_data = reinterpret_cast(weight_i->data_c()); + for (size_t i = 0; i < weight_i->ElementsNum(); i++) { + weight_i_ptr_[i] = (float16_t)weight_i_data[i]; + } + + auto weight_h = in_tensors_.at(2); + MS_ASSERT(weight_h != nullptr); + weight_h_ptr_ = reinterpret_cast(malloc(weight_h->ElementsNum() * sizeof(float16_t))); + if (weight_h_ptr_ == nullptr) { + MS_LOG(ERROR) << "Lstm fp16 malloc weight_h_ error."; + return RET_ERROR; + } + auto weight_h_data = reinterpret_cast(weight_h->data_c()); + for (size_t i = 0; i < weight_h->ElementsNum(); i++) { + weight_h_ptr_[i] = (float16_t)weight_h_data[i]; + } + + std::vector w_shape = weight_i->shape(); + auto hidden_size = w_shape.at(1) / 4; + // init bias + int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size; + bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float16_t))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "Lstm fp16 malloc bias_ptr_ error."; + return RET_ERROR; + } + + auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); + const int state_bias_offset = 4 * hidden_size; + for (int i = 0; i < state_bias_offset; i++) { + bias_ptr_[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); + } + if (lstm_param_->bidirectional_) { + bias_data += 4 * hidden_size * 2; + auto backward_bias = bias_ptr_ + 4 * hidden_size; + for (int i = 0; i < state_bias_offset; i++) { + backward_bias[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); + } + } + return RET_OK; +} + +int LstmFp16CPUKernel::Init() { + FreeTmpBuffer(); + auto ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Lstm fp16 InitWeightBias error."; + FreeTmpBuffer(); + return RET_ERROR; + } + + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int LstmFp16CPUKernel::ReSize() { + auto ret = InitParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Lstm fp16 InitParam error."; + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Lstm fp16 InitBuffer error."; + FreeTmpBuffer(); + return RET_ERROR; + } + return RET_OK; +} + +int LstmFp16CPUKernel::Run() { + auto input = in_tensors_.at(kInputIndex); + MS_ASSERT(input != nullptr); + auto hidden_state = in_tensors_.at(4); + MS_ASSERT(hidden_state != nullptr); + auto cell_state = in_tensors_.at(5); + MS_ASSERT(cell_state != nullptr); + auto output = out_tensors_.at(0); + MS_ASSERT(output != nullptr); + + auto input_ptr = reinterpret_cast(input->data_c()); + MS_ASSERT(input_ptr); + auto output_ptr = reinterpret_cast(output->data_c()); + MS_ASSERT(output_ptr); + auto output_hidden_state = out_tensors_[1]; + memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float16_t)); + auto output_cell_state = out_tensors_[2]; + memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float16_t)); + + MS_ASSERT(weight_h_ptr_); + MS_ASSERT(weight_i_ptr_); + MS_ASSERT(bias_ptr_); + MS_ASSERT(gate_buffer_); + LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, + reinterpret_cast(output_hidden_state->data_c()), + reinterpret_cast(output_cell_state->data_c()), gate_buffer_, state_buffer_, lstm_param_); + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Lstm, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h new file mode 100644 index 0000000000..4527213e37 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LSTM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LSTM_H_ + +#include +#include "src/lite_kernel.h" +#include "nnacl/lstm_parameter.h" + +namespace mindspore::kernel { +class LstmFp16CPUKernel : public LiteKernel { + public: + LstmFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + lstm_param_ = reinterpret_cast(op_parameter_); + } + + ~LstmFp16CPUKernel() override { FreeTmpBuffer(); } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + void FreeTmpBuffer(); + int InitParam(); + int InitBuffer(); + int InitWeightBias(); + + float16_t *gate_buffer_ = nullptr; + float16_t *state_buffer_ = nullptr; + float16_t *weight_i_ptr_ = nullptr; + float16_t *weight_h_ptr_ = nullptr; + float16_t *bias_ptr_ = nullptr; + LstmParameter *lstm_param_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LSTM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc index 298d535bf5..3c2c2f4fc1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc @@ -55,29 +55,29 @@ int LstmCPUKernel::InitParam() { auto input = in_tensors_.front(); MS_ASSERT(input != nullptr); std::vector in_shape = input->shape(); - lstm_parm_->seq_len_ = in_shape.at(0); - lstm_parm_->batch_ = in_shape.at(1); - lstm_parm_->input_size_ = in_shape.at(2); + lstm_param_->seq_len_ = in_shape.at(0); + lstm_param_->batch_ = in_shape.at(1); + lstm_param_->input_size_ = in_shape.at(2); auto weight_i = in_tensors_.at(1); MS_ASSERT(weight_i != nullptr); std::vector w_shape = weight_i->shape(); - lstm_parm_->hidden_size_ = w_shape.at(1) / 4; + lstm_param_->hidden_size_ = w_shape.at(1) / 4; - lstm_parm_->input_step_ = lstm_parm_->batch_ * lstm_parm_->input_size_; - lstm_parm_->output_step_ = lstm_parm_->bidirectional_ ? 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ - : lstm_parm_->batch_ * lstm_parm_->hidden_size_; + lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_; + lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ + : lstm_param_->batch_ * lstm_param_->hidden_size_; return RET_OK; } int LstmCPUKernel::InitBuffer() { - gate_buffer_ = reinterpret_cast(malloc(4 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float))); + gate_buffer_ = reinterpret_cast(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float))); if (gate_buffer_ == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; return RET_ERROR; } - if (!(lstm_parm_->smooth_ >= -FLT_EPSILON && lstm_parm_->smooth_ <= FLT_EPSILON)) { - int buffer_size = 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float); + if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) { + int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); state_buffer_ = reinterpret_cast(malloc(buffer_size)); if (state_buffer_ == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error."; @@ -96,7 +96,7 @@ int LstmCPUKernel::InitWeightBias() { MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; return RET_ERROR; } - memcpy(weight_i_ptr_, weight_i->MutableData(), weight_i->ElementsNum() * sizeof(float)); + memcpy(weight_i_ptr_, weight_i->data_c(), weight_i->ElementsNum() * sizeof(float)); auto weight_h = in_tensors_.at(2); MS_ASSERT(weight_h != nullptr); @@ -105,24 +105,24 @@ int LstmCPUKernel::InitWeightBias() { MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ error."; return RET_ERROR; } - memcpy(weight_h_ptr_, weight_h->MutableData(), weight_h->ElementsNum() * sizeof(float)); + memcpy(weight_h_ptr_, weight_h->data_c(), weight_h->ElementsNum() * sizeof(float)); std::vector w_shape = weight_i->shape(); auto hidden_size = w_shape.at(1) / 4; // init bias - int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size; + int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size; bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float))); if (bias_ptr_ == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; return RET_ERROR; } - auto bias_data = reinterpret_cast(in_tensors_.at(3)->MutableData()); + auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); const int state_bias_offset = 4 * hidden_size; for (int i = 0; i < state_bias_offset; i++) { bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; } - if (lstm_parm_->bidirectional_) { + if (lstm_param_->bidirectional_) { bias_data += 4 * hidden_size * 2; auto backward_bias = bias_ptr_ + 4 * hidden_size; for (int i = 0; i < state_bias_offset; i++) { @@ -173,22 +173,22 @@ int LstmCPUKernel::Run() { auto output = out_tensors_.at(0); MS_ASSERT(output != nullptr); - auto input_ptr = reinterpret_cast(input->MutableData()); + auto input_ptr = reinterpret_cast(input->data_c()); MS_ASSERT(input_ptr); - auto output_ptr = reinterpret_cast(output->MutableData()); + auto output_ptr = reinterpret_cast(output->data_c()); MS_ASSERT(output_ptr); auto output_hidden_state = out_tensors_[1]; - memcpy(output_hidden_state->MutableData(), hidden_state->MutableData(), hidden_state->ElementsNum() * sizeof(float)); + memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float)); auto output_cell_state = out_tensors_[2]; - memcpy(output_cell_state->MutableData(), cell_state->MutableData(), cell_state->ElementsNum() * sizeof(float)); + memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float)); MS_ASSERT(weight_h_ptr_); MS_ASSERT(weight_i_ptr_); MS_ASSERT(bias_ptr_); MS_ASSERT(gate_buffer_); Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, - reinterpret_cast(output_hidden_state->MutableData()), - reinterpret_cast(output_cell_state->MutableData()), gate_buffer_, state_buffer_, lstm_parm_); + reinterpret_cast(output_hidden_state->data_c()), reinterpret_cast(output_cell_state->data_c()), + gate_buffer_, state_buffer_, lstm_param_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h index a2ced62b73..0980762ca4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h @@ -28,7 +28,7 @@ class LstmCPUKernel : public LiteKernel { const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) { - lstm_parm_ = reinterpret_cast(op_parameter_); + lstm_param_ = reinterpret_cast(op_parameter_); } ~LstmCPUKernel() override { FreeTmpBuffer(); } @@ -48,7 +48,7 @@ class LstmCPUKernel : public LiteKernel { float *weight_i_ptr_ = nullptr; float *weight_h_ptr_ = nullptr; float *bias_ptr_ = nullptr; - LstmParameter *lstm_parm_ = nullptr; + LstmParameter *lstm_param_ = nullptr; }; } // namespace mindspore::kernel