parent
54b8d53780
commit
44fb47ed84
@ -0,0 +1,235 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp16/lstm_fp16.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/fp16/activation_fp16.h"
|
||||
#include "nnacl/fp16/arithmetic_fp16.h"
|
||||
|
||||
void InitGateFp16(float16_t *gate_buffer, const float16_t *bias, const LstmParameter *lstm_parm) {
|
||||
int gate_offest = 0;
|
||||
for (int l = 0; l < 4; l++) {
|
||||
int batch_offest = gate_offest;
|
||||
int bias_offest = l * lstm_parm->hidden_size_;
|
||||
for (int b = 0; b < lstm_parm->batch_; b++) {
|
||||
memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float16_t));
|
||||
batch_offest += lstm_parm->hidden_size_;
|
||||
}
|
||||
gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
}
|
||||
}
|
||||
|
||||
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
|
||||
void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols,
|
||||
int inner_size) {
|
||||
for (int r = 0; r < rows; r++) {
|
||||
for (int c = 0; c < cols; c++) {
|
||||
float16_t res = 0;
|
||||
const float16_t *input_col = input + r * inner_size;
|
||||
const float16_t *weight_col = weight + c * inner_size;
|
||||
int index = 0;
|
||||
float16x8_t out = vdupq_n_f16(0.0f);
|
||||
for (; index <= inner_size - 8; index += 8) {
|
||||
float16x8_t in_0 = vld1q_f16(input_col + index);
|
||||
float16x8_t in_1 = vld1q_f16(weight_col + index);
|
||||
out = vfmaq_f16(out, in_1, in_0);
|
||||
}
|
||||
float16x4_t add2 = vadd_f16(vget_low_f16(out), vget_high_f16(out));
|
||||
float16x4_t add4 = vpadd_f16(add2, add2);
|
||||
float16x4_t add8 = vpadd_f16(add4, add4);
|
||||
res += vget_lane_f16(add8, 0);
|
||||
for (; index < inner_size; index++) {
|
||||
res += input_col[index] * weight_col[index];
|
||||
}
|
||||
output[r * cols + c] += res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
|
||||
int index = 0;
|
||||
for (; index <= element_size - 8; index += 8) {
|
||||
float16x8_t in_0 = vld1q_f16(input0 + index);
|
||||
float16x8_t in_1 = vld1q_f16(input1 + index);
|
||||
float16x8_t out = vld1q_f16(output + index);
|
||||
out = vfmaq_f16(out, in_1, in_0);
|
||||
vst1q_f16(output + index, out);
|
||||
}
|
||||
for (; index < element_size; index++) {
|
||||
output[index] += input0[index] * input1[index];
|
||||
}
|
||||
}
|
||||
|
||||
int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size) {
|
||||
int index = 0;
|
||||
for (; index <= element_size - 8; index += 8) {
|
||||
float16x8_t vin0 = vld1q_f16(input0 + index);
|
||||
float16x8_t vout = vld1q_f16(output + index);
|
||||
vout = vfmaq_n_f16(vout, vin0, input1);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
for (; index < element_size; index++) {
|
||||
output[index] += input0[index] * input1;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void UpdataStateFp16(float16_t *cell_state, float16_t *forget_gate, const float16_t *input_gate,
|
||||
const float16_t *cell_gate, float16_t *state_buffer, int batch, int hidden_size,
|
||||
float16_t smooth) {
|
||||
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // smooth * old_cell_state
|
||||
memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float16_t));
|
||||
ArithmeticParameter parameter;
|
||||
parameter.in_elements_num0_ = batch * hidden_size;
|
||||
parameter.in_elements_num1_ = 1;
|
||||
ElementOptMulFp16(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter);
|
||||
}
|
||||
|
||||
ElementMulFp16(forget_gate, cell_state, cell_state, batch * hidden_size);
|
||||
ElementMulAccFp16(input_gate, cell_gate, cell_state, batch * hidden_size);
|
||||
|
||||
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // (1 - smooth) * new_cell_state
|
||||
ElementOptMulAccFp16(cell_state, 1 - smooth, state_buffer, batch * hidden_size);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdataOutputFp16(const float16_t *cell_state, float16_t *output_gate, float16_t *hidden_state,
|
||||
float16_t *state_buffer_in, int batch, int hidden_size, float16_t smooth) {
|
||||
float16_t *state_buffer = state_buffer_in + batch * hidden_size;
|
||||
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) {
|
||||
memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float16_t));
|
||||
ArithmeticParameter parameter;
|
||||
parameter.in_elements_num0_ = batch * hidden_size;
|
||||
parameter.in_elements_num1_ = 1;
|
||||
ElementOptMulFp16(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter);
|
||||
}
|
||||
|
||||
TanhFp16(cell_state, hidden_state, batch * hidden_size);
|
||||
ElementMulFp16(hidden_state, output_gate, hidden_state, batch * hidden_size);
|
||||
|
||||
if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) {
|
||||
ElementOptMulAccFp16(hidden_state, 1 - smooth, state_buffer, batch * hidden_size);
|
||||
}
|
||||
}
|
||||
|
||||
void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_input_weight,
|
||||
const float16_t *input_forget_weight, const float16_t *input_cell_weight,
|
||||
const float16_t *input_output_weight, const float16_t *state_input_weight,
|
||||
const float16_t *state_forget_weight, const float16_t *state_cell_weight,
|
||||
const float16_t *state_output_weight, const float16_t *bias, float16_t *hidden_state,
|
||||
float16_t *cell_state, float16_t *gate_buffer, float16_t *state_buffer,
|
||||
const LstmParameter *lstm_parm) {
|
||||
InitGateFp16(gate_buffer, bias, lstm_parm);
|
||||
|
||||
float16_t *input_gate = gate_buffer;
|
||||
float16_t *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2;
|
||||
float16_t *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3;
|
||||
float16_t *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
// input * weight
|
||||
MatMulAccFp16(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
MatMulAccFp16(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
MatMulAccFp16(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
MatMulAccFp16(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
|
||||
// state * weight
|
||||
MatMulAccFp16(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAccFp16(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAccFp16(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAccFp16(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
|
||||
// update input_gate
|
||||
SigmoidFp16(input_gate, input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
|
||||
|
||||
// update forget_gate
|
||||
SigmoidFp16(forget_gate, forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
|
||||
|
||||
// update cell_gate
|
||||
TanhFp16(cell_gate, cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
|
||||
// update cell state
|
||||
UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_,
|
||||
lstm_parm->hidden_size_, lstm_parm->smooth_);
|
||||
|
||||
// update output_gate
|
||||
SigmoidFp16(output_gate, output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
|
||||
// update output
|
||||
UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->smooth_);
|
||||
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t));
|
||||
|
||||
if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) {
|
||||
memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t));
|
||||
memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_,
|
||||
lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
|
||||
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer,
|
||||
float16_t *state_buffer, const LstmParameter *lstm_parm) {
|
||||
// forward
|
||||
const float16_t *input_input_weight = weight_i;
|
||||
const float16_t *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2;
|
||||
const float16_t *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3;
|
||||
const float16_t *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
const float16_t *state_input_weight = weight_h;
|
||||
const float16_t *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2;
|
||||
const float16_t *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3;
|
||||
const float16_t *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
for (int t = 0; t < lstm_parm->seq_len_; t++) {
|
||||
const float16_t *input_ptr = input + t * lstm_parm->input_step_;
|
||||
float16_t *output_ptr = output + t * lstm_parm->output_step_;
|
||||
LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
|
||||
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight,
|
||||
state_output_weight, bias, hidden_state, cell_state, gate_buffer, state_buffer, lstm_parm);
|
||||
}
|
||||
|
||||
// backward
|
||||
if (lstm_parm->bidirectional_) {
|
||||
input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4;
|
||||
input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6;
|
||||
input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7;
|
||||
input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5;
|
||||
|
||||
state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4;
|
||||
state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6;
|
||||
state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7;
|
||||
state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5;
|
||||
|
||||
float16_t *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
const float16_t *backward_bias = bias + 4 * lstm_parm->hidden_size_;
|
||||
float16_t *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
float16_t *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) {
|
||||
const float16_t *input_ptr = input + t * lstm_parm->input_step_;
|
||||
float16_t *output_ptr = backward_output + t * lstm_parm->output_step_;
|
||||
LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
|
||||
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight,
|
||||
state_output_weight, backward_bias, backward_hidden_state, backward_cell_state, gate_buffer,
|
||||
state_buffer, lstm_parm);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP16_LSTM_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP16_LSTM_H_
|
||||
|
||||
#include "nnacl/lstm_parameter.h"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols,
|
||||
int inner_size);
|
||||
|
||||
void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
|
||||
|
||||
int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size);
|
||||
|
||||
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer,
|
||||
float16_t *state_buffer, const LstmParameter *lstm_parm);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP16_LSTM_H_
|
@ -0,0 +1,39 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct LstmParameter {
|
||||
// Primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
// shape correlative
|
||||
int input_size_;
|
||||
int hidden_size_; // output_size
|
||||
int seq_len_;
|
||||
int batch_;
|
||||
// other parameter
|
||||
int input_step_;
|
||||
int output_step_;
|
||||
bool bidirectional_;
|
||||
// smooth factor for hidden/cell state calculation:
|
||||
// output_hidden = old_hidden * smooth + new_hidden * (1 - smooth)
|
||||
// output_cell = old_cell * smooth + new_cell * (1 - smooth)
|
||||
float smooth_;
|
||||
} LstmParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_
|
@ -0,0 +1,203 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16/lstm_fp16.h"
|
||||
#include <vector>
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp16/lstm_fp16.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Lstm;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void LstmFp16CPUKernel::FreeTmpBuffer() {
|
||||
if (gate_buffer_ != nullptr) {
|
||||
free(gate_buffer_);
|
||||
gate_buffer_ = nullptr;
|
||||
}
|
||||
if (state_buffer_ != nullptr) {
|
||||
free(state_buffer_);
|
||||
state_buffer_ = nullptr;
|
||||
}
|
||||
if (weight_i_ptr_ != nullptr) {
|
||||
free(weight_i_ptr_);
|
||||
weight_i_ptr_ = nullptr;
|
||||
}
|
||||
if (weight_h_ptr_ != nullptr) {
|
||||
free(weight_h_ptr_);
|
||||
weight_h_ptr_ = nullptr;
|
||||
}
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::InitParam() {
|
||||
auto input = in_tensors_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
std::vector<int> in_shape = input->shape();
|
||||
lstm_param_->seq_len_ = in_shape.at(0);
|
||||
lstm_param_->batch_ = in_shape.at(1);
|
||||
lstm_param_->input_size_ = in_shape.at(2);
|
||||
|
||||
auto weight_i = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
std::vector<int> w_shape = weight_i->shape();
|
||||
lstm_param_->hidden_size_ = w_shape.at(1) / 4;
|
||||
|
||||
lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_;
|
||||
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
|
||||
: lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::InitBuffer() {
|
||||
gate_buffer_ =
|
||||
reinterpret_cast<float16_t *>(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc gate_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) {
|
||||
int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t);
|
||||
state_buffer_ = reinterpret_cast<float16_t *>(malloc(buffer_size));
|
||||
if (state_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc state_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::InitWeightBias() {
|
||||
// copy weight_i and weight_h
|
||||
auto weight_i = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
weight_i_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_i->ElementsNum() * sizeof(float16_t)));
|
||||
if (weight_i_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc weight_i_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_i_data = reinterpret_cast<float *>(weight_i->data_c());
|
||||
for (size_t i = 0; i < weight_i->ElementsNum(); i++) {
|
||||
weight_i_ptr_[i] = (float16_t)weight_i_data[i];
|
||||
}
|
||||
|
||||
auto weight_h = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_h != nullptr);
|
||||
weight_h_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_h->ElementsNum() * sizeof(float16_t)));
|
||||
if (weight_h_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc weight_h_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
|
||||
for (size_t i = 0; i < weight_h->ElementsNum(); i++) {
|
||||
weight_h_ptr_[i] = (float16_t)weight_h_data[i];
|
||||
}
|
||||
|
||||
std::vector<int> w_shape = weight_i->shape();
|
||||
auto hidden_size = w_shape.at(1) / 4;
|
||||
// init bias
|
||||
int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size;
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(bias_num * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc bias_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
|
||||
const int state_bias_offset = 4 * hidden_size;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
bias_ptr_[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]);
|
||||
}
|
||||
if (lstm_param_->bidirectional_) {
|
||||
bias_data += 4 * hidden_size * 2;
|
||||
auto backward_bias = bias_ptr_ + 4 * hidden_size;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
backward_bias[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::Init() {
|
||||
FreeTmpBuffer();
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 InitWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::ReSize() {
|
||||
auto ret = InitParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 InitParam error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 InitBuffer error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::Run() {
|
||||
auto input = in_tensors_.at(kInputIndex);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto hidden_state = in_tensors_.at(4);
|
||||
MS_ASSERT(hidden_state != nullptr);
|
||||
auto cell_state = in_tensors_.at(5);
|
||||
MS_ASSERT(cell_state != nullptr);
|
||||
auto output = out_tensors_.at(0);
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
auto input_ptr = reinterpret_cast<float16_t *>(input->data_c());
|
||||
MS_ASSERT(input_ptr);
|
||||
auto output_ptr = reinterpret_cast<float16_t *>(output->data_c());
|
||||
MS_ASSERT(output_ptr);
|
||||
auto output_hidden_state = out_tensors_[1];
|
||||
memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float16_t));
|
||||
auto output_cell_state = out_tensors_[2];
|
||||
memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float16_t));
|
||||
|
||||
MS_ASSERT(weight_h_ptr_);
|
||||
MS_ASSERT(weight_i_ptr_);
|
||||
MS_ASSERT(bias_ptr_);
|
||||
MS_ASSERT(gate_buffer_);
|
||||
LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
|
||||
reinterpret_cast<float16_t *>(output_hidden_state->data_c()),
|
||||
reinterpret_cast<float16_t *>(output_cell_state->data_c()), gate_buffer_, state_buffer_, lstm_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Lstm, LiteKernelCreator<LstmFp16CPUKernel>)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LSTM_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LSTM_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/lstm_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class LstmFp16CPUKernel : public LiteKernel {
|
||||
public:
|
||||
LstmFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_);
|
||||
}
|
||||
|
||||
~LstmFp16CPUKernel() override { FreeTmpBuffer(); }
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
int InitParam();
|
||||
int InitBuffer();
|
||||
int InitWeightBias();
|
||||
|
||||
float16_t *gate_buffer_ = nullptr;
|
||||
float16_t *state_buffer_ = nullptr;
|
||||
float16_t *weight_i_ptr_ = nullptr;
|
||||
float16_t *weight_h_ptr_ = nullptr;
|
||||
float16_t *bias_ptr_ = nullptr;
|
||||
LstmParameter *lstm_param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LSTM_H_
|
Loading…
Reference in new issue