[MSLITE][Develop] add cpu fp16 op: lstm

pull/11610/head
yangruoqi713 4 years ago
parent 54b8d53780
commit 44fb47ed84

@ -0,0 +1,235 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/lstm_fp16.h"
#include <string.h>
#include "nnacl/fp16/activation_fp16.h"
#include "nnacl/fp16/arithmetic_fp16.h"
void InitGateFp16(float16_t *gate_buffer, const float16_t *bias, const LstmParameter *lstm_parm) {
int gate_offest = 0;
for (int l = 0; l < 4; l++) {
int batch_offest = gate_offest;
int bias_offest = l * lstm_parm->hidden_size_;
for (int b = 0; b < lstm_parm->batch_; b++) {
memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float16_t));
batch_offest += lstm_parm->hidden_size_;
}
gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_;
}
}
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols,
int inner_size) {
for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
float16_t res = 0;
const float16_t *input_col = input + r * inner_size;
const float16_t *weight_col = weight + c * inner_size;
int index = 0;
float16x8_t out = vdupq_n_f16(0.0f);
for (; index <= inner_size - 8; index += 8) {
float16x8_t in_0 = vld1q_f16(input_col + index);
float16x8_t in_1 = vld1q_f16(weight_col + index);
out = vfmaq_f16(out, in_1, in_0);
}
float16x4_t add2 = vadd_f16(vget_low_f16(out), vget_high_f16(out));
float16x4_t add4 = vpadd_f16(add2, add2);
float16x4_t add8 = vpadd_f16(add4, add4);
res += vget_lane_f16(add8, 0);
for (; index < inner_size; index++) {
res += input_col[index] * weight_col[index];
}
output[r * cols + c] += res;
}
}
}
void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
for (; index <= element_size - 8; index += 8) {
float16x8_t in_0 = vld1q_f16(input0 + index);
float16x8_t in_1 = vld1q_f16(input1 + index);
float16x8_t out = vld1q_f16(output + index);
out = vfmaq_f16(out, in_1, in_0);
vst1q_f16(output + index, out);
}
for (; index < element_size; index++) {
output[index] += input0[index] * input1[index];
}
}
int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size) {
int index = 0;
for (; index <= element_size - 8; index += 8) {
float16x8_t vin0 = vld1q_f16(input0 + index);
float16x8_t vout = vld1q_f16(output + index);
vout = vfmaq_n_f16(vout, vin0, input1);
vst1q_f16(output + index, vout);
}
for (; index < element_size; index++) {
output[index] += input0[index] * input1;
}
return NNACL_OK;
}
void UpdataStateFp16(float16_t *cell_state, float16_t *forget_gate, const float16_t *input_gate,
const float16_t *cell_gate, float16_t *state_buffer, int batch, int hidden_size,
float16_t smooth) {
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // smooth * old_cell_state
memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float16_t));
ArithmeticParameter parameter;
parameter.in_elements_num0_ = batch * hidden_size;
parameter.in_elements_num1_ = 1;
ElementOptMulFp16(state_buffer, &smooth, state_buffer, batch * hidden_size, &parameter);
}
ElementMulFp16(forget_gate, cell_state, cell_state, batch * hidden_size);
ElementMulAccFp16(input_gate, cell_gate, cell_state, batch * hidden_size);
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // (1 - smooth) * new_cell_state
ElementOptMulAccFp16(cell_state, 1 - smooth, state_buffer, batch * hidden_size);
}
}
void UpdataOutputFp16(const float16_t *cell_state, float16_t *output_gate, float16_t *hidden_state,
float16_t *state_buffer_in, int batch, int hidden_size, float16_t smooth) {
float16_t *state_buffer = state_buffer_in + batch * hidden_size;
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) {
memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float16_t));
ArithmeticParameter parameter;
parameter.in_elements_num0_ = batch * hidden_size;
parameter.in_elements_num1_ = 1;
ElementOptMulFp16(state_buffer, &smooth, state_buffer, batch * hidden_size, &parameter);
}
TanhFp16(cell_state, hidden_state, batch * hidden_size);
ElementMulFp16(hidden_state, output_gate, hidden_state, batch * hidden_size);
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) {
ElementOptMulAccFp16(hidden_state, 1 - smooth, state_buffer, batch * hidden_size);
}
}
void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_input_weight,
const float16_t *input_forget_weight, const float16_t *input_cell_weight,
const float16_t *input_output_weight, const float16_t *state_input_weight,
const float16_t *state_forget_weight, const float16_t *state_cell_weight,
const float16_t *state_output_weight, const float16_t *bias, float16_t *hidden_state,
float16_t *cell_state, float16_t *gate_buffer, float16_t *state_buffer,
const LstmParameter *lstm_parm) {
InitGateFp16(gate_buffer, bias, lstm_parm);
float16_t *input_gate = gate_buffer;
float16_t *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2;
float16_t *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3;
float16_t *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1;
// input * weight
MatMulAccFp16(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->input_size_);
MatMulAccFp16(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->input_size_);
MatMulAccFp16(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->input_size_);
MatMulAccFp16(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->input_size_);
// state * weight
MatMulAccFp16(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->hidden_size_);
MatMulAccFp16(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->hidden_size_);
MatMulAccFp16(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->hidden_size_);
MatMulAccFp16(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->hidden_size_);
// update input_gate
SigmoidFp16(input_gate, input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
// update forget_gate
SigmoidFp16(forget_gate, forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
// update cell_gate
TanhFp16(cell_gate, cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
// update cell state
UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_,
lstm_parm->hidden_size_, lstm_parm->smooth_);
// update output_gate
SigmoidFp16(output_gate, output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
// update output
UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->smooth_);
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t));
if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) {
memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t));
memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_,
lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t));
}
}
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer,
float16_t *state_buffer, const LstmParameter *lstm_parm) {
// forward
const float16_t *input_input_weight = weight_i;
const float16_t *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2;
const float16_t *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3;
const float16_t *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1;
const float16_t *state_input_weight = weight_h;
const float16_t *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2;
const float16_t *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3;
const float16_t *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1;
for (int t = 0; t < lstm_parm->seq_len_; t++) {
const float16_t *input_ptr = input + t * lstm_parm->input_step_;
float16_t *output_ptr = output + t * lstm_parm->output_step_;
LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight,
state_output_weight, bias, hidden_state, cell_state, gate_buffer, state_buffer, lstm_parm);
}
// backward
if (lstm_parm->bidirectional_) {
input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4;
input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6;
input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7;
input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5;
state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4;
state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6;
state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7;
state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5;
float16_t *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_;
const float16_t *backward_bias = bias + 4 * lstm_parm->hidden_size_;
float16_t *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
float16_t *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) {
const float16_t *input_ptr = input + t * lstm_parm->input_step_;
float16_t *output_ptr = backward_output + t * lstm_parm->output_step_;
LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight,
state_output_weight, backward_bias, backward_hidden_state, backward_cell_state, gate_buffer,
state_buffer, lstm_parm);
}
}
}

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_LSTM_H_
#define MINDSPORE_LITE_NNACL_FP16_LSTM_H_
#include "nnacl/lstm_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols,
int inner_size);
void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size);
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer,
float16_t *state_buffer, const LstmParameter *lstm_parm);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP16_LSTM_H_

