[MSLITE][Develop] optimize arm cpu fp32 op lstm: use matmul calculate function

pull/11989/head
yangruoqi713 4 years ago
parent 01a0cdf5f0
commit 7aea132188

@ -19,20 +19,7 @@
#include <float.h>
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32/arithmetic_fp32.h"
#include "nnacl/fp32/mul_fp32.h"
void InitGate(float *gate_buffer, const float *bias, const LstmParameter *lstm_parm) {
int gate_offest = 0;
for (int l = 0; l < 4; l++) {
int batch_offest = gate_offest;
int bias_offest = l * lstm_parm->hidden_size_;
for (int b = 0; b < lstm_parm->batch_; b++) {
memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float));
batch_offest += lstm_parm->hidden_size_;
}
gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_;
}
}
#include "nnacl/fp32/matmul_fp32.h"
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) {
@ -134,106 +121,131 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd
}
}
void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight,
const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight,
const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight,
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
const LstmParameter *lstm_parm) {
InitGate(gate_buffer, bias, lstm_parm);
void LstmMatmul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
if (is_vec) {
memcpy(c, bias, col * sizeof(float));
MatMulAcc(c, a, b, row, col, deep);
} else {
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
}
}
void PackLstmInput(float *dst, const float *src, int row, int deep) {
#ifdef ENABLE_AVX
RowMajor2Col6Major(src, dst, row, deep);
#elif defined(ENABLE_SSE)
RowMajor2Col4Major(src, dst, row, deep);
#else
RowMajor2Col12Major(src, dst, row, deep);
#endif
}
void UpdateGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep,
int col, int col_align, bool is_vec) {
const float *input_weight = weight;
const float *forget_weight = weight + deep * col * 2;
const float *cell_weight = weight + deep * col * 3;
const float *output_weight = weight + deep * col;
const float *input_bias = bias;
const float *forget_bias = bias + col_align * 2;
const float *cell_bias = bias + col_align * 3;
const float *output_bias = bias + col_align;
float *input_gate = gate_buffer;
float *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2;
float *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3;
float *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1;
float *forget_gate = gate_buffer + row * col * 2;
float *cell_gate = gate_buffer + row * col * 3;
float *output_gate = gate_buffer + row * col;
LstmMatmul(input_gate, input, input_weight, input_bias, row, deep, col, is_vec);
LstmMatmul(forget_gate, input, forget_weight, forget_bias, row, deep, col, is_vec);
LstmMatmul(cell_gate, input, cell_weight, cell_bias, row, deep, col, is_vec);
LstmMatmul(output_gate, input, output_weight, output_bias, row, deep, col, is_vec);
}
void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight,
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
float *matmul_buffer[2], const LstmParameter *lstm_param) {
bool is_vec = lstm_param->batch_ == 1;
// input * weight
MatMulAcc(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_);
MatMulAcc(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->input_size_);
MatMulAcc(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_);
MatMulAcc(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->input_size_);
if (is_vec) {
UpdateGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
} else {
// pack input for matmul
PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_);
UpdateGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
}
// state * weight
MatMulAcc(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->hidden_size_);
MatMulAcc(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->hidden_size_);
MatMulAcc(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->hidden_size_);
MatMulAcc(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->hidden_size_);
float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4;
const float *state_bias = bias + lstm_param->col_align_ * 4;
if (is_vec) {
UpdateGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
} else {
// pack state for matmul
PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_);
UpdateGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
}
ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_);
float *input_gate = gate_buffer;
float *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2;
float *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3;
float *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_;
// update input_gate
Sigmoid(input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, input_gate);
Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate);
// update forget_gate
Sigmoid(forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, forget_gate);
Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate);
// update cell_gate
Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate);
Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate);
// update cell state
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->smooth_);
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_param->batch_,
lstm_param->hidden_size_, lstm_param->smooth_);
// update output_gate
Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate);
Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate);
// update output
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
lstm_parm->smooth_);
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) {
memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_,
lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->smooth_);
memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
if (!(lstm_param->smooth_ >= -FLT_EPSILON && lstm_param->smooth_ <= FLT_EPSILON)) {
memcpy(cell_state, state_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
memcpy(hidden_state, state_buffer + lstm_param->batch_ * lstm_param->hidden_size_,
lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
}
}
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
const LstmParameter *lstm_parm) {
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2],
const LstmParameter *lstm_param) {
// forward
const float *input_input_weight = weight_i;
const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2;
const float *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3;
const float *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1;
const float *state_input_weight = weight_h;
const float *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2;
const float *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3;
const float *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1;
for (int t = 0; t < lstm_parm->seq_len_; t++) {
const float *input_ptr = input + t * lstm_parm->input_step_;
float *output_ptr = output + t * lstm_parm->output_step_;
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight,
state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state,
cell_state, gate_buffer, state_buffer, lstm_parm);
for (int t = 0; t < lstm_param->seq_len_; t++) {
const float *input_ptr = input + t * lstm_param->input_step_;
float *output_ptr = output + t * lstm_param->output_step_;
LstmStepUnit(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, state_buffer,
matmul_buffer, lstm_param);
}
// backward
if (lstm_parm->bidirectional_) {
input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4;
input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6;
input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7;
input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5;
state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4;
state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6;
state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7;
state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5;
float *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_;
const float *backward_bias = bias + 4 * lstm_parm->hidden_size_;
float *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
float *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) {
const float *input_ptr = input + t * lstm_parm->input_step_;
float *output_ptr = backward_output + t * lstm_parm->output_step_;
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight,
backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm);
if (lstm_param->bidirectional_) {
const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_;
const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_;
const float *backward_bias = bias + 8 * lstm_param->hidden_size_;
float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_;
float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_;
float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) {
const float *input_ptr = input + t * lstm_param->input_step_;
float *output_ptr = backward_output + t * lstm_param->output_step_;
LstmStepUnit(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, backward_hidden_state,
backward_cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param);
}
}
}

@ -28,7 +28,7 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int
int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size);
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2],
const LstmParameter *lstm_parm);
#ifdef __cplusplus
}

@ -34,6 +34,8 @@ typedef struct LstmParameter {
// output_hidden = old_hidden * smooth + new_hidden * (1 - smooth)
// output_cell = old_cell * smooth + new_cell * (1 - smooth)
float smooth_;
int col_align_;
int row_align_;
} LstmParameter;
#endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_

@ -84,9 +84,8 @@ std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, const size_t
bool IsPackedOp(schema::PrimitiveType op_type) {
static std::vector<schema::PrimitiveType> packed_ops = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm};
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_MatMul};
return IsContain(packed_ops, op_type);
}
} // namespace lite

File diff suppressed because it is too large Load Diff

@ -39,8 +39,9 @@ class LstmCPUKernel : public LiteKernel {
private:
void FreeTmpBuffer();
void FreeRunBuffer();
int InitParam();
int InitBuffer();
int MallocRunBuffer();
int InitWeightBias();
float *gate_buffer_ = nullptr;
@ -48,6 +49,10 @@ class LstmCPUKernel : public LiteKernel {
float *weight_i_ptr_ = nullptr;
float *weight_h_ptr_ = nullptr;
float *bias_ptr_ = nullptr;
float *matmul_buffer_[2];
int row_tile_ = 0;
int col_tile_ = 0;
bool is_vec_ = false;
LstmParameter *lstm_param_ = nullptr;
};
} // namespace mindspore::kernel

Loading…
Cancel
Save