@ -17,26 +17,7 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_LSTM_H_
#define MINDSPORE_LITE_NNACL_FP32_LSTM_H_
#include "nnacl/op_base.h"
typedef struct LstmParameter {
// Primitive parameter
OpParameter op_parameter_;
// shape correlative
int input_size_;
int hidden_size_; // output_size
int seq_len_;
int batch_;
// other parameter
int input_step_;
int output_step_;
bool bidirectional_;
// smooth factor for hidden/cell state calculation:
// output_hidden = old_hidden * smooth + new_hidden * (1 - smooth)
// output_cell = old_cell * smooth + new_cell * (1 - smooth)
float smooth_;
} LstmParameter;
#include "nnacl/lstm_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_
#include "nnacl/op_base.h"
typedef struct LstmParameter {
// Primitive parameter
OpParameter op_parameter_;
// shape correlative
int input_size_;
int hidden_size_; // output_size
int seq_len_;
int batch_;
// other parameter
int input_step_;
int output_step_;
bool bidirectional_;
// smooth factor for hidden/cell state calculation:
// output_hidden = old_hidden * smooth + new_hidden * (1 - smooth)
// output_cell = old_cell * smooth + new_cell * (1 - smooth)
float smooth_;
} LstmParameter;
#endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_

@ -0,0 +1,203 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/lstm_fp16.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "nnacl/fp16/lstm_fp16.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Lstm;
namespace mindspore::kernel {
void LstmFp16CPUKernel::FreeTmpBuffer() {
if (gate_buffer_ != nullptr) {
free(gate_buffer_);
gate_buffer_ = nullptr;
}
if (state_buffer_ != nullptr) {
free(state_buffer_);
state_buffer_ = nullptr;
}
if (weight_i_ptr_ != nullptr) {
free(weight_i_ptr_);
weight_i_ptr_ = nullptr;
}
if (weight_h_ptr_ != nullptr) {
free(weight_h_ptr_);
weight_h_ptr_ = nullptr;
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
}
}
int LstmFp16CPUKernel::InitParam() {
auto input = in_tensors_.front();
MS_ASSERT(input != nullptr);
std::vector<int> in_shape = input->shape();
lstm_param_->seq_len_ = in_shape.at(0);
lstm_param_->batch_ = in_shape.at(1);
lstm_param_->input_size_ = in_shape.at(2);
auto weight_i = in_tensors_.at(1);
MS_ASSERT(weight_i != nullptr);
std::vector<int> w_shape = weight_i->shape();
lstm_param_->hidden_size_ = w_shape.at(1) / 4;
lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_;
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
: lstm_param_->batch_ * lstm_param_->hidden_size_;
return RET_OK;
}
int LstmFp16CPUKernel::InitBuffer() {
gate_buffer_ =
reinterpret_cast<float16_t *>(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
if (gate_buffer_ == nullptr) {
MS_LOG(ERROR) << "Lstm fp16 malloc gate_buffer error.";
return RET_ERROR;
}
if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) {
int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t);
state_buffer_ = reinterpret_cast<float16_t *>(malloc(buffer_size));
if (state_buffer_ == nullptr) {
MS_LOG(ERROR) << "Lstm fp16 malloc state_buffer error.";
return RET_ERROR;
}
}
return RET_OK;
}
int LstmFp16CPUKernel::InitWeightBias() {
// copy weight_i and weight_h
auto weight_i = in_tensors_.at(1);
MS_ASSERT(weight_i != nullptr);
weight_i_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_i->ElementsNum() * sizeof(float16_t)));
if (weight_i_ptr_ == nullptr) {
MS_LOG(ERROR) << "Lstm fp16 malloc weight_i_ptr_ error.";
return RET_ERROR;
}
auto weight_i_data = reinterpret_cast<float *>(weight_i->data_c());
for (size_t i = 0; i < weight_i->ElementsNum(); i++) {
weight_i_ptr_[i] = (float16_t)weight_i_data[i];
}
auto weight_h = in_tensors_.at(2);
MS_ASSERT(weight_h != nullptr);
weight_h_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_h->ElementsNum() * sizeof(float16_t)));
if (weight_h_ptr_ == nullptr) {
MS_LOG(ERROR) << "Lstm fp16 malloc weight_h_ error.";
return RET_ERROR;
}
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
for (size_t i = 0; i < weight_h->ElementsNum(); i++) {
weight_h_ptr_[i] = (float16_t)weight_h_data[i];
}
std::vector<int> w_shape = weight_i->shape();
auto hidden_size = w_shape.at(1) / 4;
// init bias
int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size;
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(bias_num * sizeof(float16_t)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "Lstm fp16 malloc bias_ptr_ error.";
return RET_ERROR;
}
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
const int state_bias_offset = 4 * hidden_size;
for (int i = 0; i < state_bias_offset; i++) {
bias_ptr_[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]);
}
if (lstm_param_->bidirectional_) {
bias_data += 4 * hidden_size * 2;
auto backward_bias = bias_ptr_ + 4 * hidden_size;
for (int i = 0; i < state_bias_offset; i++) {
backward_bias[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]);
}
}
return RET_OK;
}
int LstmFp16CPUKernel::Init() {
FreeTmpBuffer();
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Lstm fp16 InitWeightBias error.";
FreeTmpBuffer();
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int LstmFp16CPUKernel::ReSize() {
auto ret = InitParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Lstm fp16 InitParam error.";
return RET_ERROR;
}
ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Lstm fp16 InitBuffer error.";
FreeTmpBuffer();
return RET_ERROR;
}
return RET_OK;
}
int LstmFp16CPUKernel::Run() {
auto input = in_tensors_.at(kInputIndex);
MS_ASSERT(input != nullptr);
auto hidden_state = in_tensors_.at(4);
MS_ASSERT(hidden_state != nullptr);
auto cell_state = in_tensors_.at(5);
MS_ASSERT(cell_state != nullptr);
auto output = out_tensors_.at(0);
MS_ASSERT(output != nullptr);
auto input_ptr = reinterpret_cast<float16_t *>(input->data_c());
MS_ASSERT(input_ptr);
auto output_ptr = reinterpret_cast<float16_t *>(output->data_c());
MS_ASSERT(output_ptr);
auto output_hidden_state = out_tensors_[1];
memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float16_t));
auto output_cell_state = out_tensors_[2];
memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float16_t));
MS_ASSERT(weight_h_ptr_);
MS_ASSERT(weight_i_ptr_);
MS_ASSERT(bias_ptr_);
MS_ASSERT(gate_buffer_);
LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
reinterpret_cast<float16_t *>(output_hidden_state->data_c()),
reinterpret_cast<float16_t *>(output_cell_state->data_c()), gate_buffer_, state_buffer_, lstm_param_);
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Lstm, LiteKernelCreator<LstmFp16CPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LSTM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LSTM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/lstm_parameter.h"
namespace mindspore::kernel {
class LstmFp16CPUKernel : public LiteKernel {
public:
LstmFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_);
}
~LstmFp16CPUKernel() override { FreeTmpBuffer(); }
int Init() override;
int ReSize() override;
int Run() override;
private:
void FreeTmpBuffer();
int InitParam();
int InitBuffer();
int InitWeightBias();
float16_t *gate_buffer_ = nullptr;
float16_t *state_buffer_ = nullptr;
float16_t *weight_i_ptr_ = nullptr;
float16_t *weight_h_ptr_ = nullptr;
float16_t *bias_ptr_ = nullptr;
LstmParameter *lstm_param_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LSTM_H_

@ -55,29 +55,29 @@ int LstmCPUKernel::InitParam() {
auto input = in_tensors_.front();
MS_ASSERT(input != nullptr);
std::vector<int> in_shape = input->shape();
lstm_parm_->seq_len_ = in_shape.at(0);
lstm_parm_->batch_ = in_shape.at(1);
lstm_parm_->input_size_ = in_shape.at(2);
lstm_param_->seq_len_ = in_shape.at(0);
lstm_param_->batch_ = in_shape.at(1);
lstm_param_->input_size_ = in_shape.at(2);
auto weight_i = in_tensors_.at(1);
MS_ASSERT(weight_i != nullptr);
std::vector<int> w_shape = weight_i->shape();
lstm_parm_->hidden_size_ = w_shape.at(1) / 4;
lstm_param_->hidden_size_ = w_shape.at(1) / 4;
lstm_parm_->input_step_ = lstm_parm_->batch_ * lstm_parm_->input_size_;
lstm_parm_->output_step_ = lstm_parm_->bidirectional_ ? 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_
: lstm_parm_->batch_ * lstm_parm_->hidden_size_;
lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_;
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
: lstm_param_->batch_ * lstm_param_->hidden_size_;
return RET_OK;
}
int LstmCPUKernel::InitBuffer() {
gate_buffer_ = reinterpret_cast<float *>(malloc(4 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float)));
gate_buffer_ = reinterpret_cast<float *>(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float)));
if (gate_buffer_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error.";
return RET_ERROR;
}
if (!(lstm_parm_->smooth_ >= -FLT_EPSILON && lstm_parm_->smooth_ <= FLT_EPSILON)) {
int buffer_size = 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float);
if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) {
int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
state_buffer_ = reinterpret_cast<float *>(malloc(buffer_size));
if (state_buffer_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error.";
@ -96,7 +96,7 @@ int LstmCPUKernel::InitWeightBias() {
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error.";
return RET_ERROR;
}
memcpy(weight_i_ptr_, weight_i->MutableData(), weight_i->ElementsNum() * sizeof(float));
memcpy(weight_i_ptr_, weight_i->data_c(), weight_i->ElementsNum() * sizeof(float));
auto weight_h = in_tensors_.at(2);
MS_ASSERT(weight_h != nullptr);
@ -105,24 +105,24 @@ int LstmCPUKernel::InitWeightBias() {
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ error.";
return RET_ERROR;
}
memcpy(weight_h_ptr_, weight_h->MutableData(), weight_h->ElementsNum() * sizeof(float));
memcpy(weight_h_ptr_, weight_h->data_c(), weight_h->ElementsNum() * sizeof(float));
std::vector<int> w_shape = weight_i->shape();
auto hidden_size = w_shape.at(1) / 4;
// init bias
int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size;
int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size;
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error.";
return RET_ERROR;
}
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData());
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
const int state_bias_offset = 4 * hidden_size;
for (int i = 0; i < state_bias_offset; i++) {
bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset];
}
if (lstm_parm_->bidirectional_) {
if (lstm_param_->bidirectional_) {
bias_data += 4 * hidden_size * 2;
auto backward_bias = bias_ptr_ + 4 * hidden_size;
for (int i = 0; i < state_bias_offset; i++) {
@ -173,22 +173,22 @@ int LstmCPUKernel::Run() {
auto output = out_tensors_.at(0);
MS_ASSERT(output != nullptr);
auto input_ptr = reinterpret_cast<float *>(input->MutableData());
auto input_ptr = reinterpret_cast<float *>(input->data_c());
MS_ASSERT(input_ptr);
auto output_ptr = reinterpret_cast<float *>(output->MutableData());
auto output_ptr = reinterpret_cast<float *>(output->data_c());
MS_ASSERT(output_ptr);
auto output_hidden_state = out_tensors_[1];
memcpy(output_hidden_state->MutableData(), hidden_state->MutableData(), hidden_state->ElementsNum() * sizeof(float));
memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float));
auto output_cell_state = out_tensors_[2];
memcpy(output_cell_state->MutableData(), cell_state->MutableData(), cell_state->ElementsNum() * sizeof(float));
memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float));
MS_ASSERT(weight_h_ptr_);
MS_ASSERT(weight_i_ptr_);
MS_ASSERT(bias_ptr_);
MS_ASSERT(gate_buffer_);
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
reinterpret_cast<float *>(output_hidden_state->MutableData()),
reinterpret_cast<float *>(output_cell_state->MutableData()), gate_buffer_, state_buffer_, lstm_parm_);
reinterpret_cast<float *>(output_hidden_state->data_c()), reinterpret_cast<float *>(output_cell_state->data_c()),
gate_buffer_, state_buffer_, lstm_param_);
return RET_OK;
}

@ -28,7 +28,7 @@ class LstmCPUKernel : public LiteKernel {
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
lstm_parm_ = reinterpret_cast<LstmParameter *>(op_parameter_);
lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_);
}
~LstmCPUKernel() override { FreeTmpBuffer(); }
@ -48,7 +48,7 @@ class LstmCPUKernel : public LiteKernel {
float *weight_i_ptr_ = nullptr;
float *weight_h_ptr_ = nullptr;
float *bias_ptr_ = nullptr;
LstmParameter *lstm_parm_ = nullptr;
LstmParameter *lstm_param_ = nullptr;
};
} // namespace mindspore::kernel

Loading…
Cancel
Save